blob: 3a25ad379c295a70f3ff53c81bb5cba95c3e45e4 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package connection
8
9import (
10 "context"
11 "sync"
12 "sync/atomic"
13
14 "github.com/mongodb/mongo-go-driver/x/network/address"
15 "github.com/mongodb/mongo-go-driver/x/network/description"
16 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
17 "golang.org/x/sync/semaphore"
18)
19
20// ErrPoolClosed is returned from an attempt to use a closed pool.
21var ErrPoolClosed = PoolError("pool is closed")
22
23// ErrSizeLargerThanCapacity is returned from an attempt to create a pool with a size
24// larger than the capacity.
25var ErrSizeLargerThanCapacity = PoolError("size is larger than capacity")
26
27// ErrPoolConnected is returned from an attempt to connect an already connected pool
28var ErrPoolConnected = PoolError("pool is connected")
29
30// ErrPoolDisconnected is returned from an attempt to disconnect an already disconnected
31// or disconnecting pool.
32var ErrPoolDisconnected = PoolError("pool is disconnected or disconnecting")
33
34// ErrConnectionClosed is returned from an attempt to use an already closed connection.
35var ErrConnectionClosed = Error{ConnectionID: "<closed>", message: "connection is closed"}
36
37// These constants represent the connection states of a pool.
38const (
39 disconnected int32 = iota
40 disconnecting
41 connected
42)
43
44// Pool is used to pool Connections to a server.
45type Pool interface {
46 // Get must return a nil *description.Server if the returned connection is
47 // not a newly dialed connection.
48 Get(context.Context) (Connection, *description.Server, error)
49 // Connect handles the initialization of a Pool and allow Connections to be
50 // retrieved and pooled. Implementations must return an error if Connect is
51 // called more than once before calling Disconnect.
52 Connect(context.Context) error
53 // Disconnect closest connections managed by this Pool. Implementations must
54 // either wait until all of the connections in use have been returned and
55 // closed or the context expires before returning. If the context expires
56 // via cancellation, deadline, timeout, or some other manner, implementations
57 // must close the in use connections. If this method returns with no errors,
58 // all connections managed by this pool must be closed. Calling Disconnect
59 // multiple times after a single Connect call must result in an error.
60 Disconnect(context.Context) error
61 Drain() error
62}
63
64type pool struct {
65 address address.Address
66 opts []Option
67 conns chan *pooledConnection
68 generation uint64
69 sem *semaphore.Weighted
70 connected int32
71 nextid uint64
72 capacity uint64
73 inflight map[uint64]*pooledConnection
74
75 sync.Mutex
76}
77
78// NewPool creates a new pool that will hold size number of idle connections
79// and will create a max of capacity connections. It will use the provided
80// options.
81func NewPool(addr address.Address, size, capacity uint64, opts ...Option) (Pool, error) {
82 if size > capacity {
83 return nil, ErrSizeLargerThanCapacity
84 }
85 p := &pool{
86 address: addr,
87 conns: make(chan *pooledConnection, size),
88 generation: 0,
89 sem: semaphore.NewWeighted(int64(capacity)),
90 connected: disconnected,
91 capacity: capacity,
92 inflight: make(map[uint64]*pooledConnection),
93 opts: opts,
94 }
95 return p, nil
96}
97
98func (p *pool) Drain() error {
99 atomic.AddUint64(&p.generation, 1)
100 return nil
101}
102
103func (p *pool) Connect(ctx context.Context) error {
104 if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) {
105 return ErrPoolConnected
106 }
107 atomic.AddUint64(&p.generation, 1)
108 return nil
109}
110
111func (p *pool) Disconnect(ctx context.Context) error {
112 if !atomic.CompareAndSwapInt32(&p.connected, connected, disconnecting) {
113 return ErrPoolDisconnected
114 }
115
116 // We first clear out the idle connections, then we attempt to acquire the entire capacity
117 // semaphore. If the context is either cancelled, the deadline expires, or there is a timeout
118 // the semaphore acquire method will return an error. If that happens, we will aggressively
119 // close the remaining open connections. If we were able to successfully acquire the semaphore,
120 // then all of the in flight connections have been closed and we release the semaphore.
121loop:
122 for {
123 select {
124 case pc := <-p.conns:
125 // This error would be overwritten by the semaphore
126 _ = p.closeConnection(pc)
127 default:
128 break loop
129 }
130 }
131 err := p.sem.Acquire(ctx, int64(p.capacity))
132 if err != nil {
133 p.Lock()
134 // We copy the remaining connections to close into a slice, then
135 // iterate the slice to do the closing. This allows us to use a single
136 // function to actually clean up and close connections at the expense of
137 // a double iteration in the worst case.
138 toClose := make([]*pooledConnection, 0, len(p.inflight))
139 for _, pc := range p.inflight {
140 toClose = append(toClose, pc)
141 }
142 p.Unlock()
143 for _, pc := range toClose {
144 _ = pc.Close()
145 }
146 } else {
147 p.sem.Release(int64(p.capacity))
148 }
149 atomic.StoreInt32(&p.connected, disconnected)
150 return nil
151}
152
153func (p *pool) Get(ctx context.Context) (Connection, *description.Server, error) {
154 if atomic.LoadInt32(&p.connected) != connected {
155 return nil, nil, ErrPoolClosed
156 }
157
158 err := p.sem.Acquire(ctx, 1)
159 if err != nil {
160 return nil, nil, err
161 }
162
163 return p.get(ctx)
164}
165
166func (p *pool) get(ctx context.Context) (Connection, *description.Server, error) {
167 g := atomic.LoadUint64(&p.generation)
168 select {
169 case c := <-p.conns:
170 if c.Expired() {
171 go p.closeConnection(c)
172 return p.get(ctx)
173 }
174
175 return &acquired{Connection: c, sem: p.sem}, nil, nil
176 case <-ctx.Done():
177 p.sem.Release(1)
178 return nil, nil, ctx.Err()
179 default:
180 c, desc, err := New(ctx, p.address, p.opts...)
181 if err != nil {
182 p.sem.Release(1)
183 return nil, nil, err
184 }
185
186 pc := &pooledConnection{
187 Connection: c,
188 p: p,
189 generation: g,
190 id: atomic.AddUint64(&p.nextid, 1),
191 }
192 p.Lock()
193 if atomic.LoadInt32(&p.connected) != connected {
194 p.Unlock()
195 p.sem.Release(1)
196 p.closeConnection(pc)
197 return nil, nil, ErrPoolClosed
198 }
199 defer p.Unlock()
200 p.inflight[pc.id] = pc
201 return &acquired{Connection: pc, sem: p.sem}, desc, nil
202 }
203}
204
205func (p *pool) closeConnection(pc *pooledConnection) error {
206 if !atomic.CompareAndSwapInt32(&pc.closed, 0, 1) {
207 return nil
208 }
209 p.Lock()
210 delete(p.inflight, pc.id)
211 p.Unlock()
212 return pc.Connection.Close()
213}
214
215func (p *pool) returnConnection(pc *pooledConnection) error {
216 if atomic.LoadInt32(&p.connected) != connected || pc.Expired() {
217 return p.closeConnection(pc)
218 }
219
220 select {
221 case p.conns <- pc:
222 return nil
223 default:
224 return p.closeConnection(pc)
225 }
226}
227
228func (p *pool) isExpired(generation uint64) bool {
229 return generation < atomic.LoadUint64(&p.generation)
230}
231
232type pooledConnection struct {
233 Connection
234 p *pool
235 generation uint64
236 id uint64
237 closed int32
238}
239
240func (pc *pooledConnection) Close() error {
241 return pc.p.returnConnection(pc)
242}
243
244func (pc *pooledConnection) Expired() bool {
245 return pc.Connection.Expired() || pc.p.isExpired(pc.generation)
246}
247
248type acquired struct {
249 Connection
250
251 sem *semaphore.Weighted
252 sync.Mutex
253}
254
255func (a *acquired) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
256 a.Lock()
257 defer a.Unlock()
258 if a.Connection == nil {
259 return ErrConnectionClosed
260 }
261 return a.Connection.WriteWireMessage(ctx, wm)
262}
263
264func (a *acquired) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) {
265 a.Lock()
266 defer a.Unlock()
267 if a.Connection == nil {
268 return nil, ErrConnectionClosed
269 }
270 return a.Connection.ReadWireMessage(ctx)
271}
272
273func (a *acquired) Close() error {
274 a.Lock()
275 defer a.Unlock()
276 if a.Connection == nil {
277 return nil
278 }
279 err := a.Connection.Close()
280 a.sem.Release(1)
281 a.Connection = nil
282 return err
283}
284
285func (a *acquired) Expired() bool {
286 a.Lock()
287 defer a.Unlock()
288 if a.Connection == nil {
289 return true
290 }
291 return a.Connection.Expired()
292}
293
294func (a *acquired) Alive() bool {
295 a.Lock()
296 defer a.Unlock()
297 if a.Connection == nil {
298 return false
299 }
300 return a.Connection.Alive()
301}
302
303func (a *acquired) ID() string {
304 a.Lock()
305 defer a.Unlock()
306 if a.Connection == nil {
307 return "<closed>"
308 }
309 return a.Connection.ID()
310}