seba-365 - implemented dep
Change-Id: Ia6226d50e7615935a0c8876809a687427ff88c22
diff --git a/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/connection.go b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/connection.go
new file mode 100644
index 0000000..d59f5b5
--- /dev/null
+++ b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/connection.go
@@ -0,0 +1,96 @@
+// Copyright (C) MongoDB, Inc. 2017-present.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+
+package topology
+
+import (
+ "context"
+ "net"
+
+ "strings"
+
+ "github.com/mongodb/mongo-go-driver/x/network/command"
+ "github.com/mongodb/mongo-go-driver/x/network/connection"
+ "github.com/mongodb/mongo-go-driver/x/network/description"
+ "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
+)
+
+// sconn is a wrapper around a connection.Connection. This type is returned by
+// a Server so that it can track network errors and when a non-timeout network
+// error is returned, the pool on the server can be cleared.
+type sconn struct {
+ connection.Connection
+ s *Server
+ id uint64
+}
+
+var notMasterCodes = []int32{10107, 13435}
+var recoveringCodes = []int32{11600, 11602, 13436, 189, 91}
+
+func (sc *sconn) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) {
+ wm, err := sc.Connection.ReadWireMessage(ctx)
+ if err != nil {
+ sc.processErr(err)
+ } else {
+ e := command.DecodeError(wm)
+ sc.processErr(e)
+ }
+ return wm, err
+}
+
+func (sc *sconn) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
+ err := sc.Connection.WriteWireMessage(ctx, wm)
+ sc.processErr(err)
+ return err
+}
+
+func (sc *sconn) processErr(err error) {
+ // TODO(GODRIVER-524) handle the rest of sdam error handling
+ // Invalidate server description if not master or node recovering error occurs
+ if cerr, ok := err.(command.Error); ok && (isRecoveringError(cerr) || isNotMasterError(cerr)) {
+ desc := sc.s.Description()
+ desc.Kind = description.Unknown
+ desc.LastError = err
+ // updates description to unknown
+ sc.s.updateDescription(desc, false)
+ }
+
+ ne, ok := err.(connection.NetworkError)
+ if !ok {
+ return
+ }
+
+ if netErr, ok := ne.Wrapped.(net.Error); ok && netErr.Timeout() {
+ return
+ }
+ if ne.Wrapped == context.Canceled || ne.Wrapped == context.DeadlineExceeded {
+ return
+ }
+
+ desc := sc.s.Description()
+ desc.Kind = description.Unknown
+ desc.LastError = err
+ // updates description to unknown
+ sc.s.updateDescription(desc, false)
+}
+
+func isRecoveringError(err command.Error) bool {
+ for _, c := range recoveringCodes {
+ if c == err.Code {
+ return true
+ }
+ }
+ return strings.Contains(err.Error(), "node is recovering")
+}
+
+func isNotMasterError(err command.Error) bool {
+ for _, c := range notMasterCodes {
+ if c == err.Code {
+ return true
+ }
+ }
+ return strings.Contains(err.Error(), "not master")
+}
diff --git a/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/fsm.go b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/fsm.go
new file mode 100644
index 0000000..3682b57
--- /dev/null
+++ b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/fsm.go
@@ -0,0 +1,350 @@
+// Copyright (C) MongoDB, Inc. 2017-present.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+
+package topology
+
+import (
+ "bytes"
+ "fmt"
+
+ "github.com/mongodb/mongo-go-driver/bson/primitive"
+ "github.com/mongodb/mongo-go-driver/x/network/address"
+ "github.com/mongodb/mongo-go-driver/x/network/description"
+)
+
+var supportedWireVersions = description.NewVersionRange(2, 6)
+var minSupportedMongoDBVersion = "2.6"
+
+type fsm struct {
+ description.Topology
+ SetName string
+ maxElectionID primitive.ObjectID
+ maxSetVersion uint32
+}
+
+func newFSM() *fsm {
+ return new(fsm)
+}
+
+// apply should operate on immutable TopologyDescriptions and Descriptions. This way we don't have to
+// lock for the entire time we're applying server description.
+func (f *fsm) apply(s description.Server) (description.Topology, error) {
+
+ newServers := make([]description.Server, len(f.Servers))
+ copy(newServers, f.Servers)
+
+ oldMinutes := f.SessionTimeoutMinutes
+ f.Topology = description.Topology{
+ Kind: f.Kind,
+ Servers: newServers,
+ }
+
+ // For data bearing servers, set SessionTimeoutMinutes to the lowest among them
+ if oldMinutes == 0 {
+ // If timeout currently 0, check all servers to see if any still don't have a timeout
+ // If they all have timeout, pick the lowest.
+ timeout := s.SessionTimeoutMinutes
+ for _, server := range f.Servers {
+ if server.DataBearing() && server.SessionTimeoutMinutes < timeout {
+ timeout = server.SessionTimeoutMinutes
+ }
+ }
+ f.SessionTimeoutMinutes = timeout
+ } else {
+ if s.DataBearing() && oldMinutes > s.SessionTimeoutMinutes {
+ f.SessionTimeoutMinutes = s.SessionTimeoutMinutes
+ } else {
+ f.SessionTimeoutMinutes = oldMinutes
+ }
+ }
+
+ if _, ok := f.findServer(s.Addr); !ok {
+ return f.Topology, nil
+ }
+
+ if s.WireVersion != nil {
+ if s.WireVersion.Max < supportedWireVersions.Min {
+ return description.Topology{}, fmt.Errorf(
+ "server at %s reports wire version %d, but this version of the Go driver requires "+
+ "at least %d (MongoDB %s)",
+ s.Addr.String(),
+ s.WireVersion.Max,
+ supportedWireVersions.Min,
+ minSupportedMongoDBVersion,
+ )
+ }
+
+ if s.WireVersion.Min > supportedWireVersions.Max {
+ return description.Topology{}, fmt.Errorf(
+ "server at %s requires wire version %d, but this version of the Go driver only "+
+ "supports up to %d",
+ s.Addr.String(),
+ s.WireVersion.Min,
+ supportedWireVersions.Max,
+ )
+ }
+ }
+
+ switch f.Kind {
+ case description.Unknown:
+ f.applyToUnknown(s)
+ case description.Sharded:
+ f.applyToSharded(s)
+ case description.ReplicaSetNoPrimary:
+ f.applyToReplicaSetNoPrimary(s)
+ case description.ReplicaSetWithPrimary:
+ f.applyToReplicaSetWithPrimary(s)
+ case description.Single:
+ f.applyToSingle(s)
+ }
+
+ return f.Topology, nil
+}
+
+func (f *fsm) applyToReplicaSetNoPrimary(s description.Server) {
+ switch s.Kind {
+ case description.Standalone, description.Mongos:
+ f.removeServerByAddr(s.Addr)
+ case description.RSPrimary:
+ f.updateRSFromPrimary(s)
+ case description.RSSecondary, description.RSArbiter, description.RSMember:
+ f.updateRSWithoutPrimary(s)
+ case description.Unknown, description.RSGhost:
+ f.replaceServer(s)
+ }
+}
+
+func (f *fsm) applyToReplicaSetWithPrimary(s description.Server) {
+ switch s.Kind {
+ case description.Standalone, description.Mongos:
+ f.removeServerByAddr(s.Addr)
+ f.checkIfHasPrimary()
+ case description.RSPrimary:
+ f.updateRSFromPrimary(s)
+ case description.RSSecondary, description.RSArbiter, description.RSMember:
+ f.updateRSWithPrimaryFromMember(s)
+ case description.Unknown, description.RSGhost:
+ f.replaceServer(s)
+ f.checkIfHasPrimary()
+ }
+}
+
+func (f *fsm) applyToSharded(s description.Server) {
+ switch s.Kind {
+ case description.Mongos, description.Unknown:
+ f.replaceServer(s)
+ case description.Standalone, description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
+ f.removeServerByAddr(s.Addr)
+ }
+}
+
+func (f *fsm) applyToSingle(s description.Server) {
+ switch s.Kind {
+ case description.Unknown:
+ f.replaceServer(s)
+ case description.Standalone, description.Mongos:
+ if f.SetName != "" {
+ f.removeServerByAddr(s.Addr)
+ return
+ }
+
+ f.replaceServer(s)
+ case description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
+ if f.SetName != "" && f.SetName != s.SetName {
+ f.removeServerByAddr(s.Addr)
+ return
+ }
+
+ f.replaceServer(s)
+ }
+}
+
+func (f *fsm) applyToUnknown(s description.Server) {
+ switch s.Kind {
+ case description.Mongos:
+ f.setKind(description.Sharded)
+ f.replaceServer(s)
+ case description.RSPrimary:
+ f.updateRSFromPrimary(s)
+ case description.RSSecondary, description.RSArbiter, description.RSMember:
+ f.setKind(description.ReplicaSetNoPrimary)
+ f.updateRSWithoutPrimary(s)
+ case description.Standalone:
+ f.updateUnknownWithStandalone(s)
+ case description.Unknown, description.RSGhost:
+ f.replaceServer(s)
+ }
+}
+
+func (f *fsm) checkIfHasPrimary() {
+ if _, ok := f.findPrimary(); ok {
+ f.setKind(description.ReplicaSetWithPrimary)
+ } else {
+ f.setKind(description.ReplicaSetNoPrimary)
+ }
+}
+
+func (f *fsm) updateRSFromPrimary(s description.Server) {
+ if f.SetName == "" {
+ f.SetName = s.SetName
+ } else if f.SetName != s.SetName {
+ f.removeServerByAddr(s.Addr)
+ f.checkIfHasPrimary()
+ return
+ }
+
+ if s.SetVersion != 0 && !bytes.Equal(s.ElectionID[:], primitive.NilObjectID[:]) {
+ if f.maxSetVersion > s.SetVersion || bytes.Compare(f.maxElectionID[:], s.ElectionID[:]) == 1 {
+ f.replaceServer(description.Server{
+ Addr: s.Addr,
+ LastError: fmt.Errorf("was a primary, but its set version or election id is stale"),
+ })
+ f.checkIfHasPrimary()
+ return
+ }
+
+ f.maxElectionID = s.ElectionID
+ }
+
+ if s.SetVersion > f.maxSetVersion {
+ f.maxSetVersion = s.SetVersion
+ }
+
+ if j, ok := f.findPrimary(); ok {
+ f.setServer(j, description.Server{
+ Addr: f.Servers[j].Addr,
+ LastError: fmt.Errorf("was a primary, but a new primary was discovered"),
+ })
+ }
+
+ f.replaceServer(s)
+
+ for j := len(f.Servers) - 1; j >= 0; j-- {
+ found := false
+ for _, member := range s.Members {
+ if member == f.Servers[j].Addr {
+ found = true
+ break
+ }
+ }
+ if !found {
+ f.removeServer(j)
+ }
+ }
+
+ for _, member := range s.Members {
+ if _, ok := f.findServer(member); !ok {
+ f.addServer(member)
+ }
+ }
+
+ f.checkIfHasPrimary()
+}
+
+func (f *fsm) updateRSWithPrimaryFromMember(s description.Server) {
+ if f.SetName != s.SetName {
+ f.removeServerByAddr(s.Addr)
+ f.checkIfHasPrimary()
+ return
+ }
+
+ if s.Addr != s.CanonicalAddr {
+ f.removeServerByAddr(s.Addr)
+ f.checkIfHasPrimary()
+ return
+ }
+
+ f.replaceServer(s)
+
+ if _, ok := f.findPrimary(); !ok {
+ f.setKind(description.ReplicaSetNoPrimary)
+ }
+}
+
+func (f *fsm) updateRSWithoutPrimary(s description.Server) {
+ if f.SetName == "" {
+ f.SetName = s.SetName
+ } else if f.SetName != s.SetName {
+ f.removeServerByAddr(s.Addr)
+ return
+ }
+
+ for _, member := range s.Members {
+ if _, ok := f.findServer(member); !ok {
+ f.addServer(member)
+ }
+ }
+
+ if s.Addr != s.CanonicalAddr {
+ f.removeServerByAddr(s.Addr)
+ return
+ }
+
+ f.replaceServer(s)
+}
+
+func (f *fsm) updateUnknownWithStandalone(s description.Server) {
+ if len(f.Servers) > 1 {
+ f.removeServerByAddr(s.Addr)
+ return
+ }
+
+ f.setKind(description.Single)
+ f.replaceServer(s)
+}
+
+func (f *fsm) addServer(addr address.Address) {
+ f.Servers = append(f.Servers, description.Server{
+ Addr: addr.Canonicalize(),
+ })
+}
+
+func (f *fsm) findPrimary() (int, bool) {
+ for i, s := range f.Servers {
+ if s.Kind == description.RSPrimary {
+ return i, true
+ }
+ }
+
+ return 0, false
+}
+
+func (f *fsm) findServer(addr address.Address) (int, bool) {
+ canon := addr.Canonicalize()
+ for i, s := range f.Servers {
+ if canon == s.Addr {
+ return i, true
+ }
+ }
+
+ return 0, false
+}
+
+func (f *fsm) removeServer(i int) {
+ f.Servers = append(f.Servers[:i], f.Servers[i+1:]...)
+}
+
+func (f *fsm) removeServerByAddr(addr address.Address) {
+ if i, ok := f.findServer(addr); ok {
+ f.removeServer(i)
+ }
+}
+
+func (f *fsm) replaceServer(s description.Server) bool {
+ if i, ok := f.findServer(s.Addr); ok {
+ f.setServer(i, s)
+ return true
+ }
+ return false
+}
+
+func (f *fsm) setServer(i int, s description.Server) {
+ f.Servers[i] = s
+}
+
+func (f *fsm) setKind(k description.TopologyKind) {
+ f.Kind = k
+}
diff --git a/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/server.go b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/server.go
new file mode 100644
index 0000000..3a7ace2
--- /dev/null
+++ b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/server.go
@@ -0,0 +1,506 @@
+// Copyright (C) MongoDB, Inc. 2017-present.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+
+package topology
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/mongodb/mongo-go-driver/event"
+ "github.com/mongodb/mongo-go-driver/x/mongo/driver/auth"
+ "github.com/mongodb/mongo-go-driver/x/network/address"
+ "github.com/mongodb/mongo-go-driver/x/network/command"
+ "github.com/mongodb/mongo-go-driver/x/network/connection"
+ "github.com/mongodb/mongo-go-driver/x/network/description"
+)
+
+const minHeartbeatInterval = 500 * time.Millisecond
+const connectionSemaphoreSize = math.MaxInt64
+
+// ErrServerClosed occurs when an attempt to get a connection is made after
+// the server has been closed.
+var ErrServerClosed = errors.New("server is closed")
+
+// ErrServerConnected occurs when at attempt to connect is made after a server
+// has already been connected.
+var ErrServerConnected = errors.New("server is connected")
+
+// SelectedServer represents a specific server that was selected during server selection.
+// It contains the kind of the typology it was selected from.
+type SelectedServer struct {
+ *Server
+
+ Kind description.TopologyKind
+}
+
+// Description returns a description of the server as of the last heartbeat.
+func (ss *SelectedServer) Description() description.SelectedServer {
+ sdesc := ss.Server.Description()
+ return description.SelectedServer{
+ Server: sdesc,
+ Kind: ss.Kind,
+ }
+}
+
+// These constants represent the connection states of a server.
+const (
+ disconnected int32 = iota
+ disconnecting
+ connected
+ connecting
+)
+
+func connectionStateString(state int32) string {
+ switch state {
+ case 0:
+ return "Disconnected"
+ case 1:
+ return "Disconnecting"
+ case 2:
+ return "Connected"
+ case 3:
+ return "Connecting"
+ }
+
+ return ""
+}
+
+// Server is a single server within a topology.
+type Server struct {
+ cfg *serverConfig
+ address address.Address
+
+ connectionstate int32
+ done chan struct{}
+ checkNow chan struct{}
+ closewg sync.WaitGroup
+ pool connection.Pool
+
+ desc atomic.Value // holds a description.Server
+
+ averageRTTSet bool
+ averageRTT time.Duration
+
+ subLock sync.Mutex
+ subscribers map[uint64]chan description.Server
+ currentSubscriberID uint64
+
+ subscriptionsClosed bool
+}
+
+// ConnectServer creates a new Server and then initializes it using the
+// Connect method.
+func ConnectServer(ctx context.Context, addr address.Address, opts ...ServerOption) (*Server, error) {
+ srvr, err := NewServer(addr, opts...)
+ if err != nil {
+ return nil, err
+ }
+ err = srvr.Connect(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return srvr, nil
+}
+
+// NewServer creates a new server. The mongodb server at the address will be monitored
+// on an internal monitoring goroutine.
+func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
+ cfg, err := newServerConfig(opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ s := &Server{
+ cfg: cfg,
+ address: addr,
+
+ done: make(chan struct{}),
+ checkNow: make(chan struct{}, 1),
+
+ subscribers: make(map[uint64]chan description.Server),
+ }
+ s.desc.Store(description.Server{Addr: addr})
+
+ var maxConns uint64
+ if cfg.maxConns == 0 {
+ maxConns = math.MaxInt64
+ } else {
+ maxConns = uint64(cfg.maxConns)
+ }
+
+ s.pool, err = connection.NewPool(addr, uint64(cfg.maxIdleConns), maxConns, cfg.connectionOpts...)
+ if err != nil {
+ return nil, err
+ }
+
+ return s, nil
+}
+
+// Connect initialzies the Server by starting background monitoring goroutines.
+// This method must be called before a Server can be used.
+func (s *Server) Connect(ctx context.Context) error {
+ if !atomic.CompareAndSwapInt32(&s.connectionstate, disconnected, connected) {
+ return ErrServerConnected
+ }
+ s.desc.Store(description.Server{Addr: s.address})
+ go s.update()
+ s.closewg.Add(1)
+ return s.pool.Connect(ctx)
+}
+
+// Disconnect closes sockets to the server referenced by this Server.
+// Subscriptions to this Server will be closed. Disconnect will shutdown
+// any monitoring goroutines, close the idle connection pool, and will
+// wait until all the in use connections have been returned to the connection
+// pool and are closed before returning. If the context expires via
+// cancellation, deadline, or timeout before the in use connections have been
+// returned, the in use connections will be closed, resulting in the failure of
+// any in flight read or write operations. If this method returns with no
+// errors, all connections associated with this Server have been closed.
+func (s *Server) Disconnect(ctx context.Context) error {
+ if !atomic.CompareAndSwapInt32(&s.connectionstate, connected, disconnecting) {
+ return ErrServerClosed
+ }
+
+ // For every call to Connect there must be at least 1 goroutine that is
+ // waiting on the done channel.
+ s.done <- struct{}{}
+ err := s.pool.Disconnect(ctx)
+ if err != nil {
+ return err
+ }
+
+ s.closewg.Wait()
+ atomic.StoreInt32(&s.connectionstate, disconnected)
+
+ return nil
+}
+
+// Connection gets a connection to the server.
+func (s *Server) Connection(ctx context.Context) (connection.Connection, error) {
+ if atomic.LoadInt32(&s.connectionstate) != connected {
+ return nil, ErrServerClosed
+ }
+ conn, desc, err := s.pool.Get(ctx)
+ if err != nil {
+ if _, ok := err.(*auth.Error); ok {
+ // authentication error --> drain connection
+ _ = s.pool.Drain()
+ }
+ if _, ok := err.(*connection.NetworkError); ok {
+ // update description to unknown and clears the connection pool
+ if desc != nil {
+ desc.Kind = description.Unknown
+ desc.LastError = err
+ s.updateDescription(*desc, false)
+ } else {
+ _ = s.pool.Drain()
+ }
+ }
+ return nil, err
+ }
+ if desc != nil {
+ go s.updateDescription(*desc, false)
+ }
+ sc := &sconn{Connection: conn, s: s}
+ return sc, nil
+}
+
+// Description returns a description of the server as of the last heartbeat.
+func (s *Server) Description() description.Server {
+ return s.desc.Load().(description.Server)
+}
+
+// SelectedDescription returns a description.SelectedServer with a Kind of
+// Single. This can be used when performing tasks like monitoring a batch
+// of servers and you want to run one off commands against those servers.
+func (s *Server) SelectedDescription() description.SelectedServer {
+ sdesc := s.Description()
+ return description.SelectedServer{
+ Server: sdesc,
+ Kind: description.Single,
+ }
+}
+
+// Subscribe returns a ServerSubscription which has a channel on which all
+// updated server descriptions will be sent. The channel will have a buffer
+// size of one, and will be pre-populated with the current description.
+func (s *Server) Subscribe() (*ServerSubscription, error) {
+ if atomic.LoadInt32(&s.connectionstate) != connected {
+ return nil, ErrSubscribeAfterClosed
+ }
+ ch := make(chan description.Server, 1)
+ ch <- s.desc.Load().(description.Server)
+
+ s.subLock.Lock()
+ defer s.subLock.Unlock()
+ if s.subscriptionsClosed {
+ return nil, ErrSubscribeAfterClosed
+ }
+ id := s.currentSubscriberID
+ s.subscribers[id] = ch
+ s.currentSubscriberID++
+
+ ss := &ServerSubscription{
+ C: ch,
+ s: s,
+ id: id,
+ }
+
+ return ss, nil
+}
+
+// RequestImmediateCheck will cause the server to send a heartbeat immediately
+// instead of waiting for the heartbeat timeout.
+func (s *Server) RequestImmediateCheck() {
+ select {
+ case s.checkNow <- struct{}{}:
+ default:
+ }
+}
+
+// update handles performing heartbeats and updating any subscribers of the
+// newest description.Server retrieved.
+func (s *Server) update() {
+ defer s.closewg.Done()
+ heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
+ rateLimiter := time.NewTicker(minHeartbeatInterval)
+ defer heartbeatTicker.Stop()
+ defer rateLimiter.Stop()
+ checkNow := s.checkNow
+ done := s.done
+
+ var doneOnce bool
+ defer func() {
+ if r := recover(); r != nil {
+ if doneOnce {
+ return
+ }
+ // We keep this goroutine alive attempting to read from the done channel.
+ <-done
+ }
+ }()
+
+ var conn connection.Connection
+ var desc description.Server
+
+ desc, conn = s.heartbeat(nil)
+ s.updateDescription(desc, true)
+
+ closeServer := func() {
+ doneOnce = true
+ s.subLock.Lock()
+ for id, c := range s.subscribers {
+ close(c)
+ delete(s.subscribers, id)
+ }
+ s.subscriptionsClosed = true
+ s.subLock.Unlock()
+ if conn == nil {
+ return
+ }
+ conn.Close()
+ }
+ for {
+ select {
+ case <-heartbeatTicker.C:
+ case <-checkNow:
+ case <-done:
+ closeServer()
+ return
+ }
+
+ select {
+ case <-rateLimiter.C:
+ case <-done:
+ closeServer()
+ return
+ }
+
+ desc, conn = s.heartbeat(conn)
+ s.updateDescription(desc, false)
+ }
+}
+
+// updateDescription handles updating the description on the Server, notifying
+// subscribers, and potentially draining the connection pool. The initial
+// parameter is used to determine if this is the first description from the
+// server.
+func (s *Server) updateDescription(desc description.Server, initial bool) {
+ defer func() {
+ // ¯\_(ツ)_/¯
+ _ = recover()
+ }()
+ s.desc.Store(desc)
+
+ s.subLock.Lock()
+ for _, c := range s.subscribers {
+ select {
+ // drain the channel if it isn't empty
+ case <-c:
+ default:
+ }
+ c <- desc
+ }
+ s.subLock.Unlock()
+
+ if initial {
+ // We don't clear the pool on the first update on the description.
+ return
+ }
+
+ switch desc.Kind {
+ case description.Unknown:
+ _ = s.pool.Drain()
+ }
+}
+
+// heartbeat sends a heartbeat to the server using the given connection. The connection can be nil.
+func (s *Server) heartbeat(conn connection.Connection) (description.Server, connection.Connection) {
+ const maxRetry = 2
+ var saved error
+ var desc description.Server
+ var set bool
+ var err error
+ ctx := context.Background()
+
+ for i := 1; i <= maxRetry; i++ {
+ if conn != nil && conn.Expired() {
+ conn.Close()
+ conn = nil
+ }
+
+ if conn == nil {
+ opts := []connection.Option{
+ connection.WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
+ connection.WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
+ connection.WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
+ }
+ opts = append(opts, s.cfg.connectionOpts...)
+ // We override whatever handshaker is currently attached to the options with an empty
+ // one because need to make sure we don't do auth.
+ opts = append(opts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker {
+ return nil
+ }))
+
+ // Override any command monitors specified in options with nil to avoid monitoring heartbeats.
+ opts = append(opts, connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
+ return nil
+ }))
+ conn, _, err = connection.New(ctx, s.address, opts...)
+ if err != nil {
+ saved = err
+ if conn != nil {
+ conn.Close()
+ }
+ conn = nil
+ continue
+ }
+ }
+
+ now := time.Now()
+
+ isMasterCmd := &command.IsMaster{Compressors: s.cfg.compressionOpts}
+ isMaster, err := isMasterCmd.RoundTrip(ctx, conn)
+ if err != nil {
+ saved = err
+ conn.Close()
+ conn = nil
+ continue
+ }
+
+ clusterTime := isMaster.ClusterTime
+ if s.cfg.clock != nil {
+ s.cfg.clock.AdvanceClusterTime(clusterTime)
+ }
+
+ delay := time.Since(now)
+ desc = description.NewServer(s.address, isMaster).SetAverageRTT(s.updateAverageRTT(delay))
+ desc.HeartbeatInterval = s.cfg.heartbeatInterval
+ set = true
+
+ break
+ }
+
+ if !set {
+ desc = description.Server{
+ Addr: s.address,
+ LastError: saved,
+ }
+ }
+
+ return desc, conn
+}
+
+func (s *Server) updateAverageRTT(delay time.Duration) time.Duration {
+ if !s.averageRTTSet {
+ s.averageRTT = delay
+ } else {
+ alpha := 0.2
+ s.averageRTT = time.Duration(alpha*float64(delay) + (1-alpha)*float64(s.averageRTT))
+ }
+ return s.averageRTT
+}
+
+// Drain will drain the connection pool of this server. This is mainly here so the
+// pool for the server doesn't need to be directly exposed and so that when an error
+// is returned from reading or writing, a client can drain the pool for this server.
+// This is exposed here so we don't have to wrap the Connection type and sniff responses
+// for errors that would cause the pool to be drained, which can in turn centralize the
+// logic for handling errors in the Client type.
+func (s *Server) Drain() error { return s.pool.Drain() }
+
+// String implements the Stringer interface.
+func (s *Server) String() string {
+ desc := s.Description()
+ str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
+ s.address, desc.Kind, connectionStateString(s.connectionstate))
+ if len(desc.Tags) != 0 {
+ str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
+ }
+ if s.connectionstate == connected {
+ str += fmt.Sprintf(", Avergage RTT: %d", s.averageRTT)
+ }
+ if desc.LastError != nil {
+ str += fmt.Sprintf(", Last error: %s", desc.LastError)
+ }
+
+ return str
+}
+
+// ServerSubscription represents a subscription to the description.Server updates for
+// a specific server.
+type ServerSubscription struct {
+ C <-chan description.Server
+ s *Server
+ id uint64
+}
+
+// Unsubscribe unsubscribes this ServerSubscription from updates and closes the
+// subscription channel.
+func (ss *ServerSubscription) Unsubscribe() error {
+ ss.s.subLock.Lock()
+ defer ss.s.subLock.Unlock()
+ if ss.s.subscriptionsClosed {
+ return nil
+ }
+
+ ch, ok := ss.s.subscribers[ss.id]
+ if !ok {
+ return nil
+ }
+
+ close(ch)
+ delete(ss.s.subscribers, ss.id)
+
+ return nil
+}
diff --git a/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/server_options.go b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/server_options.go
new file mode 100644
index 0000000..0ebbecf
--- /dev/null
+++ b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/server_options.go
@@ -0,0 +1,121 @@
+// Copyright (C) MongoDB, Inc. 2017-present.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+
+package topology
+
+import (
+ "time"
+
+ "github.com/mongodb/mongo-go-driver/bson"
+ "github.com/mongodb/mongo-go-driver/bson/bsoncodec"
+ "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
+ "github.com/mongodb/mongo-go-driver/x/network/connection"
+)
+
+var defaultRegistry = bson.NewRegistryBuilder().Build()
+
+type serverConfig struct {
+ clock *session.ClusterClock
+ compressionOpts []string
+ connectionOpts []connection.Option
+ appname string
+ heartbeatInterval time.Duration
+ heartbeatTimeout time.Duration
+ maxConns uint16
+ maxIdleConns uint16
+ registry *bsoncodec.Registry
+}
+
+func newServerConfig(opts ...ServerOption) (*serverConfig, error) {
+ cfg := &serverConfig{
+ heartbeatInterval: 10 * time.Second,
+ heartbeatTimeout: 10 * time.Second,
+ maxConns: 100,
+ maxIdleConns: 100,
+ registry: defaultRegistry,
+ }
+
+ for _, opt := range opts {
+ err := opt(cfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return cfg, nil
+}
+
+// ServerOption configures a server.
+type ServerOption func(*serverConfig) error
+
+// WithConnectionOptions configures the server's connections.
+func WithConnectionOptions(fn func(...connection.Option) []connection.Option) ServerOption {
+ return func(cfg *serverConfig) error {
+ cfg.connectionOpts = fn(cfg.connectionOpts...)
+ return nil
+ }
+}
+
+// WithCompressionOptions configures the server's compressors.
+func WithCompressionOptions(fn func(...string) []string) ServerOption {
+ return func(cfg *serverConfig) error {
+ cfg.compressionOpts = fn(cfg.compressionOpts...)
+ return nil
+ }
+}
+
+// WithHeartbeatInterval configures a server's heartbeat interval.
+func WithHeartbeatInterval(fn func(time.Duration) time.Duration) ServerOption {
+ return func(cfg *serverConfig) error {
+ cfg.heartbeatInterval = fn(cfg.heartbeatInterval)
+ return nil
+ }
+}
+
+// WithHeartbeatTimeout configures how long to wait for a heartbeat socket to
+// connection.
+func WithHeartbeatTimeout(fn func(time.Duration) time.Duration) ServerOption {
+ return func(cfg *serverConfig) error {
+ cfg.heartbeatTimeout = fn(cfg.heartbeatTimeout)
+ return nil
+ }
+}
+
+// WithMaxConnections configures the maximum number of connections to allow for
+// a given server. If max is 0, then there is no upper limit to the number of
+// connections.
+func WithMaxConnections(fn func(uint16) uint16) ServerOption {
+ return func(cfg *serverConfig) error {
+ cfg.maxConns = fn(cfg.maxConns)
+ return nil
+ }
+}
+
+// WithMaxIdleConnections configures the maximum number of idle connections
+// allowed for the server.
+func WithMaxIdleConnections(fn func(uint16) uint16) ServerOption {
+ return func(cfg *serverConfig) error {
+ cfg.maxIdleConns = fn(cfg.maxIdleConns)
+ return nil
+ }
+}
+
+// WithClock configures the ClusterClock for the server to use.
+func WithClock(fn func(clock *session.ClusterClock) *session.ClusterClock) ServerOption {
+ return func(cfg *serverConfig) error {
+ cfg.clock = fn(cfg.clock)
+ return nil
+ }
+}
+
+// WithRegistry configures the registry for the server to use when creating
+// cursors.
+func WithRegistry(fn func(*bsoncodec.Registry) *bsoncodec.Registry) ServerOption {
+ return func(cfg *serverConfig) error {
+ cfg.registry = fn(cfg.registry)
+ return nil
+ }
+}
diff --git a/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/topology.go b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/topology.go
new file mode 100644
index 0000000..09a319c
--- /dev/null
+++ b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/topology.go
@@ -0,0 +1,471 @@
+// Copyright (C) MongoDB, Inc. 2017-present.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+
+// Package topology contains types that handles the discovery, monitoring, and selection
+// of servers. This package is designed to expose enough inner workings of service discovery
+// and monitoring to allow low level applications to have fine grained control, while hiding
+// most of the detailed implementation of the algorithms.
+package topology
+
+import (
+ "context"
+ "errors"
+ "math/rand"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "fmt"
+
+ "github.com/mongodb/mongo-go-driver/bson/bsoncodec"
+ "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
+ "github.com/mongodb/mongo-go-driver/x/network/address"
+ "github.com/mongodb/mongo-go-driver/x/network/description"
+)
+
+// ErrSubscribeAfterClosed is returned when a user attempts to subscribe to a
+// closed Server or Topology.
+var ErrSubscribeAfterClosed = errors.New("cannot subscribe after close")
+
+// ErrTopologyClosed is returned when a user attempts to call a method on a
+// closed Topology.
+var ErrTopologyClosed = errors.New("topology is closed")
+
+// ErrTopologyConnected is returned whena user attempts to connect to an
+// already connected Topology.
+var ErrTopologyConnected = errors.New("topology is connected or connecting")
+
+// ErrServerSelectionTimeout is returned from server selection when the server
+// selection process took longer than allowed by the timeout.
+var ErrServerSelectionTimeout = errors.New("server selection timeout")
+
+// MonitorMode represents the way in which a server is monitored.
+type MonitorMode uint8
+
+// These constants are the available monitoring modes.
+const (
+ AutomaticMode MonitorMode = iota
+ SingleMode
+)
+
+// Topology represents a MongoDB deployment.
+type Topology struct {
+ registry *bsoncodec.Registry
+
+ connectionstate int32
+
+ cfg *config
+
+ desc atomic.Value // holds a description.Topology
+
+ done chan struct{}
+
+ fsm *fsm
+ changes chan description.Server
+ changeswg sync.WaitGroup
+
+ SessionPool *session.Pool
+
+ // This should really be encapsulated into it's own type. This will likely
+ // require a redesign so we can share a minimum of data between the
+ // subscribers and the topology.
+ subscribers map[uint64]chan description.Topology
+ currentSubscriberID uint64
+ subscriptionsClosed bool
+ subLock sync.Mutex
+
+ // We should redesign how we connect and handle individal servers. This is
+ // too difficult to maintain and it's rather easy to accidentally access
+ // the servers without acquiring the lock or checking if the servers are
+ // closed. This lock should also be an RWMutex.
+ serversLock sync.Mutex
+ serversClosed bool
+ servers map[address.Address]*Server
+
+ wg sync.WaitGroup
+}
+
+// New creates a new topology.
+func New(opts ...Option) (*Topology, error) {
+ cfg, err := newConfig(opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ t := &Topology{
+ cfg: cfg,
+ done: make(chan struct{}),
+ fsm: newFSM(),
+ changes: make(chan description.Server),
+ subscribers: make(map[uint64]chan description.Topology),
+ servers: make(map[address.Address]*Server),
+ }
+ t.desc.Store(description.Topology{})
+
+ if cfg.replicaSetName != "" {
+ t.fsm.SetName = cfg.replicaSetName
+ t.fsm.Kind = description.ReplicaSetNoPrimary
+ }
+
+ if cfg.mode == SingleMode {
+ t.fsm.Kind = description.Single
+ }
+
+ return t, nil
+}
+
+// Connect initializes a Topology and starts the monitoring process. This function
+// must be called to properly monitor the topology.
+func (t *Topology) Connect(ctx context.Context) error {
+ if !atomic.CompareAndSwapInt32(&t.connectionstate, disconnected, connecting) {
+ return ErrTopologyConnected
+ }
+
+ t.desc.Store(description.Topology{})
+ var err error
+ t.serversLock.Lock()
+ for _, a := range t.cfg.seedList {
+ addr := address.Address(a).Canonicalize()
+ t.fsm.Servers = append(t.fsm.Servers, description.Server{Addr: addr})
+ err = t.addServer(ctx, addr)
+ }
+ t.serversLock.Unlock()
+
+ go t.update()
+ t.changeswg.Add(1)
+
+ t.subscriptionsClosed = false // explicitly set in case topology was disconnected and then reconnected
+
+ atomic.StoreInt32(&t.connectionstate, connected)
+
+ // After connection, make a subscription to keep the pool updated
+ sub, err := t.Subscribe()
+ t.SessionPool = session.NewPool(sub.C)
+ return err
+}
+
+// Disconnect closes the topology. It stops the monitoring thread and
+// closes all open subscriptions.
+func (t *Topology) Disconnect(ctx context.Context) error {
+ if !atomic.CompareAndSwapInt32(&t.connectionstate, connected, disconnecting) {
+ return ErrTopologyClosed
+ }
+
+ t.serversLock.Lock()
+ t.serversClosed = true
+ for addr, server := range t.servers {
+ t.removeServer(ctx, addr, server)
+ }
+ t.serversLock.Unlock()
+
+ t.wg.Wait()
+ t.done <- struct{}{}
+ t.changeswg.Wait()
+
+ t.desc.Store(description.Topology{})
+
+ atomic.StoreInt32(&t.connectionstate, disconnected)
+ return nil
+}
+
+// Description returns a description of the topology.
+func (t *Topology) Description() description.Topology {
+ td, ok := t.desc.Load().(description.Topology)
+ if !ok {
+ td = description.Topology{}
+ }
+ return td
+}
+
+// Subscribe returns a Subscription on which all updated description.Topologys
+// will be sent. The channel of the subscription will have a buffer size of one,
+// and will be pre-populated with the current description.Topology.
+func (t *Topology) Subscribe() (*Subscription, error) {
+ if atomic.LoadInt32(&t.connectionstate) != connected {
+ return nil, errors.New("cannot subscribe to Topology that is not connected")
+ }
+ ch := make(chan description.Topology, 1)
+ td, ok := t.desc.Load().(description.Topology)
+ if !ok {
+ td = description.Topology{}
+ }
+ ch <- td
+
+ t.subLock.Lock()
+ defer t.subLock.Unlock()
+ if t.subscriptionsClosed {
+ return nil, ErrSubscribeAfterClosed
+ }
+ id := t.currentSubscriberID
+ t.subscribers[id] = ch
+ t.currentSubscriberID++
+
+ return &Subscription{
+ C: ch,
+ t: t,
+ id: id,
+ }, nil
+}
+
+// RequestImmediateCheck will send heartbeats to all the servers in the
+// topology right away, instead of waiting for the heartbeat timeout.
+func (t *Topology) RequestImmediateCheck() {
+ if atomic.LoadInt32(&t.connectionstate) != connected {
+ return
+ }
+ t.serversLock.Lock()
+ for _, server := range t.servers {
+ server.RequestImmediateCheck()
+ }
+ t.serversLock.Unlock()
+}
+
+// SupportsSessions returns true if the topology supports sessions.
+func (t *Topology) SupportsSessions() bool {
+ return t.Description().SessionTimeoutMinutes != 0 && t.Description().Kind != description.Single
+}
+
+// SelectServer selects a server given a selector.SelectServer complies with the
+// server selection spec, and will time out after severSelectionTimeout or when the
+// parent context is done.
+func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (*SelectedServer, error) {
+ if atomic.LoadInt32(&t.connectionstate) != connected {
+ return nil, ErrTopologyClosed
+ }
+ var ssTimeoutCh <-chan time.Time
+
+ if t.cfg.serverSelectionTimeout > 0 {
+ ssTimeout := time.NewTimer(t.cfg.serverSelectionTimeout)
+ ssTimeoutCh = ssTimeout.C
+ defer ssTimeout.Stop()
+ }
+
+ sub, err := t.Subscribe()
+ if err != nil {
+ return nil, err
+ }
+ defer sub.Unsubscribe()
+
+ for {
+ suitable, err := t.selectServer(ctx, sub.C, ss, ssTimeoutCh)
+ if err != nil {
+ return nil, err
+ }
+
+ selected := suitable[rand.Intn(len(suitable))]
+ selectedS, err := t.FindServer(selected)
+ switch {
+ case err != nil:
+ return nil, err
+ case selectedS != nil:
+ return selectedS, nil
+ default:
+ // We don't have an actual server for the provided description.
+ // This could happen for a number of reasons, including that the
+ // server has since stopped being a part of this topology, or that
+ // the server selector returned no suitable servers.
+ }
+ }
+}
+
+// FindServer will attempt to find a server that fits the given server description.
+// This method will return nil, nil if a matching server could not be found.
+func (t *Topology) FindServer(selected description.Server) (*SelectedServer, error) {
+ if atomic.LoadInt32(&t.connectionstate) != connected {
+ return nil, ErrTopologyClosed
+ }
+ t.serversLock.Lock()
+ defer t.serversLock.Unlock()
+ server, ok := t.servers[selected.Addr]
+ if !ok {
+ return nil, nil
+ }
+
+ desc := t.Description()
+ return &SelectedServer{
+ Server: server,
+ Kind: desc.Kind,
+ }, nil
+}
+
+func wrapServerSelectionError(err error, t *Topology) error {
+ return fmt.Errorf("server selection error: %v\ncurrent topology: %s", err, t.String())
+}
+
+// selectServer is the core piece of server selection. It handles getting
+// topology descriptions and running sever selection on those descriptions.
+func (t *Topology) selectServer(ctx context.Context, subscriptionCh <-chan description.Topology, ss description.ServerSelector, timeoutCh <-chan time.Time) ([]description.Server, error) {
+ var current description.Topology
+ for {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-timeoutCh:
+ return nil, wrapServerSelectionError(ErrServerSelectionTimeout, t)
+ case current = <-subscriptionCh:
+ }
+
+ var allowed []description.Server
+ for _, s := range current.Servers {
+ if s.Kind != description.Unknown {
+ allowed = append(allowed, s)
+ }
+ }
+
+ suitable, err := ss.SelectServer(current, allowed)
+ if err != nil {
+ return nil, wrapServerSelectionError(err, t)
+ }
+
+ if len(suitable) > 0 {
+ return suitable, nil
+ }
+
+ t.RequestImmediateCheck()
+ }
+}
+
+func (t *Topology) update() {
+ defer t.changeswg.Done()
+ defer func() {
+ // ¯\_(ツ)_/¯
+ if r := recover(); r != nil {
+ <-t.done
+ }
+ }()
+
+ for {
+ select {
+ case change := <-t.changes:
+ current, err := t.apply(context.TODO(), change)
+ if err != nil {
+ continue
+ }
+
+ t.desc.Store(current)
+ t.subLock.Lock()
+ for _, ch := range t.subscribers {
+ // We drain the description if there's one in the channel
+ select {
+ case <-ch:
+ default:
+ }
+ ch <- current
+ }
+ t.subLock.Unlock()
+ case <-t.done:
+ t.subLock.Lock()
+ for id, ch := range t.subscribers {
+ close(ch)
+ delete(t.subscribers, id)
+ }
+ t.subscriptionsClosed = true
+ t.subLock.Unlock()
+ return
+ }
+ }
+}
+
+func (t *Topology) apply(ctx context.Context, desc description.Server) (description.Topology, error) {
+ var err error
+ prev := t.fsm.Topology
+
+ current, err := t.fsm.apply(desc)
+ if err != nil {
+ return description.Topology{}, err
+ }
+
+ diff := description.DiffTopology(prev, current)
+ t.serversLock.Lock()
+ if t.serversClosed {
+ t.serversLock.Unlock()
+ return description.Topology{}, nil
+ }
+
+ for _, removed := range diff.Removed {
+ if s, ok := t.servers[removed.Addr]; ok {
+ t.removeServer(ctx, removed.Addr, s)
+ }
+ }
+
+ for _, added := range diff.Added {
+ _ = t.addServer(ctx, added.Addr)
+ }
+ t.serversLock.Unlock()
+ return current, nil
+}
+
+func (t *Topology) addServer(ctx context.Context, addr address.Address) error {
+ if _, ok := t.servers[addr]; ok {
+ return nil
+ }
+
+ svr, err := ConnectServer(ctx, addr, t.cfg.serverOpts...)
+ if err != nil {
+ return err
+ }
+
+ t.servers[addr] = svr
+ var sub *ServerSubscription
+ sub, err = svr.Subscribe()
+ if err != nil {
+ return err
+ }
+
+ t.wg.Add(1)
+ go func() {
+ for c := range sub.C {
+ t.changes <- c
+ }
+
+ t.wg.Done()
+ }()
+
+ return nil
+}
+
+func (t *Topology) removeServer(ctx context.Context, addr address.Address, server *Server) {
+ _ = server.Disconnect(ctx)
+ delete(t.servers, addr)
+}
+
+// String implements the Stringer interface
+func (t *Topology) String() string {
+ desc := t.Description()
+ str := fmt.Sprintf("Type: %s\nServers:\n", desc.Kind)
+ for _, s := range t.servers {
+ str += s.String() + "\n"
+ }
+ return str
+}
+
+// Subscription is a subscription to updates to the description of the Topology that created this
+// Subscription.
+type Subscription struct {
+ C <-chan description.Topology
+ t *Topology
+ id uint64
+}
+
+// Unsubscribe unsubscribes this Subscription from updates and closes the
+// subscription channel.
+func (s *Subscription) Unsubscribe() error {
+ s.t.subLock.Lock()
+ defer s.t.subLock.Unlock()
+ if s.t.subscriptionsClosed {
+ return nil
+ }
+
+ ch, ok := s.t.subscribers[s.id]
+ if !ok {
+ return nil
+ }
+
+ close(ch)
+ delete(s.t.subscribers, s.id)
+
+ return nil
+}
diff --git a/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/topology_options.go b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/topology_options.go
new file mode 100644
index 0000000..9fa98e6
--- /dev/null
+++ b/vendor/github.com/mongodb/mongo-go-driver/x/mongo/driver/topology/topology_options.go
@@ -0,0 +1,269 @@
+// Copyright (C) MongoDB, Inc. 2017-present.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License. You may obtain
+// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+
+package topology
+
+import (
+ "bytes"
+ "strings"
+ "time"
+
+ "github.com/mongodb/mongo-go-driver/x/mongo/driver/auth"
+ "github.com/mongodb/mongo-go-driver/x/network/command"
+ "github.com/mongodb/mongo-go-driver/x/network/compressor"
+ "github.com/mongodb/mongo-go-driver/x/network/connection"
+ "github.com/mongodb/mongo-go-driver/x/network/connstring"
+)
+
+// Option is a configuration option for a topology.
+type Option func(*config) error
+
+type config struct {
+ mode MonitorMode
+ replicaSetName string
+ seedList []string
+ serverOpts []ServerOption
+ cs connstring.ConnString
+ serverSelectionTimeout time.Duration
+}
+
+func newConfig(opts ...Option) (*config, error) {
+ cfg := &config{
+ seedList: []string{"localhost:27017"},
+ serverSelectionTimeout: 30 * time.Second,
+ }
+
+ for _, opt := range opts {
+ err := opt(cfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return cfg, nil
+}
+
+// WithConnString configures the topology using the connection string.
+func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option {
+ return func(c *config) error {
+ cs := fn(c.cs)
+ c.cs = cs
+
+ if cs.ServerSelectionTimeoutSet {
+ c.serverSelectionTimeout = cs.ServerSelectionTimeout
+ }
+
+ var connOpts []connection.Option
+
+ if cs.AppName != "" {
+ connOpts = append(connOpts, connection.WithAppName(func(string) string { return cs.AppName }))
+ }
+
+ switch cs.Connect {
+ case connstring.SingleConnect:
+ c.mode = SingleMode
+ }
+
+ c.seedList = cs.Hosts
+
+ if cs.ConnectTimeout > 0 {
+ c.serverOpts = append(c.serverOpts, WithHeartbeatTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout }))
+ connOpts = append(connOpts, connection.WithConnectTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout }))
+ }
+
+ if cs.SocketTimeoutSet {
+ connOpts = append(
+ connOpts,
+ connection.WithReadTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }),
+ connection.WithWriteTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }),
+ )
+ }
+
+ if cs.HeartbeatInterval > 0 {
+ c.serverOpts = append(c.serverOpts, WithHeartbeatInterval(func(time.Duration) time.Duration { return cs.HeartbeatInterval }))
+ }
+
+ if cs.MaxConnIdleTime > 0 {
+ connOpts = append(connOpts, connection.WithIdleTimeout(func(time.Duration) time.Duration { return cs.MaxConnIdleTime }))
+ }
+
+ if cs.MaxPoolSizeSet {
+ c.serverOpts = append(c.serverOpts, WithMaxConnections(func(uint16) uint16 { return cs.MaxPoolSize }))
+ c.serverOpts = append(c.serverOpts, WithMaxIdleConnections(func(uint16) uint16 { return cs.MaxPoolSize }))
+ }
+
+ if cs.ReplicaSet != "" {
+ c.replicaSetName = cs.ReplicaSet
+ }
+
+ var x509Username string
+ if cs.SSL {
+ tlsConfig := connection.NewTLSConfig()
+
+ if cs.SSLCaFileSet {
+ err := tlsConfig.AddCACertFromFile(cs.SSLCaFile)
+ if err != nil {
+ return err
+ }
+ }
+
+ if cs.SSLInsecure {
+ tlsConfig.SetInsecure(true)
+ }
+
+ if cs.SSLClientCertificateKeyFileSet {
+ if cs.SSLClientCertificateKeyPasswordSet && cs.SSLClientCertificateKeyPassword != nil {
+ tlsConfig.SetClientCertDecryptPassword(cs.SSLClientCertificateKeyPassword)
+ }
+ s, err := tlsConfig.AddClientCertFromFile(cs.SSLClientCertificateKeyFile)
+ if err != nil {
+ return err
+ }
+
+ // The Go x509 package gives the subject with the pairs in reverse order that we want.
+ pairs := strings.Split(s, ",")
+ b := bytes.NewBufferString("")
+
+ for i := len(pairs) - 1; i >= 0; i-- {
+ b.WriteString(pairs[i])
+
+ if i > 0 {
+ b.WriteString(",")
+ }
+ }
+
+ x509Username = b.String()
+ }
+
+ connOpts = append(connOpts, connection.WithTLSConfig(func(*connection.TLSConfig) *connection.TLSConfig { return tlsConfig }))
+ }
+
+ if cs.Username != "" || cs.AuthMechanism == auth.MongoDBX509 || cs.AuthMechanism == auth.GSSAPI {
+ cred := &auth.Cred{
+ Source: "admin",
+ Username: cs.Username,
+ Password: cs.Password,
+ PasswordSet: cs.PasswordSet,
+ Props: cs.AuthMechanismProperties,
+ }
+
+ if cs.AuthSource != "" {
+ cred.Source = cs.AuthSource
+ } else {
+ switch cs.AuthMechanism {
+ case auth.MongoDBX509:
+ if cred.Username == "" {
+ cred.Username = x509Username
+ }
+ fallthrough
+ case auth.GSSAPI, auth.PLAIN:
+ cred.Source = "$external"
+ default:
+ cred.Source = cs.Database
+ }
+ }
+
+ authenticator, err := auth.CreateAuthenticator(cs.AuthMechanism, cred)
+ if err != nil {
+ return err
+ }
+
+ connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker {
+ options := &auth.HandshakeOptions{
+ AppName: cs.AppName,
+ Authenticator: authenticator,
+ Compressors: cs.Compressors,
+ }
+ if cs.AuthMechanism == "" {
+ // Required for SASL mechanism negotiation during handshake
+ options.DBUser = cred.Source + "." + cred.Username
+ }
+ return auth.Handshaker(h, options)
+ }))
+ } else {
+ // We need to add a non-auth Handshaker to the connection options
+ connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker {
+ return &command.Handshake{Client: command.ClientDoc(cs.AppName), Compressors: cs.Compressors}
+ }))
+ }
+
+ if len(cs.Compressors) > 0 {
+ comp := make([]compressor.Compressor, 0, len(cs.Compressors))
+
+ for _, c := range cs.Compressors {
+ switch c {
+ case "snappy":
+ comp = append(comp, compressor.CreateSnappy())
+ case "zlib":
+ zlibComp, err := compressor.CreateZlib(cs.ZlibLevel)
+ if err != nil {
+ return err
+ }
+
+ comp = append(comp, zlibComp)
+ }
+ }
+
+ connOpts = append(connOpts, connection.WithCompressors(func(compressors []compressor.Compressor) []compressor.Compressor {
+ return append(compressors, comp...)
+ }))
+
+ c.serverOpts = append(c.serverOpts, WithCompressionOptions(func(opts ...string) []string {
+ return append(opts, cs.Compressors...)
+ }))
+ }
+
+ if len(connOpts) > 0 {
+ c.serverOpts = append(c.serverOpts, WithConnectionOptions(func(opts ...connection.Option) []connection.Option {
+ return append(opts, connOpts...)
+ }))
+ }
+
+ return nil
+ }
+}
+
+// WithMode configures the topology's monitor mode.
+func WithMode(fn func(MonitorMode) MonitorMode) Option {
+ return func(cfg *config) error {
+ cfg.mode = fn(cfg.mode)
+ return nil
+ }
+}
+
+// WithReplicaSetName configures the topology's default replica set name.
+func WithReplicaSetName(fn func(string) string) Option {
+ return func(cfg *config) error {
+ cfg.replicaSetName = fn(cfg.replicaSetName)
+ return nil
+ }
+}
+
+// WithSeedList configures a topology's seed list.
+func WithSeedList(fn func(...string) []string) Option {
+ return func(cfg *config) error {
+ cfg.seedList = fn(cfg.seedList...)
+ return nil
+ }
+}
+
+// WithServerOptions configures a topology's server options for when a new server
+// needs to be created.
+func WithServerOptions(fn func(...ServerOption) []ServerOption) Option {
+ return func(cfg *config) error {
+ cfg.serverOpts = fn(cfg.serverOpts...)
+ return nil
+ }
+}
+
+// WithServerSelectionTimeout configures a topology's server selection timeout.
+// A server selection timeout of 0 means there is no timeout for server selection.
+func WithServerSelectionTimeout(fn func(time.Duration) time.Duration) Option {
+ return func(cfg *config) error {
+ cfg.serverSelectionTimeout = fn(cfg.serverSelectionTimeout)
+ return nil
+ }
+}