| package pool |
| |
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "sync/atomic" |
| ) |
| |
| const ( |
| stateDefault = 0 |
| stateInited = 1 |
| stateClosed = 2 |
| ) |
| |
| type BadConnError struct { |
| wrapped error |
| } |
| |
| var _ error = (*BadConnError)(nil) |
| |
| func (e BadConnError) Error() string { |
| s := "redis: Conn is in a bad state" |
| if e.wrapped != nil { |
| s += ": " + e.wrapped.Error() |
| } |
| return s |
| } |
| |
| func (e BadConnError) Unwrap() error { |
| return e.wrapped |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| type StickyConnPool struct { |
| pool Pooler |
| shared int32 // atomic |
| |
| state uint32 // atomic |
| ch chan *Conn |
| |
| _badConnError atomic.Value |
| } |
| |
| var _ Pooler = (*StickyConnPool)(nil) |
| |
| func NewStickyConnPool(pool Pooler) *StickyConnPool { |
| p, ok := pool.(*StickyConnPool) |
| if !ok { |
| p = &StickyConnPool{ |
| pool: pool, |
| ch: make(chan *Conn, 1), |
| } |
| } |
| atomic.AddInt32(&p.shared, 1) |
| return p |
| } |
| |
| func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) { |
| return p.pool.NewConn(ctx) |
| } |
| |
| func (p *StickyConnPool) CloseConn(cn *Conn) error { |
| return p.pool.CloseConn(cn) |
| } |
| |
| func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { |
| // In worst case this races with Close which is not a very common operation. |
| for i := 0; i < 1000; i++ { |
| switch atomic.LoadUint32(&p.state) { |
| case stateDefault: |
| cn, err := p.pool.Get(ctx) |
| if err != nil { |
| return nil, err |
| } |
| if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { |
| return cn, nil |
| } |
| p.pool.Remove(ctx, cn, ErrClosed) |
| case stateInited: |
| if err := p.badConnError(); err != nil { |
| return nil, err |
| } |
| cn, ok := <-p.ch |
| if !ok { |
| return nil, ErrClosed |
| } |
| return cn, nil |
| case stateClosed: |
| return nil, ErrClosed |
| default: |
| panic("not reached") |
| } |
| } |
| return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop") |
| } |
| |
| func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) { |
| defer func() { |
| if recover() != nil { |
| p.freeConn(ctx, cn) |
| } |
| }() |
| p.ch <- cn |
| } |
| |
| func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) { |
| if err := p.badConnError(); err != nil { |
| p.pool.Remove(ctx, cn, err) |
| } else { |
| p.pool.Put(ctx, cn) |
| } |
| } |
| |
| func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) { |
| defer func() { |
| if recover() != nil { |
| p.pool.Remove(ctx, cn, ErrClosed) |
| } |
| }() |
| p._badConnError.Store(BadConnError{wrapped: reason}) |
| p.ch <- cn |
| } |
| |
| func (p *StickyConnPool) Close() error { |
| if shared := atomic.AddInt32(&p.shared, -1); shared > 0 { |
| return nil |
| } |
| |
| for i := 0; i < 1000; i++ { |
| state := atomic.LoadUint32(&p.state) |
| if state == stateClosed { |
| return ErrClosed |
| } |
| if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { |
| close(p.ch) |
| cn, ok := <-p.ch |
| if ok { |
| p.freeConn(context.TODO(), cn) |
| } |
| return nil |
| } |
| } |
| |
| return errors.New("redis: StickyConnPool.Close: infinite loop") |
| } |
| |
| func (p *StickyConnPool) Reset(ctx context.Context) error { |
| if p.badConnError() == nil { |
| return nil |
| } |
| |
| select { |
| case cn, ok := <-p.ch: |
| if !ok { |
| return ErrClosed |
| } |
| p.pool.Remove(ctx, cn, ErrClosed) |
| p._badConnError.Store(BadConnError{wrapped: nil}) |
| default: |
| return errors.New("redis: StickyConnPool does not have a Conn") |
| } |
| |
| if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { |
| state := atomic.LoadUint32(&p.state) |
| return fmt.Errorf("redis: invalid StickyConnPool state: %d", state) |
| } |
| |
| return nil |
| } |
| |
| func (p *StickyConnPool) badConnError() error { |
| if v := p._badConnError.Load(); v != nil { |
| err := v.(BadConnError) |
| if err.wrapped != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| func (p *StickyConnPool) Len() int { |
| switch atomic.LoadUint32(&p.state) { |
| case stateDefault: |
| return 0 |
| case stateInited: |
| return 1 |
| case stateClosed: |
| return 0 |
| default: |
| panic("not reached") |
| } |
| } |
| |
| func (p *StickyConnPool) IdleLen() int { |
| return len(p.ch) |
| } |
| |
| func (p *StickyConnPool) Stats() *Stats { |
| return &Stats{} |
| } |