blob: 44a4e779dfae4531d91d1e0632091f6d607fab41 [file] [log] [blame]
Naveen Sampath04696f72022-06-13 15:19:14 +05301package pool
2
3import (
4 "context"
5 "errors"
6 "net"
7 "sync"
8 "sync/atomic"
9 "time"
10
11 "github.com/go-redis/redis/v8/internal"
12)
13
14var (
15 // ErrClosed performs any operation on the closed client will return this error.
16 ErrClosed = errors.New("redis: client is closed")
17
18 // ErrPoolTimeout timed out waiting to get a connection from the connection pool.
19 ErrPoolTimeout = errors.New("redis: connection pool timeout")
20)
21
22var timers = sync.Pool{
23 New: func() interface{} {
24 t := time.NewTimer(time.Hour)
25 t.Stop()
26 return t
27 },
28}
29
30// Stats contains pool state information and accumulated stats.
31type Stats struct {
32 Hits uint32 // number of times free connection was found in the pool
33 Misses uint32 // number of times free connection was NOT found in the pool
34 Timeouts uint32 // number of times a wait timeout occurred
35
36 TotalConns uint32 // number of total connections in the pool
37 IdleConns uint32 // number of idle connections in the pool
38 StaleConns uint32 // number of stale connections removed from the pool
39}
40
41type Pooler interface {
42 NewConn(context.Context) (*Conn, error)
43 CloseConn(*Conn) error
44
45 Get(context.Context) (*Conn, error)
46 Put(context.Context, *Conn)
47 Remove(context.Context, *Conn, error)
48
49 Len() int
50 IdleLen() int
51 Stats() *Stats
52
53 Close() error
54}
55
56type Options struct {
57 Dialer func(context.Context) (net.Conn, error)
58 OnClose func(*Conn) error
59
60 PoolFIFO bool
61 PoolSize int
62 MinIdleConns int
63 MaxConnAge time.Duration
64 PoolTimeout time.Duration
65 IdleTimeout time.Duration
66 IdleCheckFrequency time.Duration
67}
68
69type lastDialErrorWrap struct {
70 err error
71}
72
73type ConnPool struct {
74 opt *Options
75
76 dialErrorsNum uint32 // atomic
77
78 lastDialError atomic.Value
79
80 queue chan struct{}
81
82 connsMu sync.Mutex
83 conns []*Conn
84 idleConns []*Conn
85 poolSize int
86 idleConnsLen int
87
88 stats Stats
89
90 _closed uint32 // atomic
91 closedCh chan struct{}
92}
93
94var _ Pooler = (*ConnPool)(nil)
95
96func NewConnPool(opt *Options) *ConnPool {
97 p := &ConnPool{
98 opt: opt,
99
100 queue: make(chan struct{}, opt.PoolSize),
101 conns: make([]*Conn, 0, opt.PoolSize),
102 idleConns: make([]*Conn, 0, opt.PoolSize),
103 closedCh: make(chan struct{}),
104 }
105
106 p.connsMu.Lock()
107 p.checkMinIdleConns()
108 p.connsMu.Unlock()
109
110 if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 {
111 go p.reaper(opt.IdleCheckFrequency)
112 }
113
114 return p
115}
116
117func (p *ConnPool) checkMinIdleConns() {
118 if p.opt.MinIdleConns == 0 {
119 return
120 }
121 for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns {
122 p.poolSize++
123 p.idleConnsLen++
124
125 go func() {
126 err := p.addIdleConn()
127 if err != nil && err != ErrClosed {
128 p.connsMu.Lock()
129 p.poolSize--
130 p.idleConnsLen--
131 p.connsMu.Unlock()
132 }
133 }()
134 }
135}
136
137func (p *ConnPool) addIdleConn() error {
138 cn, err := p.dialConn(context.TODO(), true)
139 if err != nil {
140 return err
141 }
142
143 p.connsMu.Lock()
144 defer p.connsMu.Unlock()
145
146 // It is not allowed to add new connections to the closed connection pool.
147 if p.closed() {
148 _ = cn.Close()
149 return ErrClosed
150 }
151
152 p.conns = append(p.conns, cn)
153 p.idleConns = append(p.idleConns, cn)
154 return nil
155}
156
157func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
158 return p.newConn(ctx, false)
159}
160
161func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
162 cn, err := p.dialConn(ctx, pooled)
163 if err != nil {
164 return nil, err
165 }
166
167 p.connsMu.Lock()
168 defer p.connsMu.Unlock()
169
170 // It is not allowed to add new connections to the closed connection pool.
171 if p.closed() {
172 _ = cn.Close()
173 return nil, ErrClosed
174 }
175
176 p.conns = append(p.conns, cn)
177 if pooled {
178 // If pool is full remove the cn on next Put.
179 if p.poolSize >= p.opt.PoolSize {
180 cn.pooled = false
181 } else {
182 p.poolSize++
183 }
184 }
185
186 return cn, nil
187}
188
189func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
190 if p.closed() {
191 return nil, ErrClosed
192 }
193
194 if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
195 return nil, p.getLastDialError()
196 }
197
198 netConn, err := p.opt.Dialer(ctx)
199 if err != nil {
200 p.setLastDialError(err)
201 if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
202 go p.tryDial()
203 }
204 return nil, err
205 }
206
207 cn := NewConn(netConn)
208 cn.pooled = pooled
209 return cn, nil
210}
211
212func (p *ConnPool) tryDial() {
213 for {
214 if p.closed() {
215 return
216 }
217
218 conn, err := p.opt.Dialer(context.Background())
219 if err != nil {
220 p.setLastDialError(err)
221 time.Sleep(time.Second)
222 continue
223 }
224
225 atomic.StoreUint32(&p.dialErrorsNum, 0)
226 _ = conn.Close()
227 return
228 }
229}
230
231func (p *ConnPool) setLastDialError(err error) {
232 p.lastDialError.Store(&lastDialErrorWrap{err: err})
233}
234
235func (p *ConnPool) getLastDialError() error {
236 err, _ := p.lastDialError.Load().(*lastDialErrorWrap)
237 if err != nil {
238 return err.err
239 }
240 return nil
241}
242
243// Get returns existed connection from the pool or creates a new one.
244func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
245 if p.closed() {
246 return nil, ErrClosed
247 }
248
249 if err := p.waitTurn(ctx); err != nil {
250 return nil, err
251 }
252
253 for {
254 p.connsMu.Lock()
255 cn, err := p.popIdle()
256 p.connsMu.Unlock()
257
258 if err != nil {
259 return nil, err
260 }
261
262 if cn == nil {
263 break
264 }
265
266 if p.isStaleConn(cn) {
267 _ = p.CloseConn(cn)
268 continue
269 }
270
271 atomic.AddUint32(&p.stats.Hits, 1)
272 return cn, nil
273 }
274
275 atomic.AddUint32(&p.stats.Misses, 1)
276
277 newcn, err := p.newConn(ctx, true)
278 if err != nil {
279 p.freeTurn()
280 return nil, err
281 }
282
283 return newcn, nil
284}
285
286func (p *ConnPool) getTurn() {
287 p.queue <- struct{}{}
288}
289
290func (p *ConnPool) waitTurn(ctx context.Context) error {
291 select {
292 case <-ctx.Done():
293 return ctx.Err()
294 default:
295 }
296
297 select {
298 case p.queue <- struct{}{}:
299 return nil
300 default:
301 }
302
303 timer := timers.Get().(*time.Timer)
304 timer.Reset(p.opt.PoolTimeout)
305
306 select {
307 case <-ctx.Done():
308 if !timer.Stop() {
309 <-timer.C
310 }
311 timers.Put(timer)
312 return ctx.Err()
313 case p.queue <- struct{}{}:
314 if !timer.Stop() {
315 <-timer.C
316 }
317 timers.Put(timer)
318 return nil
319 case <-timer.C:
320 timers.Put(timer)
321 atomic.AddUint32(&p.stats.Timeouts, 1)
322 return ErrPoolTimeout
323 }
324}
325
326func (p *ConnPool) freeTurn() {
327 <-p.queue
328}
329
330func (p *ConnPool) popIdle() (*Conn, error) {
331 if p.closed() {
332 return nil, ErrClosed
333 }
334 n := len(p.idleConns)
335 if n == 0 {
336 return nil, nil
337 }
338
339 var cn *Conn
340 if p.opt.PoolFIFO {
341 cn = p.idleConns[0]
342 copy(p.idleConns, p.idleConns[1:])
343 p.idleConns = p.idleConns[:n-1]
344 } else {
345 idx := n - 1
346 cn = p.idleConns[idx]
347 p.idleConns = p.idleConns[:idx]
348 }
349 p.idleConnsLen--
350 p.checkMinIdleConns()
351 return cn, nil
352}
353
354func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
355 if cn.rd.Buffered() > 0 {
356 internal.Logger.Printf(ctx, "Conn has unread data")
357 p.Remove(ctx, cn, BadConnError{})
358 return
359 }
360
361 if !cn.pooled {
362 p.Remove(ctx, cn, nil)
363 return
364 }
365
366 p.connsMu.Lock()
367 p.idleConns = append(p.idleConns, cn)
368 p.idleConnsLen++
369 p.connsMu.Unlock()
370 p.freeTurn()
371}
372
373func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
374 p.removeConnWithLock(cn)
375 p.freeTurn()
376 _ = p.closeConn(cn)
377}
378
379func (p *ConnPool) CloseConn(cn *Conn) error {
380 p.removeConnWithLock(cn)
381 return p.closeConn(cn)
382}
383
384func (p *ConnPool) removeConnWithLock(cn *Conn) {
385 p.connsMu.Lock()
386 p.removeConn(cn)
387 p.connsMu.Unlock()
388}
389
390func (p *ConnPool) removeConn(cn *Conn) {
391 for i, c := range p.conns {
392 if c == cn {
393 p.conns = append(p.conns[:i], p.conns[i+1:]...)
394 if cn.pooled {
395 p.poolSize--
396 p.checkMinIdleConns()
397 }
398 return
399 }
400 }
401}
402
403func (p *ConnPool) closeConn(cn *Conn) error {
404 if p.opt.OnClose != nil {
405 _ = p.opt.OnClose(cn)
406 }
407 return cn.Close()
408}
409
410// Len returns total number of connections.
411func (p *ConnPool) Len() int {
412 p.connsMu.Lock()
413 n := len(p.conns)
414 p.connsMu.Unlock()
415 return n
416}
417
418// IdleLen returns number of idle connections.
419func (p *ConnPool) IdleLen() int {
420 p.connsMu.Lock()
421 n := p.idleConnsLen
422 p.connsMu.Unlock()
423 return n
424}
425
426func (p *ConnPool) Stats() *Stats {
427 idleLen := p.IdleLen()
428 return &Stats{
429 Hits: atomic.LoadUint32(&p.stats.Hits),
430 Misses: atomic.LoadUint32(&p.stats.Misses),
431 Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
432
433 TotalConns: uint32(p.Len()),
434 IdleConns: uint32(idleLen),
435 StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
436 }
437}
438
439func (p *ConnPool) closed() bool {
440 return atomic.LoadUint32(&p._closed) == 1
441}
442
443func (p *ConnPool) Filter(fn func(*Conn) bool) error {
444 p.connsMu.Lock()
445 defer p.connsMu.Unlock()
446
447 var firstErr error
448 for _, cn := range p.conns {
449 if fn(cn) {
450 if err := p.closeConn(cn); err != nil && firstErr == nil {
451 firstErr = err
452 }
453 }
454 }
455 return firstErr
456}
457
458func (p *ConnPool) Close() error {
459 if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
460 return ErrClosed
461 }
462 close(p.closedCh)
463
464 var firstErr error
465 p.connsMu.Lock()
466 for _, cn := range p.conns {
467 if err := p.closeConn(cn); err != nil && firstErr == nil {
468 firstErr = err
469 }
470 }
471 p.conns = nil
472 p.poolSize = 0
473 p.idleConns = nil
474 p.idleConnsLen = 0
475 p.connsMu.Unlock()
476
477 return firstErr
478}
479
480func (p *ConnPool) reaper(frequency time.Duration) {
481 ticker := time.NewTicker(frequency)
482 defer ticker.Stop()
483
484 for {
485 select {
486 case <-ticker.C:
487 // It is possible that ticker and closedCh arrive together,
488 // and select pseudo-randomly pick ticker case, we double
489 // check here to prevent being executed after closed.
490 if p.closed() {
491 return
492 }
493 _, err := p.ReapStaleConns()
494 if err != nil {
495 internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err)
496 continue
497 }
498 case <-p.closedCh:
499 return
500 }
501 }
502}
503
504func (p *ConnPool) ReapStaleConns() (int, error) {
505 var n int
506 for {
507 p.getTurn()
508
509 p.connsMu.Lock()
510 cn := p.reapStaleConn()
511 p.connsMu.Unlock()
512
513 p.freeTurn()
514
515 if cn != nil {
516 _ = p.closeConn(cn)
517 n++
518 } else {
519 break
520 }
521 }
522 atomic.AddUint32(&p.stats.StaleConns, uint32(n))
523 return n, nil
524}
525
526func (p *ConnPool) reapStaleConn() *Conn {
527 if len(p.idleConns) == 0 {
528 return nil
529 }
530
531 cn := p.idleConns[0]
532 if !p.isStaleConn(cn) {
533 return nil
534 }
535
536 p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...)
537 p.idleConnsLen--
538 p.removeConn(cn)
539
540 return cn
541}
542
543func (p *ConnPool) isStaleConn(cn *Conn) bool {
544 if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 {
545 return false
546 }
547
548 now := time.Now()
549 if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout {
550 return true
551 }
552 if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge {
553 return true
554 }
555
556 return false
557}