blob: 355742bf35824b4248d46b94d18131f2d2f076ec [file] [log] [blame]
Joey Armstronga6af1522023-01-17 16:06:16 -05001package pool
2
3import (
4 "context"
5 "errors"
6 "net"
7 "sync"
8 "sync/atomic"
9 "time"
10
11 "github.com/go-redis/redis/v8/internal"
12)
13
14var (
15 ErrClosed = errors.New("redis: client is closed")
16 ErrPoolTimeout = errors.New("redis: connection pool timeout")
17)
18
19var timers = sync.Pool{
20 New: func() interface{} {
21 t := time.NewTimer(time.Hour)
22 t.Stop()
23 return t
24 },
25}
26
27// Stats contains pool state information and accumulated stats.
28type Stats struct {
29 Hits uint32 // number of times free connection was found in the pool
30 Misses uint32 // number of times free connection was NOT found in the pool
31 Timeouts uint32 // number of times a wait timeout occurred
32
33 TotalConns uint32 // number of total connections in the pool
34 IdleConns uint32 // number of idle connections in the pool
35 StaleConns uint32 // number of stale connections removed from the pool
36}
37
38type Pooler interface {
39 NewConn(context.Context) (*Conn, error)
40 CloseConn(*Conn) error
41
42 Get(context.Context) (*Conn, error)
43 Put(context.Context, *Conn)
44 Remove(context.Context, *Conn, error)
45
46 Len() int
47 IdleLen() int
48 Stats() *Stats
49
50 Close() error
51}
52
53type Options struct {
54 Dialer func(context.Context) (net.Conn, error)
55 OnClose func(*Conn) error
56
57 PoolSize int
58 MinIdleConns int
59 MaxConnAge time.Duration
60 PoolTimeout time.Duration
61 IdleTimeout time.Duration
62 IdleCheckFrequency time.Duration
63}
64
65type lastDialErrorWrap struct {
66 err error
67}
68
69type ConnPool struct {
70 opt *Options
71
72 dialErrorsNum uint32 // atomic
73
74 lastDialError atomic.Value
75
76 queue chan struct{}
77
78 connsMu sync.Mutex
79 conns []*Conn
80 idleConns []*Conn
81 poolSize int
82 idleConnsLen int
83
84 stats Stats
85
86 _closed uint32 // atomic
87 closedCh chan struct{}
88}
89
90var _ Pooler = (*ConnPool)(nil)
91
92func NewConnPool(opt *Options) *ConnPool {
93 p := &ConnPool{
94 opt: opt,
95
96 queue: make(chan struct{}, opt.PoolSize),
97 conns: make([]*Conn, 0, opt.PoolSize),
98 idleConns: make([]*Conn, 0, opt.PoolSize),
99 closedCh: make(chan struct{}),
100 }
101
102 p.connsMu.Lock()
103 p.checkMinIdleConns()
104 p.connsMu.Unlock()
105
106 if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 {
107 go p.reaper(opt.IdleCheckFrequency)
108 }
109
110 return p
111}
112
113func (p *ConnPool) checkMinIdleConns() {
114 if p.opt.MinIdleConns == 0 {
115 return
116 }
117 for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns {
118 p.poolSize++
119 p.idleConnsLen++
120 go func() {
121 err := p.addIdleConn()
122 if err != nil {
123 p.connsMu.Lock()
124 p.poolSize--
125 p.idleConnsLen--
126 p.connsMu.Unlock()
127 }
128 }()
129 }
130}
131
132func (p *ConnPool) addIdleConn() error {
133 cn, err := p.dialConn(context.TODO(), true)
134 if err != nil {
135 return err
136 }
137
138 p.connsMu.Lock()
139 p.conns = append(p.conns, cn)
140 p.idleConns = append(p.idleConns, cn)
141 p.connsMu.Unlock()
142 return nil
143}
144
145func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
146 return p.newConn(ctx, false)
147}
148
149func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
150 cn, err := p.dialConn(ctx, pooled)
151 if err != nil {
152 return nil, err
153 }
154
155 p.connsMu.Lock()
156 p.conns = append(p.conns, cn)
157 if pooled {
158 // If pool is full remove the cn on next Put.
159 if p.poolSize >= p.opt.PoolSize {
160 cn.pooled = false
161 } else {
162 p.poolSize++
163 }
164 }
165 p.connsMu.Unlock()
166
167 return cn, nil
168}
169
170func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
171 if p.closed() {
172 return nil, ErrClosed
173 }
174
175 if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
176 return nil, p.getLastDialError()
177 }
178
179 netConn, err := p.opt.Dialer(ctx)
180 if err != nil {
181 p.setLastDialError(err)
182 if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
183 go p.tryDial()
184 }
185 return nil, err
186 }
187
188 internal.NewConnectionsCounter.Add(ctx, 1)
189 cn := NewConn(netConn)
190 cn.pooled = pooled
191 return cn, nil
192}
193
194func (p *ConnPool) tryDial() {
195 for {
196 if p.closed() {
197 return
198 }
199
200 conn, err := p.opt.Dialer(context.Background())
201 if err != nil {
202 p.setLastDialError(err)
203 time.Sleep(time.Second)
204 continue
205 }
206
207 atomic.StoreUint32(&p.dialErrorsNum, 0)
208 _ = conn.Close()
209 return
210 }
211}
212
213func (p *ConnPool) setLastDialError(err error) {
214 p.lastDialError.Store(&lastDialErrorWrap{err: err})
215}
216
217func (p *ConnPool) getLastDialError() error {
218 err, _ := p.lastDialError.Load().(*lastDialErrorWrap)
219 if err != nil {
220 return err.err
221 }
222 return nil
223}
224
225// Get returns existed connection from the pool or creates a new one.
226func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
227 if p.closed() {
228 return nil, ErrClosed
229 }
230
231 err := p.waitTurn(ctx)
232 if err != nil {
233 return nil, err
234 }
235
236 for {
237 p.connsMu.Lock()
238 cn := p.popIdle()
239 p.connsMu.Unlock()
240
241 if cn == nil {
242 break
243 }
244
245 if p.isStaleConn(cn) {
246 _ = p.CloseConn(cn)
247 continue
248 }
249
250 atomic.AddUint32(&p.stats.Hits, 1)
251 return cn, nil
252 }
253
254 atomic.AddUint32(&p.stats.Misses, 1)
255
256 newcn, err := p.newConn(ctx, true)
257 if err != nil {
258 p.freeTurn()
259 return nil, err
260 }
261
262 return newcn, nil
263}
264
265func (p *ConnPool) getTurn() {
266 p.queue <- struct{}{}
267}
268
269func (p *ConnPool) waitTurn(ctx context.Context) error {
270 select {
271 case <-ctx.Done():
272 return ctx.Err()
273 default:
274 }
275
276 select {
277 case p.queue <- struct{}{}:
278 return nil
279 default:
280 }
281
282 timer := timers.Get().(*time.Timer)
283 timer.Reset(p.opt.PoolTimeout)
284
285 select {
286 case <-ctx.Done():
287 if !timer.Stop() {
288 <-timer.C
289 }
290 timers.Put(timer)
291 return ctx.Err()
292 case p.queue <- struct{}{}:
293 if !timer.Stop() {
294 <-timer.C
295 }
296 timers.Put(timer)
297 return nil
298 case <-timer.C:
299 timers.Put(timer)
300 atomic.AddUint32(&p.stats.Timeouts, 1)
301 return ErrPoolTimeout
302 }
303}
304
305func (p *ConnPool) freeTurn() {
306 <-p.queue
307}
308
309func (p *ConnPool) popIdle() *Conn {
310 if len(p.idleConns) == 0 {
311 return nil
312 }
313
314 idx := len(p.idleConns) - 1
315 cn := p.idleConns[idx]
316 p.idleConns = p.idleConns[:idx]
317 p.idleConnsLen--
318 p.checkMinIdleConns()
319 return cn
320}
321
322func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
323 if cn.rd.Buffered() > 0 {
324 internal.Logger.Printf(ctx, "Conn has unread data")
325 p.Remove(ctx, cn, BadConnError{})
326 return
327 }
328
329 if !cn.pooled {
330 p.Remove(ctx, cn, nil)
331 return
332 }
333
334 p.connsMu.Lock()
335 p.idleConns = append(p.idleConns, cn)
336 p.idleConnsLen++
337 p.connsMu.Unlock()
338 p.freeTurn()
339}
340
341func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
342 p.removeConnWithLock(cn)
343 p.freeTurn()
344 _ = p.closeConn(cn)
345}
346
347func (p *ConnPool) CloseConn(cn *Conn) error {
348 p.removeConnWithLock(cn)
349 return p.closeConn(cn)
350}
351
352func (p *ConnPool) removeConnWithLock(cn *Conn) {
353 p.connsMu.Lock()
354 p.removeConn(cn)
355 p.connsMu.Unlock()
356}
357
358func (p *ConnPool) removeConn(cn *Conn) {
359 for i, c := range p.conns {
360 if c == cn {
361 p.conns = append(p.conns[:i], p.conns[i+1:]...)
362 if cn.pooled {
363 p.poolSize--
364 p.checkMinIdleConns()
365 }
366 return
367 }
368 }
369}
370
371func (p *ConnPool) closeConn(cn *Conn) error {
372 if p.opt.OnClose != nil {
373 _ = p.opt.OnClose(cn)
374 }
375 return cn.Close()
376}
377
378// Len returns total number of connections.
379func (p *ConnPool) Len() int {
380 p.connsMu.Lock()
381 n := len(p.conns)
382 p.connsMu.Unlock()
383 return n
384}
385
386// IdleLen returns number of idle connections.
387func (p *ConnPool) IdleLen() int {
388 p.connsMu.Lock()
389 n := p.idleConnsLen
390 p.connsMu.Unlock()
391 return n
392}
393
394func (p *ConnPool) Stats() *Stats {
395 idleLen := p.IdleLen()
396 return &Stats{
397 Hits: atomic.LoadUint32(&p.stats.Hits),
398 Misses: atomic.LoadUint32(&p.stats.Misses),
399 Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
400
401 TotalConns: uint32(p.Len()),
402 IdleConns: uint32(idleLen),
403 StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
404 }
405}
406
407func (p *ConnPool) closed() bool {
408 return atomic.LoadUint32(&p._closed) == 1
409}
410
411func (p *ConnPool) Filter(fn func(*Conn) bool) error {
412 p.connsMu.Lock()
413 defer p.connsMu.Unlock()
414
415 var firstErr error
416 for _, cn := range p.conns {
417 if fn(cn) {
418 if err := p.closeConn(cn); err != nil && firstErr == nil {
419 firstErr = err
420 }
421 }
422 }
423 return firstErr
424}
425
426func (p *ConnPool) Close() error {
427 if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
428 return ErrClosed
429 }
430 close(p.closedCh)
431
432 var firstErr error
433 p.connsMu.Lock()
434 for _, cn := range p.conns {
435 if err := p.closeConn(cn); err != nil && firstErr == nil {
436 firstErr = err
437 }
438 }
439 p.conns = nil
440 p.poolSize = 0
441 p.idleConns = nil
442 p.idleConnsLen = 0
443 p.connsMu.Unlock()
444
445 return firstErr
446}
447
448func (p *ConnPool) reaper(frequency time.Duration) {
449 ticker := time.NewTicker(frequency)
450 defer ticker.Stop()
451
452 for {
453 select {
454 case <-ticker.C:
455 // It is possible that ticker and closedCh arrive together,
456 // and select pseudo-randomly pick ticker case, we double
457 // check here to prevent being executed after closed.
458 if p.closed() {
459 return
460 }
461 _, err := p.ReapStaleConns()
462 if err != nil {
463 internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err)
464 continue
465 }
466 case <-p.closedCh:
467 return
468 }
469 }
470}
471
472func (p *ConnPool) ReapStaleConns() (int, error) {
473 var n int
474 for {
475 p.getTurn()
476
477 p.connsMu.Lock()
478 cn := p.reapStaleConn()
479 p.connsMu.Unlock()
480 p.freeTurn()
481
482 if cn != nil {
483 _ = p.closeConn(cn)
484 n++
485 } else {
486 break
487 }
488 }
489 atomic.AddUint32(&p.stats.StaleConns, uint32(n))
490 return n, nil
491}
492
493func (p *ConnPool) reapStaleConn() *Conn {
494 if len(p.idleConns) == 0 {
495 return nil
496 }
497
498 cn := p.idleConns[0]
499 if !p.isStaleConn(cn) {
500 return nil
501 }
502
503 p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...)
504 p.idleConnsLen--
505 p.removeConn(cn)
506
507 return cn
508}
509
510func (p *ConnPool) isStaleConn(cn *Conn) bool {
511 if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 {
512 return false
513 }
514
515 now := time.Now()
516 if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout {
517 return true
518 }
519 if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge {
520 return true
521 }
522
523 return false
524}