Don Newton | 379ae25 | 2019-04-01 12:17:06 -0400 | [diff] [blame^] | 1 | // Copyright (C) MongoDB, Inc. 2017-present. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may |
| 4 | // not use this file except in compliance with the License. You may obtain |
| 5 | // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | |
| 7 | package topology |
| 8 | |
| 9 | import ( |
| 10 | "bytes" |
| 11 | "strings" |
| 12 | "time" |
| 13 | |
| 14 | "github.com/mongodb/mongo-go-driver/x/mongo/driver/auth" |
| 15 | "github.com/mongodb/mongo-go-driver/x/network/command" |
| 16 | "github.com/mongodb/mongo-go-driver/x/network/compressor" |
| 17 | "github.com/mongodb/mongo-go-driver/x/network/connection" |
| 18 | "github.com/mongodb/mongo-go-driver/x/network/connstring" |
| 19 | ) |
| 20 | |
| 21 | // Option is a configuration option for a topology. |
| 22 | type Option func(*config) error |
| 23 | |
| 24 | type config struct { |
| 25 | mode MonitorMode |
| 26 | replicaSetName string |
| 27 | seedList []string |
| 28 | serverOpts []ServerOption |
| 29 | cs connstring.ConnString |
| 30 | serverSelectionTimeout time.Duration |
| 31 | } |
| 32 | |
| 33 | func newConfig(opts ...Option) (*config, error) { |
| 34 | cfg := &config{ |
| 35 | seedList: []string{"localhost:27017"}, |
| 36 | serverSelectionTimeout: 30 * time.Second, |
| 37 | } |
| 38 | |
| 39 | for _, opt := range opts { |
| 40 | err := opt(cfg) |
| 41 | if err != nil { |
| 42 | return nil, err |
| 43 | } |
| 44 | } |
| 45 | |
| 46 | return cfg, nil |
| 47 | } |
| 48 | |
| 49 | // WithConnString configures the topology using the connection string. |
| 50 | func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option { |
| 51 | return func(c *config) error { |
| 52 | cs := fn(c.cs) |
| 53 | c.cs = cs |
| 54 | |
| 55 | if cs.ServerSelectionTimeoutSet { |
| 56 | c.serverSelectionTimeout = cs.ServerSelectionTimeout |
| 57 | } |
| 58 | |
| 59 | var connOpts []connection.Option |
| 60 | |
| 61 | if cs.AppName != "" { |
| 62 | connOpts = append(connOpts, connection.WithAppName(func(string) string { return cs.AppName })) |
| 63 | } |
| 64 | |
| 65 | switch cs.Connect { |
| 66 | case connstring.SingleConnect: |
| 67 | c.mode = SingleMode |
| 68 | } |
| 69 | |
| 70 | c.seedList = cs.Hosts |
| 71 | |
| 72 | if cs.ConnectTimeout > 0 { |
| 73 | c.serverOpts = append(c.serverOpts, WithHeartbeatTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout })) |
| 74 | connOpts = append(connOpts, connection.WithConnectTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout })) |
| 75 | } |
| 76 | |
| 77 | if cs.SocketTimeoutSet { |
| 78 | connOpts = append( |
| 79 | connOpts, |
| 80 | connection.WithReadTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), |
| 81 | connection.WithWriteTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }), |
| 82 | ) |
| 83 | } |
| 84 | |
| 85 | if cs.HeartbeatInterval > 0 { |
| 86 | c.serverOpts = append(c.serverOpts, WithHeartbeatInterval(func(time.Duration) time.Duration { return cs.HeartbeatInterval })) |
| 87 | } |
| 88 | |
| 89 | if cs.MaxConnIdleTime > 0 { |
| 90 | connOpts = append(connOpts, connection.WithIdleTimeout(func(time.Duration) time.Duration { return cs.MaxConnIdleTime })) |
| 91 | } |
| 92 | |
| 93 | if cs.MaxPoolSizeSet { |
| 94 | c.serverOpts = append(c.serverOpts, WithMaxConnections(func(uint16) uint16 { return cs.MaxPoolSize })) |
| 95 | c.serverOpts = append(c.serverOpts, WithMaxIdleConnections(func(uint16) uint16 { return cs.MaxPoolSize })) |
| 96 | } |
| 97 | |
| 98 | if cs.ReplicaSet != "" { |
| 99 | c.replicaSetName = cs.ReplicaSet |
| 100 | } |
| 101 | |
| 102 | var x509Username string |
| 103 | if cs.SSL { |
| 104 | tlsConfig := connection.NewTLSConfig() |
| 105 | |
| 106 | if cs.SSLCaFileSet { |
| 107 | err := tlsConfig.AddCACertFromFile(cs.SSLCaFile) |
| 108 | if err != nil { |
| 109 | return err |
| 110 | } |
| 111 | } |
| 112 | |
| 113 | if cs.SSLInsecure { |
| 114 | tlsConfig.SetInsecure(true) |
| 115 | } |
| 116 | |
| 117 | if cs.SSLClientCertificateKeyFileSet { |
| 118 | if cs.SSLClientCertificateKeyPasswordSet && cs.SSLClientCertificateKeyPassword != nil { |
| 119 | tlsConfig.SetClientCertDecryptPassword(cs.SSLClientCertificateKeyPassword) |
| 120 | } |
| 121 | s, err := tlsConfig.AddClientCertFromFile(cs.SSLClientCertificateKeyFile) |
| 122 | if err != nil { |
| 123 | return err |
| 124 | } |
| 125 | |
| 126 | // The Go x509 package gives the subject with the pairs in reverse order that we want. |
| 127 | pairs := strings.Split(s, ",") |
| 128 | b := bytes.NewBufferString("") |
| 129 | |
| 130 | for i := len(pairs) - 1; i >= 0; i-- { |
| 131 | b.WriteString(pairs[i]) |
| 132 | |
| 133 | if i > 0 { |
| 134 | b.WriteString(",") |
| 135 | } |
| 136 | } |
| 137 | |
| 138 | x509Username = b.String() |
| 139 | } |
| 140 | |
| 141 | connOpts = append(connOpts, connection.WithTLSConfig(func(*connection.TLSConfig) *connection.TLSConfig { return tlsConfig })) |
| 142 | } |
| 143 | |
| 144 | if cs.Username != "" || cs.AuthMechanism == auth.MongoDBX509 || cs.AuthMechanism == auth.GSSAPI { |
| 145 | cred := &auth.Cred{ |
| 146 | Source: "admin", |
| 147 | Username: cs.Username, |
| 148 | Password: cs.Password, |
| 149 | PasswordSet: cs.PasswordSet, |
| 150 | Props: cs.AuthMechanismProperties, |
| 151 | } |
| 152 | |
| 153 | if cs.AuthSource != "" { |
| 154 | cred.Source = cs.AuthSource |
| 155 | } else { |
| 156 | switch cs.AuthMechanism { |
| 157 | case auth.MongoDBX509: |
| 158 | if cred.Username == "" { |
| 159 | cred.Username = x509Username |
| 160 | } |
| 161 | fallthrough |
| 162 | case auth.GSSAPI, auth.PLAIN: |
| 163 | cred.Source = "$external" |
| 164 | default: |
| 165 | cred.Source = cs.Database |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | authenticator, err := auth.CreateAuthenticator(cs.AuthMechanism, cred) |
| 170 | if err != nil { |
| 171 | return err |
| 172 | } |
| 173 | |
| 174 | connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker { |
| 175 | options := &auth.HandshakeOptions{ |
| 176 | AppName: cs.AppName, |
| 177 | Authenticator: authenticator, |
| 178 | Compressors: cs.Compressors, |
| 179 | } |
| 180 | if cs.AuthMechanism == "" { |
| 181 | // Required for SASL mechanism negotiation during handshake |
| 182 | options.DBUser = cred.Source + "." + cred.Username |
| 183 | } |
| 184 | return auth.Handshaker(h, options) |
| 185 | })) |
| 186 | } else { |
| 187 | // We need to add a non-auth Handshaker to the connection options |
| 188 | connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker { |
| 189 | return &command.Handshake{Client: command.ClientDoc(cs.AppName), Compressors: cs.Compressors} |
| 190 | })) |
| 191 | } |
| 192 | |
| 193 | if len(cs.Compressors) > 0 { |
| 194 | comp := make([]compressor.Compressor, 0, len(cs.Compressors)) |
| 195 | |
| 196 | for _, c := range cs.Compressors { |
| 197 | switch c { |
| 198 | case "snappy": |
| 199 | comp = append(comp, compressor.CreateSnappy()) |
| 200 | case "zlib": |
| 201 | zlibComp, err := compressor.CreateZlib(cs.ZlibLevel) |
| 202 | if err != nil { |
| 203 | return err |
| 204 | } |
| 205 | |
| 206 | comp = append(comp, zlibComp) |
| 207 | } |
| 208 | } |
| 209 | |
| 210 | connOpts = append(connOpts, connection.WithCompressors(func(compressors []compressor.Compressor) []compressor.Compressor { |
| 211 | return append(compressors, comp...) |
| 212 | })) |
| 213 | |
| 214 | c.serverOpts = append(c.serverOpts, WithCompressionOptions(func(opts ...string) []string { |
| 215 | return append(opts, cs.Compressors...) |
| 216 | })) |
| 217 | } |
| 218 | |
| 219 | if len(connOpts) > 0 { |
| 220 | c.serverOpts = append(c.serverOpts, WithConnectionOptions(func(opts ...connection.Option) []connection.Option { |
| 221 | return append(opts, connOpts...) |
| 222 | })) |
| 223 | } |
| 224 | |
| 225 | return nil |
| 226 | } |
| 227 | } |
| 228 | |
| 229 | // WithMode configures the topology's monitor mode. |
| 230 | func WithMode(fn func(MonitorMode) MonitorMode) Option { |
| 231 | return func(cfg *config) error { |
| 232 | cfg.mode = fn(cfg.mode) |
| 233 | return nil |
| 234 | } |
| 235 | } |
| 236 | |
| 237 | // WithReplicaSetName configures the topology's default replica set name. |
| 238 | func WithReplicaSetName(fn func(string) string) Option { |
| 239 | return func(cfg *config) error { |
| 240 | cfg.replicaSetName = fn(cfg.replicaSetName) |
| 241 | return nil |
| 242 | } |
| 243 | } |
| 244 | |
| 245 | // WithSeedList configures a topology's seed list. |
| 246 | func WithSeedList(fn func(...string) []string) Option { |
| 247 | return func(cfg *config) error { |
| 248 | cfg.seedList = fn(cfg.seedList...) |
| 249 | return nil |
| 250 | } |
| 251 | } |
| 252 | |
| 253 | // WithServerOptions configures a topology's server options for when a new server |
| 254 | // needs to be created. |
| 255 | func WithServerOptions(fn func(...ServerOption) []ServerOption) Option { |
| 256 | return func(cfg *config) error { |
| 257 | cfg.serverOpts = fn(cfg.serverOpts...) |
| 258 | return nil |
| 259 | } |
| 260 | } |
| 261 | |
| 262 | // WithServerSelectionTimeout configures a topology's server selection timeout. |
| 263 | // A server selection timeout of 0 means there is no timeout for server selection. |
| 264 | func WithServerSelectionTimeout(fn func(time.Duration) time.Duration) Option { |
| 265 | return func(cfg *config) error { |
| 266 | cfg.serverSelectionTimeout = fn(cfg.serverSelectionTimeout) |
| 267 | return nil |
| 268 | } |
| 269 | } |