| package pool |
| |
| import ( |
| "context" |
| "errors" |
| "net" |
| "sync" |
| "sync/atomic" |
| "time" |
| |
| "github.com/go-redis/redis/v8/internal" |
| ) |
| |
| var ( |
| // ErrClosed performs any operation on the closed client will return this error. |
| ErrClosed = errors.New("redis: client is closed") |
| |
| // ErrPoolTimeout timed out waiting to get a connection from the connection pool. |
| ErrPoolTimeout = errors.New("redis: connection pool timeout") |
| ) |
| |
| var timers = sync.Pool{ |
| New: func() interface{} { |
| t := time.NewTimer(time.Hour) |
| t.Stop() |
| return t |
| }, |
| } |
| |
| // Stats contains pool state information and accumulated stats. |
| type Stats struct { |
| Hits uint32 // number of times free connection was found in the pool |
| Misses uint32 // number of times free connection was NOT found in the pool |
| Timeouts uint32 // number of times a wait timeout occurred |
| |
| TotalConns uint32 // number of total connections in the pool |
| IdleConns uint32 // number of idle connections in the pool |
| StaleConns uint32 // number of stale connections removed from the pool |
| } |
| |
| type Pooler interface { |
| NewConn(context.Context) (*Conn, error) |
| CloseConn(*Conn) error |
| |
| Get(context.Context) (*Conn, error) |
| Put(context.Context, *Conn) |
| Remove(context.Context, *Conn, error) |
| |
| Len() int |
| IdleLen() int |
| Stats() *Stats |
| |
| Close() error |
| } |
| |
| type Options struct { |
| Dialer func(context.Context) (net.Conn, error) |
| OnClose func(*Conn) error |
| |
| PoolFIFO bool |
| PoolSize int |
| MinIdleConns int |
| MaxConnAge time.Duration |
| PoolTimeout time.Duration |
| IdleTimeout time.Duration |
| IdleCheckFrequency time.Duration |
| } |
| |
| type lastDialErrorWrap struct { |
| err error |
| } |
| |
| type ConnPool struct { |
| opt *Options |
| |
| dialErrorsNum uint32 // atomic |
| |
| lastDialError atomic.Value |
| |
| queue chan struct{} |
| |
| connsMu sync.Mutex |
| conns []*Conn |
| idleConns []*Conn |
| poolSize int |
| idleConnsLen int |
| |
| stats Stats |
| |
| _closed uint32 // atomic |
| closedCh chan struct{} |
| } |
| |
| var _ Pooler = (*ConnPool)(nil) |
| |
| func NewConnPool(opt *Options) *ConnPool { |
| p := &ConnPool{ |
| opt: opt, |
| |
| queue: make(chan struct{}, opt.PoolSize), |
| conns: make([]*Conn, 0, opt.PoolSize), |
| idleConns: make([]*Conn, 0, opt.PoolSize), |
| closedCh: make(chan struct{}), |
| } |
| |
| p.connsMu.Lock() |
| p.checkMinIdleConns() |
| p.connsMu.Unlock() |
| |
| if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { |
| go p.reaper(opt.IdleCheckFrequency) |
| } |
| |
| return p |
| } |
| |
| func (p *ConnPool) checkMinIdleConns() { |
| if p.opt.MinIdleConns == 0 { |
| return |
| } |
| for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { |
| p.poolSize++ |
| p.idleConnsLen++ |
| |
| go func() { |
| err := p.addIdleConn() |
| if err != nil && err != ErrClosed { |
| p.connsMu.Lock() |
| p.poolSize-- |
| p.idleConnsLen-- |
| p.connsMu.Unlock() |
| } |
| }() |
| } |
| } |
| |
| func (p *ConnPool) addIdleConn() error { |
| cn, err := p.dialConn(context.TODO(), true) |
| if err != nil { |
| return err |
| } |
| |
| p.connsMu.Lock() |
| defer p.connsMu.Unlock() |
| |
| // It is not allowed to add new connections to the closed connection pool. |
| if p.closed() { |
| _ = cn.Close() |
| return ErrClosed |
| } |
| |
| p.conns = append(p.conns, cn) |
| p.idleConns = append(p.idleConns, cn) |
| return nil |
| } |
| |
| func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { |
| return p.newConn(ctx, false) |
| } |
| |
| func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { |
| cn, err := p.dialConn(ctx, pooled) |
| if err != nil { |
| return nil, err |
| } |
| |
| p.connsMu.Lock() |
| defer p.connsMu.Unlock() |
| |
| // It is not allowed to add new connections to the closed connection pool. |
| if p.closed() { |
| _ = cn.Close() |
| return nil, ErrClosed |
| } |
| |
| p.conns = append(p.conns, cn) |
| if pooled { |
| // If pool is full remove the cn on next Put. |
| if p.poolSize >= p.opt.PoolSize { |
| cn.pooled = false |
| } else { |
| p.poolSize++ |
| } |
| } |
| |
| return cn, nil |
| } |
| |
| func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { |
| if p.closed() { |
| return nil, ErrClosed |
| } |
| |
| if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) { |
| return nil, p.getLastDialError() |
| } |
| |
| netConn, err := p.opt.Dialer(ctx) |
| if err != nil { |
| p.setLastDialError(err) |
| if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { |
| go p.tryDial() |
| } |
| return nil, err |
| } |
| |
| cn := NewConn(netConn) |
| cn.pooled = pooled |
| return cn, nil |
| } |
| |
| func (p *ConnPool) tryDial() { |
| for { |
| if p.closed() { |
| return |
| } |
| |
| conn, err := p.opt.Dialer(context.Background()) |
| if err != nil { |
| p.setLastDialError(err) |
| time.Sleep(time.Second) |
| continue |
| } |
| |
| atomic.StoreUint32(&p.dialErrorsNum, 0) |
| _ = conn.Close() |
| return |
| } |
| } |
| |
| func (p *ConnPool) setLastDialError(err error) { |
| p.lastDialError.Store(&lastDialErrorWrap{err: err}) |
| } |
| |
| func (p *ConnPool) getLastDialError() error { |
| err, _ := p.lastDialError.Load().(*lastDialErrorWrap) |
| if err != nil { |
| return err.err |
| } |
| return nil |
| } |
| |
| // Get returns existed connection from the pool or creates a new one. |
| func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { |
| if p.closed() { |
| return nil, ErrClosed |
| } |
| |
| if err := p.waitTurn(ctx); err != nil { |
| return nil, err |
| } |
| |
| for { |
| p.connsMu.Lock() |
| cn, err := p.popIdle() |
| p.connsMu.Unlock() |
| |
| if err != nil { |
| return nil, err |
| } |
| |
| if cn == nil { |
| break |
| } |
| |
| if p.isStaleConn(cn) { |
| _ = p.CloseConn(cn) |
| continue |
| } |
| |
| atomic.AddUint32(&p.stats.Hits, 1) |
| return cn, nil |
| } |
| |
| atomic.AddUint32(&p.stats.Misses, 1) |
| |
| newcn, err := p.newConn(ctx, true) |
| if err != nil { |
| p.freeTurn() |
| return nil, err |
| } |
| |
| return newcn, nil |
| } |
| |
| func (p *ConnPool) getTurn() { |
| p.queue <- struct{}{} |
| } |
| |
| func (p *ConnPool) waitTurn(ctx context.Context) error { |
| select { |
| case <-ctx.Done(): |
| return ctx.Err() |
| default: |
| } |
| |
| select { |
| case p.queue <- struct{}{}: |
| return nil |
| default: |
| } |
| |
| timer := timers.Get().(*time.Timer) |
| timer.Reset(p.opt.PoolTimeout) |
| |
| select { |
| case <-ctx.Done(): |
| if !timer.Stop() { |
| <-timer.C |
| } |
| timers.Put(timer) |
| return ctx.Err() |
| case p.queue <- struct{}{}: |
| if !timer.Stop() { |
| <-timer.C |
| } |
| timers.Put(timer) |
| return nil |
| case <-timer.C: |
| timers.Put(timer) |
| atomic.AddUint32(&p.stats.Timeouts, 1) |
| return ErrPoolTimeout |
| } |
| } |
| |
| func (p *ConnPool) freeTurn() { |
| <-p.queue |
| } |
| |
| func (p *ConnPool) popIdle() (*Conn, error) { |
| if p.closed() { |
| return nil, ErrClosed |
| } |
| n := len(p.idleConns) |
| if n == 0 { |
| return nil, nil |
| } |
| |
| var cn *Conn |
| if p.opt.PoolFIFO { |
| cn = p.idleConns[0] |
| copy(p.idleConns, p.idleConns[1:]) |
| p.idleConns = p.idleConns[:n-1] |
| } else { |
| idx := n - 1 |
| cn = p.idleConns[idx] |
| p.idleConns = p.idleConns[:idx] |
| } |
| p.idleConnsLen-- |
| p.checkMinIdleConns() |
| return cn, nil |
| } |
| |
| func (p *ConnPool) Put(ctx context.Context, cn *Conn) { |
| if cn.rd.Buffered() > 0 { |
| internal.Logger.Printf(ctx, "Conn has unread data") |
| p.Remove(ctx, cn, BadConnError{}) |
| return |
| } |
| |
| if !cn.pooled { |
| p.Remove(ctx, cn, nil) |
| return |
| } |
| |
| p.connsMu.Lock() |
| p.idleConns = append(p.idleConns, cn) |
| p.idleConnsLen++ |
| p.connsMu.Unlock() |
| p.freeTurn() |
| } |
| |
| func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { |
| p.removeConnWithLock(cn) |
| p.freeTurn() |
| _ = p.closeConn(cn) |
| } |
| |
| func (p *ConnPool) CloseConn(cn *Conn) error { |
| p.removeConnWithLock(cn) |
| return p.closeConn(cn) |
| } |
| |
| func (p *ConnPool) removeConnWithLock(cn *Conn) { |
| p.connsMu.Lock() |
| p.removeConn(cn) |
| p.connsMu.Unlock() |
| } |
| |
| func (p *ConnPool) removeConn(cn *Conn) { |
| for i, c := range p.conns { |
| if c == cn { |
| p.conns = append(p.conns[:i], p.conns[i+1:]...) |
| if cn.pooled { |
| p.poolSize-- |
| p.checkMinIdleConns() |
| } |
| return |
| } |
| } |
| } |
| |
| func (p *ConnPool) closeConn(cn *Conn) error { |
| if p.opt.OnClose != nil { |
| _ = p.opt.OnClose(cn) |
| } |
| return cn.Close() |
| } |
| |
| // Len returns total number of connections. |
| func (p *ConnPool) Len() int { |
| p.connsMu.Lock() |
| n := len(p.conns) |
| p.connsMu.Unlock() |
| return n |
| } |
| |
| // IdleLen returns number of idle connections. |
| func (p *ConnPool) IdleLen() int { |
| p.connsMu.Lock() |
| n := p.idleConnsLen |
| p.connsMu.Unlock() |
| return n |
| } |
| |
| func (p *ConnPool) Stats() *Stats { |
| idleLen := p.IdleLen() |
| return &Stats{ |
| Hits: atomic.LoadUint32(&p.stats.Hits), |
| Misses: atomic.LoadUint32(&p.stats.Misses), |
| Timeouts: atomic.LoadUint32(&p.stats.Timeouts), |
| |
| TotalConns: uint32(p.Len()), |
| IdleConns: uint32(idleLen), |
| StaleConns: atomic.LoadUint32(&p.stats.StaleConns), |
| } |
| } |
| |
| func (p *ConnPool) closed() bool { |
| return atomic.LoadUint32(&p._closed) == 1 |
| } |
| |
| func (p *ConnPool) Filter(fn func(*Conn) bool) error { |
| p.connsMu.Lock() |
| defer p.connsMu.Unlock() |
| |
| var firstErr error |
| for _, cn := range p.conns { |
| if fn(cn) { |
| if err := p.closeConn(cn); err != nil && firstErr == nil { |
| firstErr = err |
| } |
| } |
| } |
| return firstErr |
| } |
| |
| func (p *ConnPool) Close() error { |
| if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { |
| return ErrClosed |
| } |
| close(p.closedCh) |
| |
| var firstErr error |
| p.connsMu.Lock() |
| for _, cn := range p.conns { |
| if err := p.closeConn(cn); err != nil && firstErr == nil { |
| firstErr = err |
| } |
| } |
| p.conns = nil |
| p.poolSize = 0 |
| p.idleConns = nil |
| p.idleConnsLen = 0 |
| p.connsMu.Unlock() |
| |
| return firstErr |
| } |
| |
| func (p *ConnPool) reaper(frequency time.Duration) { |
| ticker := time.NewTicker(frequency) |
| defer ticker.Stop() |
| |
| for { |
| select { |
| case <-ticker.C: |
| // It is possible that ticker and closedCh arrive together, |
| // and select pseudo-randomly pick ticker case, we double |
| // check here to prevent being executed after closed. |
| if p.closed() { |
| return |
| } |
| _, err := p.ReapStaleConns() |
| if err != nil { |
| internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err) |
| continue |
| } |
| case <-p.closedCh: |
| return |
| } |
| } |
| } |
| |
| func (p *ConnPool) ReapStaleConns() (int, error) { |
| var n int |
| for { |
| p.getTurn() |
| |
| p.connsMu.Lock() |
| cn := p.reapStaleConn() |
| p.connsMu.Unlock() |
| |
| p.freeTurn() |
| |
| if cn != nil { |
| _ = p.closeConn(cn) |
| n++ |
| } else { |
| break |
| } |
| } |
| atomic.AddUint32(&p.stats.StaleConns, uint32(n)) |
| return n, nil |
| } |
| |
| func (p *ConnPool) reapStaleConn() *Conn { |
| if len(p.idleConns) == 0 { |
| return nil |
| } |
| |
| cn := p.idleConns[0] |
| if !p.isStaleConn(cn) { |
| return nil |
| } |
| |
| p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) |
| p.idleConnsLen-- |
| p.removeConn(cn) |
| |
| return cn |
| } |
| |
| func (p *ConnPool) isStaleConn(cn *Conn) bool { |
| if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { |
| return false |
| } |
| |
| now := time.Now() |
| if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { |
| return true |
| } |
| if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge { |
| return true |
| } |
| |
| return false |
| } |