blob: 9fa98e66e50862b6f3b35dd9ba32c01f0663b47c [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package topology
8
9import (
10 "bytes"
11 "strings"
12 "time"
13
14 "github.com/mongodb/mongo-go-driver/x/mongo/driver/auth"
15 "github.com/mongodb/mongo-go-driver/x/network/command"
16 "github.com/mongodb/mongo-go-driver/x/network/compressor"
17 "github.com/mongodb/mongo-go-driver/x/network/connection"
18 "github.com/mongodb/mongo-go-driver/x/network/connstring"
19)
20
21// Option is a configuration option for a topology.
22type Option func(*config) error
23
24type config struct {
25 mode MonitorMode
26 replicaSetName string
27 seedList []string
28 serverOpts []ServerOption
29 cs connstring.ConnString
30 serverSelectionTimeout time.Duration
31}
32
33func newConfig(opts ...Option) (*config, error) {
34 cfg := &config{
35 seedList: []string{"localhost:27017"},
36 serverSelectionTimeout: 30 * time.Second,
37 }
38
39 for _, opt := range opts {
40 err := opt(cfg)
41 if err != nil {
42 return nil, err
43 }
44 }
45
46 return cfg, nil
47}
48
49// WithConnString configures the topology using the connection string.
50func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option {
51 return func(c *config) error {
52 cs := fn(c.cs)
53 c.cs = cs
54
55 if cs.ServerSelectionTimeoutSet {
56 c.serverSelectionTimeout = cs.ServerSelectionTimeout
57 }
58
59 var connOpts []connection.Option
60
61 if cs.AppName != "" {
62 connOpts = append(connOpts, connection.WithAppName(func(string) string { return cs.AppName }))
63 }
64
65 switch cs.Connect {
66 case connstring.SingleConnect:
67 c.mode = SingleMode
68 }
69
70 c.seedList = cs.Hosts
71
72 if cs.ConnectTimeout > 0 {
73 c.serverOpts = append(c.serverOpts, WithHeartbeatTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout }))
74 connOpts = append(connOpts, connection.WithConnectTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout }))
75 }
76
77 if cs.SocketTimeoutSet {
78 connOpts = append(
79 connOpts,
80 connection.WithReadTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }),
81 connection.WithWriteTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }),
82 )
83 }
84
85 if cs.HeartbeatInterval > 0 {
86 c.serverOpts = append(c.serverOpts, WithHeartbeatInterval(func(time.Duration) time.Duration { return cs.HeartbeatInterval }))
87 }
88
89 if cs.MaxConnIdleTime > 0 {
90 connOpts = append(connOpts, connection.WithIdleTimeout(func(time.Duration) time.Duration { return cs.MaxConnIdleTime }))
91 }
92
93 if cs.MaxPoolSizeSet {
94 c.serverOpts = append(c.serverOpts, WithMaxConnections(func(uint16) uint16 { return cs.MaxPoolSize }))
95 c.serverOpts = append(c.serverOpts, WithMaxIdleConnections(func(uint16) uint16 { return cs.MaxPoolSize }))
96 }
97
98 if cs.ReplicaSet != "" {
99 c.replicaSetName = cs.ReplicaSet
100 }
101
102 var x509Username string
103 if cs.SSL {
104 tlsConfig := connection.NewTLSConfig()
105
106 if cs.SSLCaFileSet {
107 err := tlsConfig.AddCACertFromFile(cs.SSLCaFile)
108 if err != nil {
109 return err
110 }
111 }
112
113 if cs.SSLInsecure {
114 tlsConfig.SetInsecure(true)
115 }
116
117 if cs.SSLClientCertificateKeyFileSet {
118 if cs.SSLClientCertificateKeyPasswordSet && cs.SSLClientCertificateKeyPassword != nil {
119 tlsConfig.SetClientCertDecryptPassword(cs.SSLClientCertificateKeyPassword)
120 }
121 s, err := tlsConfig.AddClientCertFromFile(cs.SSLClientCertificateKeyFile)
122 if err != nil {
123 return err
124 }
125
126 // The Go x509 package gives the subject with the pairs in reverse order that we want.
127 pairs := strings.Split(s, ",")
128 b := bytes.NewBufferString("")
129
130 for i := len(pairs) - 1; i >= 0; i-- {
131 b.WriteString(pairs[i])
132
133 if i > 0 {
134 b.WriteString(",")
135 }
136 }
137
138 x509Username = b.String()
139 }
140
141 connOpts = append(connOpts, connection.WithTLSConfig(func(*connection.TLSConfig) *connection.TLSConfig { return tlsConfig }))
142 }
143
144 if cs.Username != "" || cs.AuthMechanism == auth.MongoDBX509 || cs.AuthMechanism == auth.GSSAPI {
145 cred := &auth.Cred{
146 Source: "admin",
147 Username: cs.Username,
148 Password: cs.Password,
149 PasswordSet: cs.PasswordSet,
150 Props: cs.AuthMechanismProperties,
151 }
152
153 if cs.AuthSource != "" {
154 cred.Source = cs.AuthSource
155 } else {
156 switch cs.AuthMechanism {
157 case auth.MongoDBX509:
158 if cred.Username == "" {
159 cred.Username = x509Username
160 }
161 fallthrough
162 case auth.GSSAPI, auth.PLAIN:
163 cred.Source = "$external"
164 default:
165 cred.Source = cs.Database
166 }
167 }
168
169 authenticator, err := auth.CreateAuthenticator(cs.AuthMechanism, cred)
170 if err != nil {
171 return err
172 }
173
174 connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker {
175 options := &auth.HandshakeOptions{
176 AppName: cs.AppName,
177 Authenticator: authenticator,
178 Compressors: cs.Compressors,
179 }
180 if cs.AuthMechanism == "" {
181 // Required for SASL mechanism negotiation during handshake
182 options.DBUser = cred.Source + "." + cred.Username
183 }
184 return auth.Handshaker(h, options)
185 }))
186 } else {
187 // We need to add a non-auth Handshaker to the connection options
188 connOpts = append(connOpts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker {
189 return &command.Handshake{Client: command.ClientDoc(cs.AppName), Compressors: cs.Compressors}
190 }))
191 }
192
193 if len(cs.Compressors) > 0 {
194 comp := make([]compressor.Compressor, 0, len(cs.Compressors))
195
196 for _, c := range cs.Compressors {
197 switch c {
198 case "snappy":
199 comp = append(comp, compressor.CreateSnappy())
200 case "zlib":
201 zlibComp, err := compressor.CreateZlib(cs.ZlibLevel)
202 if err != nil {
203 return err
204 }
205
206 comp = append(comp, zlibComp)
207 }
208 }
209
210 connOpts = append(connOpts, connection.WithCompressors(func(compressors []compressor.Compressor) []compressor.Compressor {
211 return append(compressors, comp...)
212 }))
213
214 c.serverOpts = append(c.serverOpts, WithCompressionOptions(func(opts ...string) []string {
215 return append(opts, cs.Compressors...)
216 }))
217 }
218
219 if len(connOpts) > 0 {
220 c.serverOpts = append(c.serverOpts, WithConnectionOptions(func(opts ...connection.Option) []connection.Option {
221 return append(opts, connOpts...)
222 }))
223 }
224
225 return nil
226 }
227}
228
229// WithMode configures the topology's monitor mode.
230func WithMode(fn func(MonitorMode) MonitorMode) Option {
231 return func(cfg *config) error {
232 cfg.mode = fn(cfg.mode)
233 return nil
234 }
235}
236
237// WithReplicaSetName configures the topology's default replica set name.
238func WithReplicaSetName(fn func(string) string) Option {
239 return func(cfg *config) error {
240 cfg.replicaSetName = fn(cfg.replicaSetName)
241 return nil
242 }
243}
244
245// WithSeedList configures a topology's seed list.
246func WithSeedList(fn func(...string) []string) Option {
247 return func(cfg *config) error {
248 cfg.seedList = fn(cfg.seedList...)
249 return nil
250 }
251}
252
253// WithServerOptions configures a topology's server options for when a new server
254// needs to be created.
255func WithServerOptions(fn func(...ServerOption) []ServerOption) Option {
256 return func(cfg *config) error {
257 cfg.serverOpts = fn(cfg.serverOpts...)
258 return nil
259 }
260}
261
262// WithServerSelectionTimeout configures a topology's server selection timeout.
263// A server selection timeout of 0 means there is no timeout for server selection.
264func WithServerSelectionTimeout(fn func(time.Duration) time.Duration) Option {
265 return func(cfg *config) error {
266 cfg.serverSelectionTimeout = fn(cfg.serverSelectionTimeout)
267 return nil
268 }
269}