Don Newton | 379ae25 | 2019-04-01 12:17:06 -0400 | [diff] [blame^] | 1 | // Copyright (C) MongoDB, Inc. 2017-present. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may |
| 4 | // not use this file except in compliance with the License. You may obtain |
| 5 | // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | |
| 7 | package connection |
| 8 | |
| 9 | import ( |
| 10 | "context" |
| 11 | "sync" |
| 12 | "sync/atomic" |
| 13 | |
| 14 | "github.com/mongodb/mongo-go-driver/x/network/address" |
| 15 | "github.com/mongodb/mongo-go-driver/x/network/description" |
| 16 | "github.com/mongodb/mongo-go-driver/x/network/wiremessage" |
| 17 | "golang.org/x/sync/semaphore" |
| 18 | ) |
| 19 | |
| 20 | // ErrPoolClosed is returned from an attempt to use a closed pool. |
| 21 | var ErrPoolClosed = PoolError("pool is closed") |
| 22 | |
| 23 | // ErrSizeLargerThanCapacity is returned from an attempt to create a pool with a size |
| 24 | // larger than the capacity. |
| 25 | var ErrSizeLargerThanCapacity = PoolError("size is larger than capacity") |
| 26 | |
| 27 | // ErrPoolConnected is returned from an attempt to connect an already connected pool |
| 28 | var ErrPoolConnected = PoolError("pool is connected") |
| 29 | |
| 30 | // ErrPoolDisconnected is returned from an attempt to disconnect an already disconnected |
| 31 | // or disconnecting pool. |
| 32 | var ErrPoolDisconnected = PoolError("pool is disconnected or disconnecting") |
| 33 | |
| 34 | // ErrConnectionClosed is returned from an attempt to use an already closed connection. |
| 35 | var ErrConnectionClosed = Error{ConnectionID: "<closed>", message: "connection is closed"} |
| 36 | |
| 37 | // These constants represent the connection states of a pool. |
| 38 | const ( |
| 39 | disconnected int32 = iota |
| 40 | disconnecting |
| 41 | connected |
| 42 | ) |
| 43 | |
| 44 | // Pool is used to pool Connections to a server. |
| 45 | type Pool interface { |
| 46 | // Get must return a nil *description.Server if the returned connection is |
| 47 | // not a newly dialed connection. |
| 48 | Get(context.Context) (Connection, *description.Server, error) |
| 49 | // Connect handles the initialization of a Pool and allow Connections to be |
| 50 | // retrieved and pooled. Implementations must return an error if Connect is |
| 51 | // called more than once before calling Disconnect. |
| 52 | Connect(context.Context) error |
| 53 | // Disconnect closest connections managed by this Pool. Implementations must |
| 54 | // either wait until all of the connections in use have been returned and |
| 55 | // closed or the context expires before returning. If the context expires |
| 56 | // via cancellation, deadline, timeout, or some other manner, implementations |
| 57 | // must close the in use connections. If this method returns with no errors, |
| 58 | // all connections managed by this pool must be closed. Calling Disconnect |
| 59 | // multiple times after a single Connect call must result in an error. |
| 60 | Disconnect(context.Context) error |
| 61 | Drain() error |
| 62 | } |
| 63 | |
| 64 | type pool struct { |
| 65 | address address.Address |
| 66 | opts []Option |
| 67 | conns chan *pooledConnection |
| 68 | generation uint64 |
| 69 | sem *semaphore.Weighted |
| 70 | connected int32 |
| 71 | nextid uint64 |
| 72 | capacity uint64 |
| 73 | inflight map[uint64]*pooledConnection |
| 74 | |
| 75 | sync.Mutex |
| 76 | } |
| 77 | |
| 78 | // NewPool creates a new pool that will hold size number of idle connections |
| 79 | // and will create a max of capacity connections. It will use the provided |
| 80 | // options. |
| 81 | func NewPool(addr address.Address, size, capacity uint64, opts ...Option) (Pool, error) { |
| 82 | if size > capacity { |
| 83 | return nil, ErrSizeLargerThanCapacity |
| 84 | } |
| 85 | p := &pool{ |
| 86 | address: addr, |
| 87 | conns: make(chan *pooledConnection, size), |
| 88 | generation: 0, |
| 89 | sem: semaphore.NewWeighted(int64(capacity)), |
| 90 | connected: disconnected, |
| 91 | capacity: capacity, |
| 92 | inflight: make(map[uint64]*pooledConnection), |
| 93 | opts: opts, |
| 94 | } |
| 95 | return p, nil |
| 96 | } |
| 97 | |
| 98 | func (p *pool) Drain() error { |
| 99 | atomic.AddUint64(&p.generation, 1) |
| 100 | return nil |
| 101 | } |
| 102 | |
| 103 | func (p *pool) Connect(ctx context.Context) error { |
| 104 | if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) { |
| 105 | return ErrPoolConnected |
| 106 | } |
| 107 | atomic.AddUint64(&p.generation, 1) |
| 108 | return nil |
| 109 | } |
| 110 | |
| 111 | func (p *pool) Disconnect(ctx context.Context) error { |
| 112 | if !atomic.CompareAndSwapInt32(&p.connected, connected, disconnecting) { |
| 113 | return ErrPoolDisconnected |
| 114 | } |
| 115 | |
| 116 | // We first clear out the idle connections, then we attempt to acquire the entire capacity |
| 117 | // semaphore. If the context is either cancelled, the deadline expires, or there is a timeout |
| 118 | // the semaphore acquire method will return an error. If that happens, we will aggressively |
| 119 | // close the remaining open connections. If we were able to successfully acquire the semaphore, |
| 120 | // then all of the in flight connections have been closed and we release the semaphore. |
| 121 | loop: |
| 122 | for { |
| 123 | select { |
| 124 | case pc := <-p.conns: |
| 125 | // This error would be overwritten by the semaphore |
| 126 | _ = p.closeConnection(pc) |
| 127 | default: |
| 128 | break loop |
| 129 | } |
| 130 | } |
| 131 | err := p.sem.Acquire(ctx, int64(p.capacity)) |
| 132 | if err != nil { |
| 133 | p.Lock() |
| 134 | // We copy the remaining connections to close into a slice, then |
| 135 | // iterate the slice to do the closing. This allows us to use a single |
| 136 | // function to actually clean up and close connections at the expense of |
| 137 | // a double iteration in the worst case. |
| 138 | toClose := make([]*pooledConnection, 0, len(p.inflight)) |
| 139 | for _, pc := range p.inflight { |
| 140 | toClose = append(toClose, pc) |
| 141 | } |
| 142 | p.Unlock() |
| 143 | for _, pc := range toClose { |
| 144 | _ = pc.Close() |
| 145 | } |
| 146 | } else { |
| 147 | p.sem.Release(int64(p.capacity)) |
| 148 | } |
| 149 | atomic.StoreInt32(&p.connected, disconnected) |
| 150 | return nil |
| 151 | } |
| 152 | |
| 153 | func (p *pool) Get(ctx context.Context) (Connection, *description.Server, error) { |
| 154 | if atomic.LoadInt32(&p.connected) != connected { |
| 155 | return nil, nil, ErrPoolClosed |
| 156 | } |
| 157 | |
| 158 | err := p.sem.Acquire(ctx, 1) |
| 159 | if err != nil { |
| 160 | return nil, nil, err |
| 161 | } |
| 162 | |
| 163 | return p.get(ctx) |
| 164 | } |
| 165 | |
| 166 | func (p *pool) get(ctx context.Context) (Connection, *description.Server, error) { |
| 167 | g := atomic.LoadUint64(&p.generation) |
| 168 | select { |
| 169 | case c := <-p.conns: |
| 170 | if c.Expired() { |
| 171 | go p.closeConnection(c) |
| 172 | return p.get(ctx) |
| 173 | } |
| 174 | |
| 175 | return &acquired{Connection: c, sem: p.sem}, nil, nil |
| 176 | case <-ctx.Done(): |
| 177 | p.sem.Release(1) |
| 178 | return nil, nil, ctx.Err() |
| 179 | default: |
| 180 | c, desc, err := New(ctx, p.address, p.opts...) |
| 181 | if err != nil { |
| 182 | p.sem.Release(1) |
| 183 | return nil, nil, err |
| 184 | } |
| 185 | |
| 186 | pc := &pooledConnection{ |
| 187 | Connection: c, |
| 188 | p: p, |
| 189 | generation: g, |
| 190 | id: atomic.AddUint64(&p.nextid, 1), |
| 191 | } |
| 192 | p.Lock() |
| 193 | if atomic.LoadInt32(&p.connected) != connected { |
| 194 | p.Unlock() |
| 195 | p.sem.Release(1) |
| 196 | p.closeConnection(pc) |
| 197 | return nil, nil, ErrPoolClosed |
| 198 | } |
| 199 | defer p.Unlock() |
| 200 | p.inflight[pc.id] = pc |
| 201 | return &acquired{Connection: pc, sem: p.sem}, desc, nil |
| 202 | } |
| 203 | } |
| 204 | |
| 205 | func (p *pool) closeConnection(pc *pooledConnection) error { |
| 206 | if !atomic.CompareAndSwapInt32(&pc.closed, 0, 1) { |
| 207 | return nil |
| 208 | } |
| 209 | p.Lock() |
| 210 | delete(p.inflight, pc.id) |
| 211 | p.Unlock() |
| 212 | return pc.Connection.Close() |
| 213 | } |
| 214 | |
| 215 | func (p *pool) returnConnection(pc *pooledConnection) error { |
| 216 | if atomic.LoadInt32(&p.connected) != connected || pc.Expired() { |
| 217 | return p.closeConnection(pc) |
| 218 | } |
| 219 | |
| 220 | select { |
| 221 | case p.conns <- pc: |
| 222 | return nil |
| 223 | default: |
| 224 | return p.closeConnection(pc) |
| 225 | } |
| 226 | } |
| 227 | |
| 228 | func (p *pool) isExpired(generation uint64) bool { |
| 229 | return generation < atomic.LoadUint64(&p.generation) |
| 230 | } |
| 231 | |
| 232 | type pooledConnection struct { |
| 233 | Connection |
| 234 | p *pool |
| 235 | generation uint64 |
| 236 | id uint64 |
| 237 | closed int32 |
| 238 | } |
| 239 | |
| 240 | func (pc *pooledConnection) Close() error { |
| 241 | return pc.p.returnConnection(pc) |
| 242 | } |
| 243 | |
| 244 | func (pc *pooledConnection) Expired() bool { |
| 245 | return pc.Connection.Expired() || pc.p.isExpired(pc.generation) |
| 246 | } |
| 247 | |
| 248 | type acquired struct { |
| 249 | Connection |
| 250 | |
| 251 | sem *semaphore.Weighted |
| 252 | sync.Mutex |
| 253 | } |
| 254 | |
| 255 | func (a *acquired) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error { |
| 256 | a.Lock() |
| 257 | defer a.Unlock() |
| 258 | if a.Connection == nil { |
| 259 | return ErrConnectionClosed |
| 260 | } |
| 261 | return a.Connection.WriteWireMessage(ctx, wm) |
| 262 | } |
| 263 | |
| 264 | func (a *acquired) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) { |
| 265 | a.Lock() |
| 266 | defer a.Unlock() |
| 267 | if a.Connection == nil { |
| 268 | return nil, ErrConnectionClosed |
| 269 | } |
| 270 | return a.Connection.ReadWireMessage(ctx) |
| 271 | } |
| 272 | |
| 273 | func (a *acquired) Close() error { |
| 274 | a.Lock() |
| 275 | defer a.Unlock() |
| 276 | if a.Connection == nil { |
| 277 | return nil |
| 278 | } |
| 279 | err := a.Connection.Close() |
| 280 | a.sem.Release(1) |
| 281 | a.Connection = nil |
| 282 | return err |
| 283 | } |
| 284 | |
| 285 | func (a *acquired) Expired() bool { |
| 286 | a.Lock() |
| 287 | defer a.Unlock() |
| 288 | if a.Connection == nil { |
| 289 | return true |
| 290 | } |
| 291 | return a.Connection.Expired() |
| 292 | } |
| 293 | |
| 294 | func (a *acquired) Alive() bool { |
| 295 | a.Lock() |
| 296 | defer a.Unlock() |
| 297 | if a.Connection == nil { |
| 298 | return false |
| 299 | } |
| 300 | return a.Connection.Alive() |
| 301 | } |
| 302 | |
| 303 | func (a *acquired) ID() string { |
| 304 | a.Lock() |
| 305 | defer a.Unlock() |
| 306 | if a.Connection == nil { |
| 307 | return "<closed>" |
| 308 | } |
| 309 | return a.Connection.ID() |
| 310 | } |