blob: ee1151079abd4f04d73cbbea8d20ef35772e799f [file] [log] [blame]
Scott Baker2d897982019-09-24 11:50:08 -07001// Copyright 2016 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package concurrency
16
17import (
18 "context"
19 "math"
20
21 v3 "go.etcd.io/etcd/clientv3"
22)
23
24// STM is an interface for software transactional memory.
25type STM interface {
26 // Get returns the value for a key and inserts the key in the txn's read set.
27 // If Get fails, it aborts the transaction with an error, never returning.
28 Get(key ...string) string
29 // Put adds a value for a key to the write set.
30 Put(key, val string, opts ...v3.OpOption)
31 // Rev returns the revision of a key in the read set.
32 Rev(key string) int64
33 // Del deletes a key.
34 Del(key string)
35
36 // commit attempts to apply the txn's changes to the server.
37 commit() *v3.TxnResponse
38 reset()
39}
40
41// Isolation is an enumeration of transactional isolation levels which
42// describes how transactions should interfere and conflict.
43type Isolation int
44
45const (
46 // SerializableSnapshot provides serializable isolation and also checks
47 // for write conflicts.
48 SerializableSnapshot Isolation = iota
49 // Serializable reads within the same transaction attempt return data
50 // from the at the revision of the first read.
51 Serializable
52 // RepeatableReads reads within the same transaction attempt always
53 // return the same data.
54 RepeatableReads
55 // ReadCommitted reads keys from any committed revision.
56 ReadCommitted
57)
58
59// stmError safely passes STM errors through panic to the STM error channel.
60type stmError struct{ err error }
61
62type stmOptions struct {
63 iso Isolation
64 ctx context.Context
65 prefetch []string
66}
67
68type stmOption func(*stmOptions)
69
70// WithIsolation specifies the transaction isolation level.
71func WithIsolation(lvl Isolation) stmOption {
72 return func(so *stmOptions) { so.iso = lvl }
73}
74
75// WithAbortContext specifies the context for permanently aborting the transaction.
76func WithAbortContext(ctx context.Context) stmOption {
77 return func(so *stmOptions) { so.ctx = ctx }
78}
79
80// WithPrefetch is a hint to prefetch a list of keys before trying to apply.
81// If an STM transaction will unconditionally fetch a set of keys, prefetching
82// those keys will save the round-trip cost from requesting each key one by one
83// with Get().
84func WithPrefetch(keys ...string) stmOption {
85 return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
86}
87
88// NewSTM initiates a new STM instance, using serializable snapshot isolation by default.
89func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
90 opts := &stmOptions{ctx: c.Ctx()}
91 for _, f := range so {
92 f(opts)
93 }
94 if len(opts.prefetch) != 0 {
95 f := apply
96 apply = func(s STM) error {
97 s.Get(opts.prefetch...)
98 return f(s)
99 }
100 }
101 return runSTM(mkSTM(c, opts), apply)
102}
103
104func mkSTM(c *v3.Client, opts *stmOptions) STM {
105 switch opts.iso {
106 case SerializableSnapshot:
107 s := &stmSerializable{
108 stm: stm{client: c, ctx: opts.ctx},
109 prefetch: make(map[string]*v3.GetResponse),
110 }
111 s.conflicts = func() []v3.Cmp {
112 return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
113 }
114 return s
115 case Serializable:
116 s := &stmSerializable{
117 stm: stm{client: c, ctx: opts.ctx},
118 prefetch: make(map[string]*v3.GetResponse),
119 }
120 s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
121 return s
122 case RepeatableReads:
123 s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
124 s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
125 return s
126 case ReadCommitted:
127 s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
128 s.conflicts = func() []v3.Cmp { return nil }
129 return s
130 default:
131 panic("unsupported stm")
132 }
133}
134
135type stmResponse struct {
136 resp *v3.TxnResponse
137 err error
138}
139
140func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
141 outc := make(chan stmResponse, 1)
142 go func() {
143 defer func() {
144 if r := recover(); r != nil {
145 e, ok := r.(stmError)
146 if !ok {
147 // client apply panicked
148 panic(r)
149 }
150 outc <- stmResponse{nil, e.err}
151 }
152 }()
153 var out stmResponse
154 for {
155 s.reset()
156 if out.err = apply(s); out.err != nil {
157 break
158 }
159 if out.resp = s.commit(); out.resp != nil {
160 break
161 }
162 }
163 outc <- out
164 }()
165 r := <-outc
166 return r.resp, r.err
167}
168
169// stm implements repeatable-read software transactional memory over etcd
170type stm struct {
171 client *v3.Client
172 ctx context.Context
173 // rset holds read key values and revisions
174 rset readSet
175 // wset holds overwritten keys and their values
176 wset writeSet
177 // getOpts are the opts used for gets
178 getOpts []v3.OpOption
179 // conflicts computes the current conflicts on the txn
180 conflicts func() []v3.Cmp
181}
182
183type stmPut struct {
184 val string
185 op v3.Op
186}
187
188type readSet map[string]*v3.GetResponse
189
190func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
191 for i, resp := range txnresp.Responses {
192 rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
193 }
194}
195
196// first returns the store revision from the first fetch
197func (rs readSet) first() int64 {
198 ret := int64(math.MaxInt64 - 1)
199 for _, resp := range rs {
200 if rev := resp.Header.Revision; rev < ret {
201 ret = rev
202 }
203 }
204 return ret
205}
206
207// cmps guards the txn from updates to read set
208func (rs readSet) cmps() []v3.Cmp {
209 cmps := make([]v3.Cmp, 0, len(rs))
210 for k, rk := range rs {
211 cmps = append(cmps, isKeyCurrent(k, rk))
212 }
213 return cmps
214}
215
216type writeSet map[string]stmPut
217
218func (ws writeSet) get(keys ...string) *stmPut {
219 for _, key := range keys {
220 if wv, ok := ws[key]; ok {
221 return &wv
222 }
223 }
224 return nil
225}
226
227// cmps returns a cmp list testing no writes have happened past rev
228func (ws writeSet) cmps(rev int64) []v3.Cmp {
229 cmps := make([]v3.Cmp, 0, len(ws))
230 for key := range ws {
231 cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
232 }
233 return cmps
234}
235
236// puts is the list of ops for all pending writes
237func (ws writeSet) puts() []v3.Op {
238 puts := make([]v3.Op, 0, len(ws))
239 for _, v := range ws {
240 puts = append(puts, v.op)
241 }
242 return puts
243}
244
245func (s *stm) Get(keys ...string) string {
246 if wv := s.wset.get(keys...); wv != nil {
247 return wv.val
248 }
249 return respToValue(s.fetch(keys...))
250}
251
252func (s *stm) Put(key, val string, opts ...v3.OpOption) {
253 s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
254}
255
256func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
257
258func (s *stm) Rev(key string) int64 {
259 if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
260 return resp.Kvs[0].ModRevision
261 }
262 return 0
263}
264
265func (s *stm) commit() *v3.TxnResponse {
266 txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
267 if err != nil {
268 panic(stmError{err})
269 }
270 if txnresp.Succeeded {
271 return txnresp
272 }
273 return nil
274}
275
276func (s *stm) fetch(keys ...string) *v3.GetResponse {
277 if len(keys) == 0 {
278 return nil
279 }
280 ops := make([]v3.Op, len(keys))
281 for i, key := range keys {
282 if resp, ok := s.rset[key]; ok {
283 return resp
284 }
285 ops[i] = v3.OpGet(key, s.getOpts...)
286 }
287 txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
288 if err != nil {
289 panic(stmError{err})
290 }
291 s.rset.add(keys, txnresp)
292 return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
293}
294
295func (s *stm) reset() {
296 s.rset = make(map[string]*v3.GetResponse)
297 s.wset = make(map[string]stmPut)
298}
299
300type stmSerializable struct {
301 stm
302 prefetch map[string]*v3.GetResponse
303}
304
305func (s *stmSerializable) Get(keys ...string) string {
306 if wv := s.wset.get(keys...); wv != nil {
307 return wv.val
308 }
309 firstRead := len(s.rset) == 0
310 for _, key := range keys {
311 if resp, ok := s.prefetch[key]; ok {
312 delete(s.prefetch, key)
313 s.rset[key] = resp
314 }
315 }
316 resp := s.stm.fetch(keys...)
317 if firstRead {
318 // txn's base revision is defined by the first read
319 s.getOpts = []v3.OpOption{
320 v3.WithRev(resp.Header.Revision),
321 v3.WithSerializable(),
322 }
323 }
324 return respToValue(resp)
325}
326
327func (s *stmSerializable) Rev(key string) int64 {
328 s.Get(key)
329 return s.stm.Rev(key)
330}
331
332func (s *stmSerializable) gets() ([]string, []v3.Op) {
333 keys := make([]string, 0, len(s.rset))
334 ops := make([]v3.Op, 0, len(s.rset))
335 for k := range s.rset {
336 keys = append(keys, k)
337 ops = append(ops, v3.OpGet(k))
338 }
339 return keys, ops
340}
341
342func (s *stmSerializable) commit() *v3.TxnResponse {
343 keys, getops := s.gets()
344 txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
345 // use Else to prefetch keys in case of conflict to save a round trip
346 txnresp, err := txn.Else(getops...).Commit()
347 if err != nil {
348 panic(stmError{err})
349 }
350 if txnresp.Succeeded {
351 return txnresp
352 }
353 // load prefetch with Else data
354 s.rset.add(keys, txnresp)
355 s.prefetch = s.rset
356 s.getOpts = nil
357 return nil
358}
359
360func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
361 if len(r.Kvs) != 0 {
362 return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
363 }
364 return v3.Compare(v3.ModRevision(k), "=", 0)
365}
366
367func respToValue(resp *v3.GetResponse) string {
368 if resp == nil || len(resp.Kvs) == 0 {
369 return ""
370 }
371 return string(resp.Kvs[0].Value)
372}
373
374// NewSTMRepeatable is deprecated.
375func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
376 return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads))
377}
378
379// NewSTMSerializable is deprecated.
380func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
381 return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable))
382}
383
384// NewSTMReadCommitted is deprecated.
385func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
386 return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted))
387}