blob: 405d507a117b3cb0aedc534889b6eb07d9b0227e [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package session
8
9import (
10 "errors"
11
12 "github.com/mongodb/mongo-go-driver/bson"
13 "github.com/mongodb/mongo-go-driver/bson/primitive"
14 "github.com/mongodb/mongo-go-driver/mongo/readconcern"
15 "github.com/mongodb/mongo-go-driver/mongo/readpref"
16 "github.com/mongodb/mongo-go-driver/mongo/writeconcern"
17 "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid"
18)
19
20// ErrSessionEnded is returned when a client session is used after a call to endSession().
21var ErrSessionEnded = errors.New("ended session was used")
22
23// ErrNoTransactStarted is returned if a transaction operation is called when no transaction has started.
24var ErrNoTransactStarted = errors.New("no transaction started")
25
26// ErrTransactInProgress is returned if startTransaction() is called when a transaction is in progress.
27var ErrTransactInProgress = errors.New("transaction already in progress")
28
29// ErrAbortAfterCommit is returned when abort is called after a commit.
30var ErrAbortAfterCommit = errors.New("cannot call abortTransaction after calling commitTransaction")
31
32// ErrAbortTwice is returned if abort is called after transaction is already aborted.
33var ErrAbortTwice = errors.New("cannot call abortTransaction twice")
34
35// ErrCommitAfterAbort is returned if commit is called after an abort.
36var ErrCommitAfterAbort = errors.New("cannot call commitTransaction after calling abortTransaction")
37
38// ErrUnackWCUnsupported is returned if an unacknowledged write concern is supported for a transaciton.
39var ErrUnackWCUnsupported = errors.New("transactions do not support unacknowledged write concerns")
40
41// Type describes the type of the session
42type Type uint8
43
44// These constants are the valid types for a client session.
45const (
46 Explicit Type = iota
47 Implicit
48)
49
50// State indicates the state of the FSM.
51type state uint8
52
53// Client Session states
54const (
55 None state = iota
56 Starting
57 InProgress
58 Committed
59 Aborted
60)
61
62// Client is a session for clients to run commands.
63type Client struct {
64 *Server
65 ClientID uuid.UUID
66 ClusterTime bson.Raw
67 Consistent bool // causal consistency
68 OperationTime *primitive.Timestamp
69 SessionType Type
70 Terminated bool
71 RetryingCommit bool
72 Committing bool
73 Aborting bool
74 RetryWrite bool
75
76 // options for the current transaction
77 // most recently set by transactionopt
78 CurrentRc *readconcern.ReadConcern
79 CurrentRp *readpref.ReadPref
80 CurrentWc *writeconcern.WriteConcern
81
82 // default transaction options
83 transactionRc *readconcern.ReadConcern
84 transactionRp *readpref.ReadPref
85 transactionWc *writeconcern.WriteConcern
86
87 pool *Pool
88 state state
89}
90
91func getClusterTime(clusterTime bson.Raw) (uint32, uint32) {
92 if clusterTime == nil {
93 return 0, 0
94 }
95
96 clusterTimeVal, err := clusterTime.LookupErr("$clusterTime")
97 if err != nil {
98 return 0, 0
99 }
100
101 timestampVal, err := bson.Raw(clusterTimeVal.Value).LookupErr("clusterTime")
102 if err != nil {
103 return 0, 0
104 }
105
106 return timestampVal.Timestamp()
107}
108
109// MaxClusterTime compares 2 clusterTime documents and returns the document representing the highest cluster time.
110func MaxClusterTime(ct1, ct2 bson.Raw) bson.Raw {
111 epoch1, ord1 := getClusterTime(ct1)
112 epoch2, ord2 := getClusterTime(ct2)
113
114 if epoch1 > epoch2 {
115 return ct1
116 } else if epoch1 < epoch2 {
117 return ct2
118 } else if ord1 > ord2 {
119 return ct1
120 } else if ord1 < ord2 {
121 return ct2
122 }
123
124 return ct1
125}
126
127// NewClientSession creates a Client.
128func NewClientSession(pool *Pool, clientID uuid.UUID, sessionType Type, opts ...*ClientOptions) (*Client, error) {
129 c := &Client{
130 Consistent: true, // set default
131 ClientID: clientID,
132 SessionType: sessionType,
133 pool: pool,
134 }
135
136 mergedOpts := mergeClientOptions(opts...)
137 if mergedOpts.CausalConsistency != nil {
138 c.Consistent = *mergedOpts.CausalConsistency
139 }
140 if mergedOpts.DefaultReadPreference != nil {
141 c.transactionRp = mergedOpts.DefaultReadPreference
142 }
143 if mergedOpts.DefaultReadConcern != nil {
144 c.transactionRc = mergedOpts.DefaultReadConcern
145 }
146 if mergedOpts.DefaultWriteConcern != nil {
147 c.transactionWc = mergedOpts.DefaultWriteConcern
148 }
149
150 servSess, err := pool.GetSession()
151 if err != nil {
152 return nil, err
153 }
154
155 c.Server = servSess
156
157 return c, nil
158}
159
160// AdvanceClusterTime updates the session's cluster time.
161func (c *Client) AdvanceClusterTime(clusterTime bson.Raw) error {
162 if c.Terminated {
163 return ErrSessionEnded
164 }
165 c.ClusterTime = MaxClusterTime(c.ClusterTime, clusterTime)
166 return nil
167}
168
169// AdvanceOperationTime updates the session's operation time.
170func (c *Client) AdvanceOperationTime(opTime *primitive.Timestamp) error {
171 if c.Terminated {
172 return ErrSessionEnded
173 }
174
175 if c.OperationTime == nil {
176 c.OperationTime = opTime
177 return nil
178 }
179
180 if opTime.T > c.OperationTime.T {
181 c.OperationTime = opTime
182 } else if (opTime.T == c.OperationTime.T) && (opTime.I > c.OperationTime.I) {
183 c.OperationTime = opTime
184 }
185
186 return nil
187}
188
189// UpdateUseTime updates the session's last used time.
190// Must be called whenver this session is used to send a command to the server.
191func (c *Client) UpdateUseTime() error {
192 if c.Terminated {
193 return ErrSessionEnded
194 }
195 c.updateUseTime()
196 return nil
197}
198
199// EndSession ends the session.
200func (c *Client) EndSession() {
201 if c.Terminated {
202 return
203 }
204
205 c.Terminated = true
206 c.pool.ReturnSession(c.Server)
207
208 return
209}
210
211// TransactionInProgress returns true if the client session is in an active transaction.
212func (c *Client) TransactionInProgress() bool {
213 return c.state == InProgress
214}
215
216// TransactionStarting returns true if the client session is starting a transaction.
217func (c *Client) TransactionStarting() bool {
218 return c.state == Starting
219}
220
221// TransactionRunning returns true if the client session has started the transaction
222// and it hasn't been committed or aborted
223func (c *Client) TransactionRunning() bool {
224 return c.state == Starting || c.state == InProgress
225}
226
227// TransactionCommitted returns true of the client session just committed a transaciton.
228func (c *Client) TransactionCommitted() bool {
229 return c.state == Committed
230}
231
232// CheckStartTransaction checks to see if allowed to start transaction and returns
233// an error if not allowed
234func (c *Client) CheckStartTransaction() error {
235 if c.state == InProgress || c.state == Starting {
236 return ErrTransactInProgress
237 }
238 return nil
239}
240
241// StartTransaction initializes the transaction options and advances the state machine.
242// It does not contact the server to start the transaction.
243func (c *Client) StartTransaction(opts *TransactionOptions) error {
244 err := c.CheckStartTransaction()
245 if err != nil {
246 return err
247 }
248
249 c.IncrementTxnNumber()
250 c.RetryingCommit = false
251
252 if opts != nil {
253 c.CurrentRc = opts.ReadConcern
254 c.CurrentRp = opts.ReadPreference
255 c.CurrentWc = opts.WriteConcern
256 }
257
258 if c.CurrentRc == nil {
259 c.CurrentRc = c.transactionRc
260 }
261
262 if c.CurrentRp == nil {
263 c.CurrentRp = c.transactionRp
264 }
265
266 if c.CurrentWc == nil {
267 c.CurrentWc = c.transactionWc
268 }
269
270 if !writeconcern.AckWrite(c.CurrentWc) {
271 c.clearTransactionOpts()
272 return ErrUnackWCUnsupported
273 }
274
275 c.state = Starting
276 return nil
277}
278
279// CheckCommitTransaction checks to see if allowed to commit transaction and returns
280// an error if not allowed.
281func (c *Client) CheckCommitTransaction() error {
282 if c.state == None {
283 return ErrNoTransactStarted
284 } else if c.state == Aborted {
285 return ErrCommitAfterAbort
286 }
287 return nil
288}
289
290// CommitTransaction updates the state for a successfully committed transaction and returns
291// an error if not permissible. It does not actually perform the commit.
292func (c *Client) CommitTransaction() error {
293 err := c.CheckCommitTransaction()
294 if err != nil {
295 return err
296 }
297 c.state = Committed
298 return nil
299}
300
301// CheckAbortTransaction checks to see if allowed to abort transaction and returns
302// an error if not allowed.
303func (c *Client) CheckAbortTransaction() error {
304 if c.state == None {
305 return ErrNoTransactStarted
306 } else if c.state == Committed {
307 return ErrAbortAfterCommit
308 } else if c.state == Aborted {
309 return ErrAbortTwice
310 }
311 return nil
312}
313
314// AbortTransaction updates the state for a successfully committed transaction and returns
315// an error if not permissible. It does not actually perform the abort.
316func (c *Client) AbortTransaction() error {
317 err := c.CheckAbortTransaction()
318 if err != nil {
319 return err
320 }
321 c.state = Aborted
322 c.clearTransactionOpts()
323 return nil
324}
325
326// ApplyCommand advances the state machine upon command execution.
327func (c *Client) ApplyCommand() {
328 if c.Committing {
329 // Do not change state if committing after already committed
330 return
331 }
332 if c.state == Starting {
333 c.state = InProgress
334 } else if c.state == Committed || c.state == Aborted {
335 c.clearTransactionOpts()
336 c.state = None
337 }
338}
339
340func (c *Client) clearTransactionOpts() {
341 c.RetryingCommit = false
342 c.Aborting = false
343 c.Committing = false
344 c.CurrentWc = nil
345 c.CurrentRp = nil
346 c.CurrentRc = nil
347}