blob: c3e7e7c0458e5c916021acaeb436e863e7858865 [file] [log] [blame]
Joey Armstrong5f51f2e2023-01-17 17:06:26 -05001package pool
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "sync/atomic"
8)
9
10const (
11 stateDefault = 0
12 stateInited = 1
13 stateClosed = 2
14)
15
16type BadConnError struct {
17 wrapped error
18}
19
20var _ error = (*BadConnError)(nil)
21
22func (e BadConnError) Error() string {
23 s := "redis: Conn is in a bad state"
24 if e.wrapped != nil {
25 s += ": " + e.wrapped.Error()
26 }
27 return s
28}
29
30func (e BadConnError) Unwrap() error {
31 return e.wrapped
32}
33
34//------------------------------------------------------------------------------
35
36type StickyConnPool struct {
37 pool Pooler
38 shared int32 // atomic
39
40 state uint32 // atomic
41 ch chan *Conn
42
43 _badConnError atomic.Value
44}
45
46var _ Pooler = (*StickyConnPool)(nil)
47
48func NewStickyConnPool(pool Pooler) *StickyConnPool {
49 p, ok := pool.(*StickyConnPool)
50 if !ok {
51 p = &StickyConnPool{
52 pool: pool,
53 ch: make(chan *Conn, 1),
54 }
55 }
56 atomic.AddInt32(&p.shared, 1)
57 return p
58}
59
60func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
61 return p.pool.NewConn(ctx)
62}
63
64func (p *StickyConnPool) CloseConn(cn *Conn) error {
65 return p.pool.CloseConn(cn)
66}
67
68func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
69 // In worst case this races with Close which is not a very common operation.
70 for i := 0; i < 1000; i++ {
71 switch atomic.LoadUint32(&p.state) {
72 case stateDefault:
73 cn, err := p.pool.Get(ctx)
74 if err != nil {
75 return nil, err
76 }
77 if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
78 return cn, nil
79 }
80 p.pool.Remove(ctx, cn, ErrClosed)
81 case stateInited:
82 if err := p.badConnError(); err != nil {
83 return nil, err
84 }
85 cn, ok := <-p.ch
86 if !ok {
87 return nil, ErrClosed
88 }
89 return cn, nil
90 case stateClosed:
91 return nil, ErrClosed
92 default:
93 panic("not reached")
94 }
95 }
96 return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop")
97}
98
99func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
100 defer func() {
101 if recover() != nil {
102 p.freeConn(ctx, cn)
103 }
104 }()
105 p.ch <- cn
106}
107
108func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
109 if err := p.badConnError(); err != nil {
110 p.pool.Remove(ctx, cn, err)
111 } else {
112 p.pool.Put(ctx, cn)
113 }
114}
115
116func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
117 defer func() {
118 if recover() != nil {
119 p.pool.Remove(ctx, cn, ErrClosed)
120 }
121 }()
122 p._badConnError.Store(BadConnError{wrapped: reason})
123 p.ch <- cn
124}
125
126func (p *StickyConnPool) Close() error {
127 if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
128 return nil
129 }
130
131 for i := 0; i < 1000; i++ {
132 state := atomic.LoadUint32(&p.state)
133 if state == stateClosed {
134 return ErrClosed
135 }
136 if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
137 close(p.ch)
138 cn, ok := <-p.ch
139 if ok {
140 p.freeConn(context.TODO(), cn)
141 }
142 return nil
143 }
144 }
145
146 return errors.New("redis: StickyConnPool.Close: infinite loop")
147}
148
149func (p *StickyConnPool) Reset(ctx context.Context) error {
150 if p.badConnError() == nil {
151 return nil
152 }
153
154 select {
155 case cn, ok := <-p.ch:
156 if !ok {
157 return ErrClosed
158 }
159 p.pool.Remove(ctx, cn, ErrClosed)
160 p._badConnError.Store(BadConnError{wrapped: nil})
161 default:
162 return errors.New("redis: StickyConnPool does not have a Conn")
163 }
164
165 if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
166 state := atomic.LoadUint32(&p.state)
167 return fmt.Errorf("redis: invalid StickyConnPool state: %d", state)
168 }
169
170 return nil
171}
172
173func (p *StickyConnPool) badConnError() error {
174 if v := p._badConnError.Load(); v != nil {
175 err := v.(BadConnError)
176 if err.wrapped != nil {
177 return err
178 }
179 }
180 return nil
181}
182
183func (p *StickyConnPool) Len() int {
184 switch atomic.LoadUint32(&p.state) {
185 case stateDefault:
186 return 0
187 case stateInited:
188 return 1
189 case stateClosed:
190 return 0
191 default:
192 panic("not reached")
193 }
194}
195
196func (p *StickyConnPool) IdleLen() int {
197 return len(p.ch)
198}
199
200func (p *StickyConnPool) Stats() *Stats {
201 return &Stats{}
202}