blob: dc9926ec6c2b530e85952cbe8ce4f9afabc73dc4 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001// Copyright 2017 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package ordering
16
17import (
18 "context"
19 "sync"
20
21 "github.com/coreos/etcd/clientv3"
22)
23
24// kvOrdering ensures that serialized requests do not return
25// get with revisions less than the previous
26// returned revision.
27type kvOrdering struct {
28 clientv3.KV
29 orderViolationFunc OrderViolationFunc
30 prevRev int64
31 revMu sync.RWMutex
32}
33
34func NewKV(kv clientv3.KV, orderViolationFunc OrderViolationFunc) *kvOrdering {
35 return &kvOrdering{kv, orderViolationFunc, 0, sync.RWMutex{}}
36}
37
38func (kv *kvOrdering) getPrevRev() int64 {
39 kv.revMu.RLock()
40 defer kv.revMu.RUnlock()
41 return kv.prevRev
42}
43
44func (kv *kvOrdering) setPrevRev(currRev int64) {
45 kv.revMu.Lock()
46 defer kv.revMu.Unlock()
47 if currRev > kv.prevRev {
48 kv.prevRev = currRev
49 }
50}
51
52func (kv *kvOrdering) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
53 // prevRev is stored in a local variable in order to record the prevRev
54 // at the beginning of the Get operation, because concurrent
55 // access to kvOrdering could change the prevRev field in the
56 // middle of the Get operation.
57 prevRev := kv.getPrevRev()
58 op := clientv3.OpGet(key, opts...)
59 for {
60 r, err := kv.KV.Do(ctx, op)
61 if err != nil {
62 return nil, err
63 }
64 resp := r.Get()
65 if resp.Header.Revision == prevRev {
66 return resp, nil
67 } else if resp.Header.Revision > prevRev {
68 kv.setPrevRev(resp.Header.Revision)
69 return resp, nil
70 }
71 err = kv.orderViolationFunc(op, r, prevRev)
72 if err != nil {
73 return nil, err
74 }
75 }
76}
77
78func (kv *kvOrdering) Txn(ctx context.Context) clientv3.Txn {
79 return &txnOrdering{
80 kv.KV.Txn(ctx),
81 kv,
82 ctx,
83 sync.Mutex{},
84 []clientv3.Cmp{},
85 []clientv3.Op{},
86 []clientv3.Op{},
87 }
88}
89
90// txnOrdering ensures that serialized requests do not return
91// txn responses with revisions less than the previous
92// returned revision.
93type txnOrdering struct {
94 clientv3.Txn
95 *kvOrdering
96 ctx context.Context
97 mu sync.Mutex
98 cmps []clientv3.Cmp
99 thenOps []clientv3.Op
100 elseOps []clientv3.Op
101}
102
103func (txn *txnOrdering) If(cs ...clientv3.Cmp) clientv3.Txn {
104 txn.mu.Lock()
105 defer txn.mu.Unlock()
106 txn.cmps = cs
107 txn.Txn.If(cs...)
108 return txn
109}
110
111func (txn *txnOrdering) Then(ops ...clientv3.Op) clientv3.Txn {
112 txn.mu.Lock()
113 defer txn.mu.Unlock()
114 txn.thenOps = ops
115 txn.Txn.Then(ops...)
116 return txn
117}
118
119func (txn *txnOrdering) Else(ops ...clientv3.Op) clientv3.Txn {
120 txn.mu.Lock()
121 defer txn.mu.Unlock()
122 txn.elseOps = ops
123 txn.Txn.Else(ops...)
124 return txn
125}
126
127func (txn *txnOrdering) Commit() (*clientv3.TxnResponse, error) {
128 // prevRev is stored in a local variable in order to record the prevRev
129 // at the beginning of the Commit operation, because concurrent
130 // access to txnOrdering could change the prevRev field in the
131 // middle of the Commit operation.
132 prevRev := txn.getPrevRev()
133 opTxn := clientv3.OpTxn(txn.cmps, txn.thenOps, txn.elseOps)
134 for {
135 opResp, err := txn.KV.Do(txn.ctx, opTxn)
136 if err != nil {
137 return nil, err
138 }
139 txnResp := opResp.Txn()
140 if txnResp.Header.Revision >= prevRev {
141 txn.setPrevRev(txnResp.Header.Revision)
142 return txnResp, nil
143 }
144 err = txn.orderViolationFunc(opTxn, opResp, prevRev)
145 if err != nil {
146 return nil, err
147 }
148 }
149}