| // Copyright 2016 The etcd Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package concurrency |
| |
| import ( |
| "context" |
| "math" |
| |
| v3 "go.etcd.io/etcd/clientv3" |
| ) |
| |
| // STM is an interface for software transactional memory. |
| type STM interface { |
| // Get returns the value for a key and inserts the key in the txn's read set. |
| // If Get fails, it aborts the transaction with an error, never returning. |
| Get(key ...string) string |
| // Put adds a value for a key to the write set. |
| Put(key, val string, opts ...v3.OpOption) |
| // Rev returns the revision of a key in the read set. |
| Rev(key string) int64 |
| // Del deletes a key. |
| Del(key string) |
| |
| // commit attempts to apply the txn's changes to the server. |
| commit() *v3.TxnResponse |
| reset() |
| } |
| |
| // Isolation is an enumeration of transactional isolation levels which |
| // describes how transactions should interfere and conflict. |
| type Isolation int |
| |
| const ( |
| // SerializableSnapshot provides serializable isolation and also checks |
| // for write conflicts. |
| SerializableSnapshot Isolation = iota |
| // Serializable reads within the same transaction attempt return data |
| // from the at the revision of the first read. |
| Serializable |
| // RepeatableReads reads within the same transaction attempt always |
| // return the same data. |
| RepeatableReads |
| // ReadCommitted reads keys from any committed revision. |
| ReadCommitted |
| ) |
| |
| // stmError safely passes STM errors through panic to the STM error channel. |
| type stmError struct{ err error } |
| |
| type stmOptions struct { |
| iso Isolation |
| ctx context.Context |
| prefetch []string |
| } |
| |
| type stmOption func(*stmOptions) |
| |
| // WithIsolation specifies the transaction isolation level. |
| func WithIsolation(lvl Isolation) stmOption { |
| return func(so *stmOptions) { so.iso = lvl } |
| } |
| |
| // WithAbortContext specifies the context for permanently aborting the transaction. |
| func WithAbortContext(ctx context.Context) stmOption { |
| return func(so *stmOptions) { so.ctx = ctx } |
| } |
| |
| // WithPrefetch is a hint to prefetch a list of keys before trying to apply. |
| // If an STM transaction will unconditionally fetch a set of keys, prefetching |
| // those keys will save the round-trip cost from requesting each key one by one |
| // with Get(). |
| func WithPrefetch(keys ...string) stmOption { |
| return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) } |
| } |
| |
| // NewSTM initiates a new STM instance, using serializable snapshot isolation by default. |
| func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) { |
| opts := &stmOptions{ctx: c.Ctx()} |
| for _, f := range so { |
| f(opts) |
| } |
| if len(opts.prefetch) != 0 { |
| f := apply |
| apply = func(s STM) error { |
| s.Get(opts.prefetch...) |
| return f(s) |
| } |
| } |
| return runSTM(mkSTM(c, opts), apply) |
| } |
| |
| func mkSTM(c *v3.Client, opts *stmOptions) STM { |
| switch opts.iso { |
| case SerializableSnapshot: |
| s := &stmSerializable{ |
| stm: stm{client: c, ctx: opts.ctx}, |
| prefetch: make(map[string]*v3.GetResponse), |
| } |
| s.conflicts = func() []v3.Cmp { |
| return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...) |
| } |
| return s |
| case Serializable: |
| s := &stmSerializable{ |
| stm: stm{client: c, ctx: opts.ctx}, |
| prefetch: make(map[string]*v3.GetResponse), |
| } |
| s.conflicts = func() []v3.Cmp { return s.rset.cmps() } |
| return s |
| case RepeatableReads: |
| s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} |
| s.conflicts = func() []v3.Cmp { return s.rset.cmps() } |
| return s |
| case ReadCommitted: |
| s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} |
| s.conflicts = func() []v3.Cmp { return nil } |
| return s |
| default: |
| panic("unsupported stm") |
| } |
| } |
| |
| type stmResponse struct { |
| resp *v3.TxnResponse |
| err error |
| } |
| |
| func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) { |
| outc := make(chan stmResponse, 1) |
| go func() { |
| defer func() { |
| if r := recover(); r != nil { |
| e, ok := r.(stmError) |
| if !ok { |
| // client apply panicked |
| panic(r) |
| } |
| outc <- stmResponse{nil, e.err} |
| } |
| }() |
| var out stmResponse |
| for { |
| s.reset() |
| if out.err = apply(s); out.err != nil { |
| break |
| } |
| if out.resp = s.commit(); out.resp != nil { |
| break |
| } |
| } |
| outc <- out |
| }() |
| r := <-outc |
| return r.resp, r.err |
| } |
| |
| // stm implements repeatable-read software transactional memory over etcd |
| type stm struct { |
| client *v3.Client |
| ctx context.Context |
| // rset holds read key values and revisions |
| rset readSet |
| // wset holds overwritten keys and their values |
| wset writeSet |
| // getOpts are the opts used for gets |
| getOpts []v3.OpOption |
| // conflicts computes the current conflicts on the txn |
| conflicts func() []v3.Cmp |
| } |
| |
| type stmPut struct { |
| val string |
| op v3.Op |
| } |
| |
| type readSet map[string]*v3.GetResponse |
| |
| func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) { |
| for i, resp := range txnresp.Responses { |
| rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange()) |
| } |
| } |
| |
| // first returns the store revision from the first fetch |
| func (rs readSet) first() int64 { |
| ret := int64(math.MaxInt64 - 1) |
| for _, resp := range rs { |
| if rev := resp.Header.Revision; rev < ret { |
| ret = rev |
| } |
| } |
| return ret |
| } |
| |
| // cmps guards the txn from updates to read set |
| func (rs readSet) cmps() []v3.Cmp { |
| cmps := make([]v3.Cmp, 0, len(rs)) |
| for k, rk := range rs { |
| cmps = append(cmps, isKeyCurrent(k, rk)) |
| } |
| return cmps |
| } |
| |
| type writeSet map[string]stmPut |
| |
| func (ws writeSet) get(keys ...string) *stmPut { |
| for _, key := range keys { |
| if wv, ok := ws[key]; ok { |
| return &wv |
| } |
| } |
| return nil |
| } |
| |
| // cmps returns a cmp list testing no writes have happened past rev |
| func (ws writeSet) cmps(rev int64) []v3.Cmp { |
| cmps := make([]v3.Cmp, 0, len(ws)) |
| for key := range ws { |
| cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev)) |
| } |
| return cmps |
| } |
| |
| // puts is the list of ops for all pending writes |
| func (ws writeSet) puts() []v3.Op { |
| puts := make([]v3.Op, 0, len(ws)) |
| for _, v := range ws { |
| puts = append(puts, v.op) |
| } |
| return puts |
| } |
| |
| func (s *stm) Get(keys ...string) string { |
| if wv := s.wset.get(keys...); wv != nil { |
| return wv.val |
| } |
| return respToValue(s.fetch(keys...)) |
| } |
| |
| func (s *stm) Put(key, val string, opts ...v3.OpOption) { |
| s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)} |
| } |
| |
| func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} } |
| |
| func (s *stm) Rev(key string) int64 { |
| if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 { |
| return resp.Kvs[0].ModRevision |
| } |
| return 0 |
| } |
| |
| func (s *stm) commit() *v3.TxnResponse { |
| txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit() |
| if err != nil { |
| panic(stmError{err}) |
| } |
| if txnresp.Succeeded { |
| return txnresp |
| } |
| return nil |
| } |
| |
| func (s *stm) fetch(keys ...string) *v3.GetResponse { |
| if len(keys) == 0 { |
| return nil |
| } |
| ops := make([]v3.Op, len(keys)) |
| for i, key := range keys { |
| if resp, ok := s.rset[key]; ok { |
| return resp |
| } |
| ops[i] = v3.OpGet(key, s.getOpts...) |
| } |
| txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit() |
| if err != nil { |
| panic(stmError{err}) |
| } |
| s.rset.add(keys, txnresp) |
| return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange()) |
| } |
| |
| func (s *stm) reset() { |
| s.rset = make(map[string]*v3.GetResponse) |
| s.wset = make(map[string]stmPut) |
| } |
| |
| type stmSerializable struct { |
| stm |
| prefetch map[string]*v3.GetResponse |
| } |
| |
| func (s *stmSerializable) Get(keys ...string) string { |
| if wv := s.wset.get(keys...); wv != nil { |
| return wv.val |
| } |
| firstRead := len(s.rset) == 0 |
| for _, key := range keys { |
| if resp, ok := s.prefetch[key]; ok { |
| delete(s.prefetch, key) |
| s.rset[key] = resp |
| } |
| } |
| resp := s.stm.fetch(keys...) |
| if firstRead { |
| // txn's base revision is defined by the first read |
| s.getOpts = []v3.OpOption{ |
| v3.WithRev(resp.Header.Revision), |
| v3.WithSerializable(), |
| } |
| } |
| return respToValue(resp) |
| } |
| |
| func (s *stmSerializable) Rev(key string) int64 { |
| s.Get(key) |
| return s.stm.Rev(key) |
| } |
| |
| func (s *stmSerializable) gets() ([]string, []v3.Op) { |
| keys := make([]string, 0, len(s.rset)) |
| ops := make([]v3.Op, 0, len(s.rset)) |
| for k := range s.rset { |
| keys = append(keys, k) |
| ops = append(ops, v3.OpGet(k)) |
| } |
| return keys, ops |
| } |
| |
| func (s *stmSerializable) commit() *v3.TxnResponse { |
| keys, getops := s.gets() |
| txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...) |
| // use Else to prefetch keys in case of conflict to save a round trip |
| txnresp, err := txn.Else(getops...).Commit() |
| if err != nil { |
| panic(stmError{err}) |
| } |
| if txnresp.Succeeded { |
| return txnresp |
| } |
| // load prefetch with Else data |
| s.rset.add(keys, txnresp) |
| s.prefetch = s.rset |
| s.getOpts = nil |
| return nil |
| } |
| |
| func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp { |
| if len(r.Kvs) != 0 { |
| return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision) |
| } |
| return v3.Compare(v3.ModRevision(k), "=", 0) |
| } |
| |
| func respToValue(resp *v3.GetResponse) string { |
| if resp == nil || len(resp.Kvs) == 0 { |
| return "" |
| } |
| return string(resp.Kvs[0].Value) |
| } |
| |
| // NewSTMRepeatable is deprecated. |
| func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { |
| return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads)) |
| } |
| |
| // NewSTMSerializable is deprecated. |
| func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { |
| return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable)) |
| } |
| |
| // NewSTMReadCommitted is deprecated. |
| func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { |
| return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted)) |
| } |