blob: 3a7ace2d5ae92b92058a367340d9b689c4539b09 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package topology
8
9import (
10 "context"
11 "errors"
12 "fmt"
13 "math"
14 "sync"
15 "sync/atomic"
16 "time"
17
18 "github.com/mongodb/mongo-go-driver/event"
19 "github.com/mongodb/mongo-go-driver/x/mongo/driver/auth"
20 "github.com/mongodb/mongo-go-driver/x/network/address"
21 "github.com/mongodb/mongo-go-driver/x/network/command"
22 "github.com/mongodb/mongo-go-driver/x/network/connection"
23 "github.com/mongodb/mongo-go-driver/x/network/description"
24)
25
26const minHeartbeatInterval = 500 * time.Millisecond
27const connectionSemaphoreSize = math.MaxInt64
28
29// ErrServerClosed occurs when an attempt to get a connection is made after
30// the server has been closed.
31var ErrServerClosed = errors.New("server is closed")
32
33// ErrServerConnected occurs when at attempt to connect is made after a server
34// has already been connected.
35var ErrServerConnected = errors.New("server is connected")
36
37// SelectedServer represents a specific server that was selected during server selection.
38// It contains the kind of the typology it was selected from.
39type SelectedServer struct {
40 *Server
41
42 Kind description.TopologyKind
43}
44
45// Description returns a description of the server as of the last heartbeat.
46func (ss *SelectedServer) Description() description.SelectedServer {
47 sdesc := ss.Server.Description()
48 return description.SelectedServer{
49 Server: sdesc,
50 Kind: ss.Kind,
51 }
52}
53
54// These constants represent the connection states of a server.
55const (
56 disconnected int32 = iota
57 disconnecting
58 connected
59 connecting
60)
61
62func connectionStateString(state int32) string {
63 switch state {
64 case 0:
65 return "Disconnected"
66 case 1:
67 return "Disconnecting"
68 case 2:
69 return "Connected"
70 case 3:
71 return "Connecting"
72 }
73
74 return ""
75}
76
77// Server is a single server within a topology.
78type Server struct {
79 cfg *serverConfig
80 address address.Address
81
82 connectionstate int32
83 done chan struct{}
84 checkNow chan struct{}
85 closewg sync.WaitGroup
86 pool connection.Pool
87
88 desc atomic.Value // holds a description.Server
89
90 averageRTTSet bool
91 averageRTT time.Duration
92
93 subLock sync.Mutex
94 subscribers map[uint64]chan description.Server
95 currentSubscriberID uint64
96
97 subscriptionsClosed bool
98}
99
100// ConnectServer creates a new Server and then initializes it using the
101// Connect method.
102func ConnectServer(ctx context.Context, addr address.Address, opts ...ServerOption) (*Server, error) {
103 srvr, err := NewServer(addr, opts...)
104 if err != nil {
105 return nil, err
106 }
107 err = srvr.Connect(ctx)
108 if err != nil {
109 return nil, err
110 }
111 return srvr, nil
112}
113
114// NewServer creates a new server. The mongodb server at the address will be monitored
115// on an internal monitoring goroutine.
116func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
117 cfg, err := newServerConfig(opts...)
118 if err != nil {
119 return nil, err
120 }
121
122 s := &Server{
123 cfg: cfg,
124 address: addr,
125
126 done: make(chan struct{}),
127 checkNow: make(chan struct{}, 1),
128
129 subscribers: make(map[uint64]chan description.Server),
130 }
131 s.desc.Store(description.Server{Addr: addr})
132
133 var maxConns uint64
134 if cfg.maxConns == 0 {
135 maxConns = math.MaxInt64
136 } else {
137 maxConns = uint64(cfg.maxConns)
138 }
139
140 s.pool, err = connection.NewPool(addr, uint64(cfg.maxIdleConns), maxConns, cfg.connectionOpts...)
141 if err != nil {
142 return nil, err
143 }
144
145 return s, nil
146}
147
148// Connect initialzies the Server by starting background monitoring goroutines.
149// This method must be called before a Server can be used.
150func (s *Server) Connect(ctx context.Context) error {
151 if !atomic.CompareAndSwapInt32(&s.connectionstate, disconnected, connected) {
152 return ErrServerConnected
153 }
154 s.desc.Store(description.Server{Addr: s.address})
155 go s.update()
156 s.closewg.Add(1)
157 return s.pool.Connect(ctx)
158}
159
160// Disconnect closes sockets to the server referenced by this Server.
161// Subscriptions to this Server will be closed. Disconnect will shutdown
162// any monitoring goroutines, close the idle connection pool, and will
163// wait until all the in use connections have been returned to the connection
164// pool and are closed before returning. If the context expires via
165// cancellation, deadline, or timeout before the in use connections have been
166// returned, the in use connections will be closed, resulting in the failure of
167// any in flight read or write operations. If this method returns with no
168// errors, all connections associated with this Server have been closed.
169func (s *Server) Disconnect(ctx context.Context) error {
170 if !atomic.CompareAndSwapInt32(&s.connectionstate, connected, disconnecting) {
171 return ErrServerClosed
172 }
173
174 // For every call to Connect there must be at least 1 goroutine that is
175 // waiting on the done channel.
176 s.done <- struct{}{}
177 err := s.pool.Disconnect(ctx)
178 if err != nil {
179 return err
180 }
181
182 s.closewg.Wait()
183 atomic.StoreInt32(&s.connectionstate, disconnected)
184
185 return nil
186}
187
188// Connection gets a connection to the server.
189func (s *Server) Connection(ctx context.Context) (connection.Connection, error) {
190 if atomic.LoadInt32(&s.connectionstate) != connected {
191 return nil, ErrServerClosed
192 }
193 conn, desc, err := s.pool.Get(ctx)
194 if err != nil {
195 if _, ok := err.(*auth.Error); ok {
196 // authentication error --> drain connection
197 _ = s.pool.Drain()
198 }
199 if _, ok := err.(*connection.NetworkError); ok {
200 // update description to unknown and clears the connection pool
201 if desc != nil {
202 desc.Kind = description.Unknown
203 desc.LastError = err
204 s.updateDescription(*desc, false)
205 } else {
206 _ = s.pool.Drain()
207 }
208 }
209 return nil, err
210 }
211 if desc != nil {
212 go s.updateDescription(*desc, false)
213 }
214 sc := &sconn{Connection: conn, s: s}
215 return sc, nil
216}
217
218// Description returns a description of the server as of the last heartbeat.
219func (s *Server) Description() description.Server {
220 return s.desc.Load().(description.Server)
221}
222
223// SelectedDescription returns a description.SelectedServer with a Kind of
224// Single. This can be used when performing tasks like monitoring a batch
225// of servers and you want to run one off commands against those servers.
226func (s *Server) SelectedDescription() description.SelectedServer {
227 sdesc := s.Description()
228 return description.SelectedServer{
229 Server: sdesc,
230 Kind: description.Single,
231 }
232}
233
234// Subscribe returns a ServerSubscription which has a channel on which all
235// updated server descriptions will be sent. The channel will have a buffer
236// size of one, and will be pre-populated with the current description.
237func (s *Server) Subscribe() (*ServerSubscription, error) {
238 if atomic.LoadInt32(&s.connectionstate) != connected {
239 return nil, ErrSubscribeAfterClosed
240 }
241 ch := make(chan description.Server, 1)
242 ch <- s.desc.Load().(description.Server)
243
244 s.subLock.Lock()
245 defer s.subLock.Unlock()
246 if s.subscriptionsClosed {
247 return nil, ErrSubscribeAfterClosed
248 }
249 id := s.currentSubscriberID
250 s.subscribers[id] = ch
251 s.currentSubscriberID++
252
253 ss := &ServerSubscription{
254 C: ch,
255 s: s,
256 id: id,
257 }
258
259 return ss, nil
260}
261
262// RequestImmediateCheck will cause the server to send a heartbeat immediately
263// instead of waiting for the heartbeat timeout.
264func (s *Server) RequestImmediateCheck() {
265 select {
266 case s.checkNow <- struct{}{}:
267 default:
268 }
269}
270
271// update handles performing heartbeats and updating any subscribers of the
272// newest description.Server retrieved.
273func (s *Server) update() {
274 defer s.closewg.Done()
275 heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
276 rateLimiter := time.NewTicker(minHeartbeatInterval)
277 defer heartbeatTicker.Stop()
278 defer rateLimiter.Stop()
279 checkNow := s.checkNow
280 done := s.done
281
282 var doneOnce bool
283 defer func() {
284 if r := recover(); r != nil {
285 if doneOnce {
286 return
287 }
288 // We keep this goroutine alive attempting to read from the done channel.
289 <-done
290 }
291 }()
292
293 var conn connection.Connection
294 var desc description.Server
295
296 desc, conn = s.heartbeat(nil)
297 s.updateDescription(desc, true)
298
299 closeServer := func() {
300 doneOnce = true
301 s.subLock.Lock()
302 for id, c := range s.subscribers {
303 close(c)
304 delete(s.subscribers, id)
305 }
306 s.subscriptionsClosed = true
307 s.subLock.Unlock()
308 if conn == nil {
309 return
310 }
311 conn.Close()
312 }
313 for {
314 select {
315 case <-heartbeatTicker.C:
316 case <-checkNow:
317 case <-done:
318 closeServer()
319 return
320 }
321
322 select {
323 case <-rateLimiter.C:
324 case <-done:
325 closeServer()
326 return
327 }
328
329 desc, conn = s.heartbeat(conn)
330 s.updateDescription(desc, false)
331 }
332}
333
334// updateDescription handles updating the description on the Server, notifying
335// subscribers, and potentially draining the connection pool. The initial
336// parameter is used to determine if this is the first description from the
337// server.
338func (s *Server) updateDescription(desc description.Server, initial bool) {
339 defer func() {
340 // ¯\_(ツ)_/¯
341 _ = recover()
342 }()
343 s.desc.Store(desc)
344
345 s.subLock.Lock()
346 for _, c := range s.subscribers {
347 select {
348 // drain the channel if it isn't empty
349 case <-c:
350 default:
351 }
352 c <- desc
353 }
354 s.subLock.Unlock()
355
356 if initial {
357 // We don't clear the pool on the first update on the description.
358 return
359 }
360
361 switch desc.Kind {
362 case description.Unknown:
363 _ = s.pool.Drain()
364 }
365}
366
367// heartbeat sends a heartbeat to the server using the given connection. The connection can be nil.
368func (s *Server) heartbeat(conn connection.Connection) (description.Server, connection.Connection) {
369 const maxRetry = 2
370 var saved error
371 var desc description.Server
372 var set bool
373 var err error
374 ctx := context.Background()
375
376 for i := 1; i <= maxRetry; i++ {
377 if conn != nil && conn.Expired() {
378 conn.Close()
379 conn = nil
380 }
381
382 if conn == nil {
383 opts := []connection.Option{
384 connection.WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
385 connection.WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
386 connection.WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
387 }
388 opts = append(opts, s.cfg.connectionOpts...)
389 // We override whatever handshaker is currently attached to the options with an empty
390 // one because need to make sure we don't do auth.
391 opts = append(opts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker {
392 return nil
393 }))
394
395 // Override any command monitors specified in options with nil to avoid monitoring heartbeats.
396 opts = append(opts, connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
397 return nil
398 }))
399 conn, _, err = connection.New(ctx, s.address, opts...)
400 if err != nil {
401 saved = err
402 if conn != nil {
403 conn.Close()
404 }
405 conn = nil
406 continue
407 }
408 }
409
410 now := time.Now()
411
412 isMasterCmd := &command.IsMaster{Compressors: s.cfg.compressionOpts}
413 isMaster, err := isMasterCmd.RoundTrip(ctx, conn)
414 if err != nil {
415 saved = err
416 conn.Close()
417 conn = nil
418 continue
419 }
420
421 clusterTime := isMaster.ClusterTime
422 if s.cfg.clock != nil {
423 s.cfg.clock.AdvanceClusterTime(clusterTime)
424 }
425
426 delay := time.Since(now)
427 desc = description.NewServer(s.address, isMaster).SetAverageRTT(s.updateAverageRTT(delay))
428 desc.HeartbeatInterval = s.cfg.heartbeatInterval
429 set = true
430
431 break
432 }
433
434 if !set {
435 desc = description.Server{
436 Addr: s.address,
437 LastError: saved,
438 }
439 }
440
441 return desc, conn
442}
443
444func (s *Server) updateAverageRTT(delay time.Duration) time.Duration {
445 if !s.averageRTTSet {
446 s.averageRTT = delay
447 } else {
448 alpha := 0.2
449 s.averageRTT = time.Duration(alpha*float64(delay) + (1-alpha)*float64(s.averageRTT))
450 }
451 return s.averageRTT
452}
453
454// Drain will drain the connection pool of this server. This is mainly here so the
455// pool for the server doesn't need to be directly exposed and so that when an error
456// is returned from reading or writing, a client can drain the pool for this server.
457// This is exposed here so we don't have to wrap the Connection type and sniff responses
458// for errors that would cause the pool to be drained, which can in turn centralize the
459// logic for handling errors in the Client type.
460func (s *Server) Drain() error { return s.pool.Drain() }
461
462// String implements the Stringer interface.
463func (s *Server) String() string {
464 desc := s.Description()
465 str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
466 s.address, desc.Kind, connectionStateString(s.connectionstate))
467 if len(desc.Tags) != 0 {
468 str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
469 }
470 if s.connectionstate == connected {
471 str += fmt.Sprintf(", Avergage RTT: %d", s.averageRTT)
472 }
473 if desc.LastError != nil {
474 str += fmt.Sprintf(", Last error: %s", desc.LastError)
475 }
476
477 return str
478}
479
480// ServerSubscription represents a subscription to the description.Server updates for
481// a specific server.
482type ServerSubscription struct {
483 C <-chan description.Server
484 s *Server
485 id uint64
486}
487
488// Unsubscribe unsubscribes this ServerSubscription from updates and closes the
489// subscription channel.
490func (ss *ServerSubscription) Unsubscribe() error {
491 ss.s.subLock.Lock()
492 defer ss.s.subLock.Unlock()
493 if ss.s.subscriptionsClosed {
494 return nil
495 }
496
497 ch, ok := ss.s.subscribers[ss.id]
498 if !ok {
499 return nil
500 }
501
502 close(ch)
503 delete(ss.s.subscribers, ss.id)
504
505 return nil
506}