blob: 381714d86bcb7cef8061d74834f33e1c845c5571 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10 "context"
11 "errors"
12
13 "github.com/mongodb/mongo-go-driver/bson"
14 "github.com/mongodb/mongo-go-driver/bson/primitive"
15 "github.com/mongodb/mongo-go-driver/mongo/options"
16 "github.com/mongodb/mongo-go-driver/x/mongo/driver"
17 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
18 "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
19 "github.com/mongodb/mongo-go-driver/x/network/command"
20 "github.com/mongodb/mongo-go-driver/x/network/description"
21)
22
23// ErrWrongClient is returned when a user attempts to pass in a session created by a different client than
24// the method call is using.
25var ErrWrongClient = errors.New("session was not created by this client")
26
27// SessionContext is a hybrid interface. It combines a context.Context with
28// a mongo.Session. This type can be used as a regular context.Context or
29// Session type. It is not goroutine safe and should not be used in multiple goroutines concurrently.
30type SessionContext interface {
31 context.Context
32 Session
33}
34
35type sessionContext struct {
36 context.Context
37 Session
38}
39
40type sessionKey struct {
41}
42
43// Session is the interface that represents a sequential set of operations executed.
44// Instances of this interface can be used to use transactions against the server
45// and to enable causally consistent behavior for applications.
46type Session interface {
47 EndSession(context.Context)
48 StartTransaction(...*options.TransactionOptions) error
49 AbortTransaction(context.Context) error
50 CommitTransaction(context.Context) error
51 ClusterTime() bson.Raw
52 AdvanceClusterTime(bson.Raw) error
53 OperationTime() *primitive.Timestamp
54 AdvanceOperationTime(*primitive.Timestamp) error
55 session()
56}
57
58// sessionImpl represents a set of sequential operations executed by an application that are related in some way.
59type sessionImpl struct {
60 *session.Client
61 topo *topology.Topology
62 didCommitAfterStart bool // true if commit was called after start with no other operations
63}
64
65// EndSession ends the session.
66func (s *sessionImpl) EndSession(ctx context.Context) {
67 if s.TransactionInProgress() {
68 // ignore all errors aborting during an end session
69 _ = s.AbortTransaction(ctx)
70 }
71 s.Client.EndSession()
72}
73
74// StartTransaction starts a transaction for this session.
75func (s *sessionImpl) StartTransaction(opts ...*options.TransactionOptions) error {
76 err := s.CheckStartTransaction()
77 if err != nil {
78 return err
79 }
80
81 s.didCommitAfterStart = false
82
83 topts := options.MergeTransactionOptions(opts...)
84 coreOpts := &session.TransactionOptions{
85 ReadConcern: topts.ReadConcern,
86 ReadPreference: topts.ReadPreference,
87 WriteConcern: topts.WriteConcern,
88 }
89
90 return s.Client.StartTransaction(coreOpts)
91}
92
93// AbortTransaction aborts the session's transaction, returning any errors and error codes
94func (s *sessionImpl) AbortTransaction(ctx context.Context) error {
95 err := s.CheckAbortTransaction()
96 if err != nil {
97 return err
98 }
99
100 cmd := command.AbortTransaction{
101 Session: s.Client,
102 }
103
104 s.Aborting = true
105 _, err = driver.AbortTransaction(ctx, cmd, s.topo, description.WriteSelector())
106
107 _ = s.Client.AbortTransaction()
108 return err
109}
110
111// CommitTransaction commits the sesson's transaction.
112func (s *sessionImpl) CommitTransaction(ctx context.Context) error {
113 err := s.CheckCommitTransaction()
114 if err != nil {
115 return err
116 }
117
118 // Do not run the commit command if the transaction is in started state
119 if s.TransactionStarting() || s.didCommitAfterStart {
120 s.didCommitAfterStart = true
121 return s.Client.CommitTransaction()
122 }
123
124 if s.Client.TransactionCommitted() {
125 s.RetryingCommit = true
126 }
127
128 cmd := command.CommitTransaction{
129 Session: s.Client,
130 }
131
132 // Hack to ensure that session stays in committed state
133 if s.TransactionCommitted() {
134 s.Committing = true
135 defer func() {
136 s.Committing = false
137 }()
138 }
139 _, err = driver.CommitTransaction(ctx, cmd, s.topo, description.WriteSelector())
140 if err == nil {
141 return s.Client.CommitTransaction()
142 }
143 return err
144}
145
146func (s *sessionImpl) ClusterTime() bson.Raw {
147 return s.Client.ClusterTime
148}
149
150func (s *sessionImpl) AdvanceClusterTime(d bson.Raw) error {
151 return s.Client.AdvanceClusterTime(d)
152}
153
154func (s *sessionImpl) OperationTime() *primitive.Timestamp {
155 return s.Client.OperationTime
156}
157
158func (s *sessionImpl) AdvanceOperationTime(ts *primitive.Timestamp) error {
159 return s.Client.AdvanceOperationTime(ts)
160}
161
162func (*sessionImpl) session() {
163}
164
165// sessionFromContext checks for a sessionImpl in the argued context and returns the session if it
166// exists
167func sessionFromContext(ctx context.Context) *session.Client {
168 s := ctx.Value(sessionKey{})
169 if ses, ok := s.(*sessionImpl); ses != nil && ok {
170 return ses.Client
171 }
172
173 return nil
174}
175
176func contextWithSession(ctx context.Context, sess Session) SessionContext {
177 return &sessionContext{
178 Context: context.WithValue(ctx, sessionKey{}, sess),
179 Session: sess,
180 }
181}