blob: 3a25ad379c295a70f3ff53c81bb5cba95c3e45e4 [file] [log] [blame]
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package connection
import (
"context"
"sync"
"sync/atomic"
"github.com/mongodb/mongo-go-driver/x/network/address"
"github.com/mongodb/mongo-go-driver/x/network/description"
"github.com/mongodb/mongo-go-driver/x/network/wiremessage"
"golang.org/x/sync/semaphore"
)
// ErrPoolClosed is returned from an attempt to use a closed pool.
var ErrPoolClosed = PoolError("pool is closed")
// ErrSizeLargerThanCapacity is returned from an attempt to create a pool with a size
// larger than the capacity.
var ErrSizeLargerThanCapacity = PoolError("size is larger than capacity")
// ErrPoolConnected is returned from an attempt to connect an already connected pool
var ErrPoolConnected = PoolError("pool is connected")
// ErrPoolDisconnected is returned from an attempt to disconnect an already disconnected
// or disconnecting pool.
var ErrPoolDisconnected = PoolError("pool is disconnected or disconnecting")
// ErrConnectionClosed is returned from an attempt to use an already closed connection.
var ErrConnectionClosed = Error{ConnectionID: "<closed>", message: "connection is closed"}
// These constants represent the connection states of a pool.
const (
disconnected int32 = iota
disconnecting
connected
)
// Pool is used to pool Connections to a server.
type Pool interface {
// Get must return a nil *description.Server if the returned connection is
// not a newly dialed connection.
Get(context.Context) (Connection, *description.Server, error)
// Connect handles the initialization of a Pool and allow Connections to be
// retrieved and pooled. Implementations must return an error if Connect is
// called more than once before calling Disconnect.
Connect(context.Context) error
// Disconnect closest connections managed by this Pool. Implementations must
// either wait until all of the connections in use have been returned and
// closed or the context expires before returning. If the context expires
// via cancellation, deadline, timeout, or some other manner, implementations
// must close the in use connections. If this method returns with no errors,
// all connections managed by this pool must be closed. Calling Disconnect
// multiple times after a single Connect call must result in an error.
Disconnect(context.Context) error
Drain() error
}
type pool struct {
address address.Address
opts []Option
conns chan *pooledConnection
generation uint64
sem *semaphore.Weighted
connected int32
nextid uint64
capacity uint64
inflight map[uint64]*pooledConnection
sync.Mutex
}
// NewPool creates a new pool that will hold size number of idle connections
// and will create a max of capacity connections. It will use the provided
// options.
func NewPool(addr address.Address, size, capacity uint64, opts ...Option) (Pool, error) {
if size > capacity {
return nil, ErrSizeLargerThanCapacity
}
p := &pool{
address: addr,
conns: make(chan *pooledConnection, size),
generation: 0,
sem: semaphore.NewWeighted(int64(capacity)),
connected: disconnected,
capacity: capacity,
inflight: make(map[uint64]*pooledConnection),
opts: opts,
}
return p, nil
}
func (p *pool) Drain() error {
atomic.AddUint64(&p.generation, 1)
return nil
}
func (p *pool) Connect(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) {
return ErrPoolConnected
}
atomic.AddUint64(&p.generation, 1)
return nil
}
func (p *pool) Disconnect(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&p.connected, connected, disconnecting) {
return ErrPoolDisconnected
}
// We first clear out the idle connections, then we attempt to acquire the entire capacity
// semaphore. If the context is either cancelled, the deadline expires, or there is a timeout
// the semaphore acquire method will return an error. If that happens, we will aggressively
// close the remaining open connections. If we were able to successfully acquire the semaphore,
// then all of the in flight connections have been closed and we release the semaphore.
loop:
for {
select {
case pc := <-p.conns:
// This error would be overwritten by the semaphore
_ = p.closeConnection(pc)
default:
break loop
}
}
err := p.sem.Acquire(ctx, int64(p.capacity))
if err != nil {
p.Lock()
// We copy the remaining connections to close into a slice, then
// iterate the slice to do the closing. This allows us to use a single
// function to actually clean up and close connections at the expense of
// a double iteration in the worst case.
toClose := make([]*pooledConnection, 0, len(p.inflight))
for _, pc := range p.inflight {
toClose = append(toClose, pc)
}
p.Unlock()
for _, pc := range toClose {
_ = pc.Close()
}
} else {
p.sem.Release(int64(p.capacity))
}
atomic.StoreInt32(&p.connected, disconnected)
return nil
}
func (p *pool) Get(ctx context.Context) (Connection, *description.Server, error) {
if atomic.LoadInt32(&p.connected) != connected {
return nil, nil, ErrPoolClosed
}
err := p.sem.Acquire(ctx, 1)
if err != nil {
return nil, nil, err
}
return p.get(ctx)
}
func (p *pool) get(ctx context.Context) (Connection, *description.Server, error) {
g := atomic.LoadUint64(&p.generation)
select {
case c := <-p.conns:
if c.Expired() {
go p.closeConnection(c)
return p.get(ctx)
}
return &acquired{Connection: c, sem: p.sem}, nil, nil
case <-ctx.Done():
p.sem.Release(1)
return nil, nil, ctx.Err()
default:
c, desc, err := New(ctx, p.address, p.opts...)
if err != nil {
p.sem.Release(1)
return nil, nil, err
}
pc := &pooledConnection{
Connection: c,
p: p,
generation: g,
id: atomic.AddUint64(&p.nextid, 1),
}
p.Lock()
if atomic.LoadInt32(&p.connected) != connected {
p.Unlock()
p.sem.Release(1)
p.closeConnection(pc)
return nil, nil, ErrPoolClosed
}
defer p.Unlock()
p.inflight[pc.id] = pc
return &acquired{Connection: pc, sem: p.sem}, desc, nil
}
}
func (p *pool) closeConnection(pc *pooledConnection) error {
if !atomic.CompareAndSwapInt32(&pc.closed, 0, 1) {
return nil
}
p.Lock()
delete(p.inflight, pc.id)
p.Unlock()
return pc.Connection.Close()
}
func (p *pool) returnConnection(pc *pooledConnection) error {
if atomic.LoadInt32(&p.connected) != connected || pc.Expired() {
return p.closeConnection(pc)
}
select {
case p.conns <- pc:
return nil
default:
return p.closeConnection(pc)
}
}
func (p *pool) isExpired(generation uint64) bool {
return generation < atomic.LoadUint64(&p.generation)
}
type pooledConnection struct {
Connection
p *pool
generation uint64
id uint64
closed int32
}
func (pc *pooledConnection) Close() error {
return pc.p.returnConnection(pc)
}
func (pc *pooledConnection) Expired() bool {
return pc.Connection.Expired() || pc.p.isExpired(pc.generation)
}
type acquired struct {
Connection
sem *semaphore.Weighted
sync.Mutex
}
func (a *acquired) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
a.Lock()
defer a.Unlock()
if a.Connection == nil {
return ErrConnectionClosed
}
return a.Connection.WriteWireMessage(ctx, wm)
}
func (a *acquired) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) {
a.Lock()
defer a.Unlock()
if a.Connection == nil {
return nil, ErrConnectionClosed
}
return a.Connection.ReadWireMessage(ctx)
}
func (a *acquired) Close() error {
a.Lock()
defer a.Unlock()
if a.Connection == nil {
return nil
}
err := a.Connection.Close()
a.sem.Release(1)
a.Connection = nil
return err
}
func (a *acquired) Expired() bool {
a.Lock()
defer a.Unlock()
if a.Connection == nil {
return true
}
return a.Connection.Expired()
}
func (a *acquired) Alive() bool {
a.Lock()
defer a.Unlock()
if a.Connection == nil {
return false
}
return a.Connection.Alive()
}
func (a *acquired) ID() string {
a.Lock()
defer a.Unlock()
if a.Connection == nil {
return "<closed>"
}
return a.Connection.ID()
}