| // Copyright (C) MongoDB, Inc. 2017-present. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); you may |
| // not use this file except in compliance with the License. You may obtain |
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
| |
| package mongo |
| |
| import ( |
| "context" |
| "errors" |
| |
| "github.com/mongodb/mongo-go-driver/bson" |
| "github.com/mongodb/mongo-go-driver/bson/primitive" |
| "github.com/mongodb/mongo-go-driver/mongo/options" |
| "github.com/mongodb/mongo-go-driver/x/mongo/driver" |
| "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" |
| "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology" |
| "github.com/mongodb/mongo-go-driver/x/network/command" |
| "github.com/mongodb/mongo-go-driver/x/network/description" |
| ) |
| |
| // ErrWrongClient is returned when a user attempts to pass in a session created by a different client than |
| // the method call is using. |
| var ErrWrongClient = errors.New("session was not created by this client") |
| |
| // SessionContext is a hybrid interface. It combines a context.Context with |
| // a mongo.Session. This type can be used as a regular context.Context or |
| // Session type. It is not goroutine safe and should not be used in multiple goroutines concurrently. |
| type SessionContext interface { |
| context.Context |
| Session |
| } |
| |
| type sessionContext struct { |
| context.Context |
| Session |
| } |
| |
| type sessionKey struct { |
| } |
| |
| // Session is the interface that represents a sequential set of operations executed. |
| // Instances of this interface can be used to use transactions against the server |
| // and to enable causally consistent behavior for applications. |
| type Session interface { |
| EndSession(context.Context) |
| StartTransaction(...*options.TransactionOptions) error |
| AbortTransaction(context.Context) error |
| CommitTransaction(context.Context) error |
| ClusterTime() bson.Raw |
| AdvanceClusterTime(bson.Raw) error |
| OperationTime() *primitive.Timestamp |
| AdvanceOperationTime(*primitive.Timestamp) error |
| session() |
| } |
| |
| // sessionImpl represents a set of sequential operations executed by an application that are related in some way. |
| type sessionImpl struct { |
| *session.Client |
| topo *topology.Topology |
| didCommitAfterStart bool // true if commit was called after start with no other operations |
| } |
| |
| // EndSession ends the session. |
| func (s *sessionImpl) EndSession(ctx context.Context) { |
| if s.TransactionInProgress() { |
| // ignore all errors aborting during an end session |
| _ = s.AbortTransaction(ctx) |
| } |
| s.Client.EndSession() |
| } |
| |
| // StartTransaction starts a transaction for this session. |
| func (s *sessionImpl) StartTransaction(opts ...*options.TransactionOptions) error { |
| err := s.CheckStartTransaction() |
| if err != nil { |
| return err |
| } |
| |
| s.didCommitAfterStart = false |
| |
| topts := options.MergeTransactionOptions(opts...) |
| coreOpts := &session.TransactionOptions{ |
| ReadConcern: topts.ReadConcern, |
| ReadPreference: topts.ReadPreference, |
| WriteConcern: topts.WriteConcern, |
| } |
| |
| return s.Client.StartTransaction(coreOpts) |
| } |
| |
| // AbortTransaction aborts the session's transaction, returning any errors and error codes |
| func (s *sessionImpl) AbortTransaction(ctx context.Context) error { |
| err := s.CheckAbortTransaction() |
| if err != nil { |
| return err |
| } |
| |
| cmd := command.AbortTransaction{ |
| Session: s.Client, |
| } |
| |
| s.Aborting = true |
| _, err = driver.AbortTransaction(ctx, cmd, s.topo, description.WriteSelector()) |
| |
| _ = s.Client.AbortTransaction() |
| return err |
| } |
| |
| // CommitTransaction commits the sesson's transaction. |
| func (s *sessionImpl) CommitTransaction(ctx context.Context) error { |
| err := s.CheckCommitTransaction() |
| if err != nil { |
| return err |
| } |
| |
| // Do not run the commit command if the transaction is in started state |
| if s.TransactionStarting() || s.didCommitAfterStart { |
| s.didCommitAfterStart = true |
| return s.Client.CommitTransaction() |
| } |
| |
| if s.Client.TransactionCommitted() { |
| s.RetryingCommit = true |
| } |
| |
| cmd := command.CommitTransaction{ |
| Session: s.Client, |
| } |
| |
| // Hack to ensure that session stays in committed state |
| if s.TransactionCommitted() { |
| s.Committing = true |
| defer func() { |
| s.Committing = false |
| }() |
| } |
| _, err = driver.CommitTransaction(ctx, cmd, s.topo, description.WriteSelector()) |
| if err == nil { |
| return s.Client.CommitTransaction() |
| } |
| return err |
| } |
| |
| func (s *sessionImpl) ClusterTime() bson.Raw { |
| return s.Client.ClusterTime |
| } |
| |
| func (s *sessionImpl) AdvanceClusterTime(d bson.Raw) error { |
| return s.Client.AdvanceClusterTime(d) |
| } |
| |
| func (s *sessionImpl) OperationTime() *primitive.Timestamp { |
| return s.Client.OperationTime |
| } |
| |
| func (s *sessionImpl) AdvanceOperationTime(ts *primitive.Timestamp) error { |
| return s.Client.AdvanceOperationTime(ts) |
| } |
| |
| func (*sessionImpl) session() { |
| } |
| |
| // sessionFromContext checks for a sessionImpl in the argued context and returns the session if it |
| // exists |
| func sessionFromContext(ctx context.Context) *session.Client { |
| s := ctx.Value(sessionKey{}) |
| if ses, ok := s.(*sessionImpl); ses != nil && ok { |
| return ses.Client |
| } |
| |
| return nil |
| } |
| |
| func contextWithSession(ctx context.Context, sess Session) SessionContext { |
| return &sessionContext{ |
| Context: context.WithValue(ctx, sessionKey{}, sess), |
| Session: sess, |
| } |
| } |