blob: bcf8a2a94bd064ae331a0a8dd2ede097752b03a8 [file] [log] [blame]
Naveen Sampath04696f72022-06-13 15:19:14 +05301package redis
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "sync/atomic"
8 "time"
9
10 "github.com/go-redis/redis/v8/internal"
11 "github.com/go-redis/redis/v8/internal/pool"
12 "github.com/go-redis/redis/v8/internal/proto"
13)
14
15// Nil reply returned by Redis when key does not exist.
16const Nil = proto.Nil
17
18func SetLogger(logger internal.Logging) {
19 internal.Logger = logger
20}
21
22//------------------------------------------------------------------------------
23
24type Hook interface {
25 BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
26 AfterProcess(ctx context.Context, cmd Cmder) error
27
28 BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
29 AfterProcessPipeline(ctx context.Context, cmds []Cmder) error
30}
31
32type hooks struct {
33 hooks []Hook
34}
35
36func (hs *hooks) lock() {
37 hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
38}
39
40func (hs hooks) clone() hooks {
41 clone := hs
42 clone.lock()
43 return clone
44}
45
46func (hs *hooks) AddHook(hook Hook) {
47 hs.hooks = append(hs.hooks, hook)
48}
49
50func (hs hooks) process(
51 ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
52) error {
53 if len(hs.hooks) == 0 {
54 err := fn(ctx, cmd)
55 cmd.SetErr(err)
56 return err
57 }
58
59 var hookIndex int
60 var retErr error
61
62 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
63 ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
64 if retErr != nil {
65 cmd.SetErr(retErr)
66 }
67 }
68
69 if retErr == nil {
70 retErr = fn(ctx, cmd)
71 cmd.SetErr(retErr)
72 }
73
74 for hookIndex--; hookIndex >= 0; hookIndex-- {
75 if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
76 retErr = err
77 cmd.SetErr(retErr)
78 }
79 }
80
81 return retErr
82}
83
84func (hs hooks) processPipeline(
85 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
86) error {
87 if len(hs.hooks) == 0 {
88 err := fn(ctx, cmds)
89 return err
90 }
91
92 var hookIndex int
93 var retErr error
94
95 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
96 ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
97 if retErr != nil {
98 setCmdsErr(cmds, retErr)
99 }
100 }
101
102 if retErr == nil {
103 retErr = fn(ctx, cmds)
104 }
105
106 for hookIndex--; hookIndex >= 0; hookIndex-- {
107 if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
108 retErr = err
109 setCmdsErr(cmds, retErr)
110 }
111 }
112
113 return retErr
114}
115
116func (hs hooks) processTxPipeline(
117 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
118) error {
119 cmds = wrapMultiExec(ctx, cmds)
120 return hs.processPipeline(ctx, cmds, fn)
121}
122
123//------------------------------------------------------------------------------
124
125type baseClient struct {
126 opt *Options
127 connPool pool.Pooler
128
129 onClose func() error // hook called when client is closed
130}
131
132func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
133 return &baseClient{
134 opt: opt,
135 connPool: connPool,
136 }
137}
138
139func (c *baseClient) clone() *baseClient {
140 clone := *c
141 return &clone
142}
143
144func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
145 opt := c.opt.clone()
146 opt.ReadTimeout = timeout
147 opt.WriteTimeout = timeout
148
149 clone := c.clone()
150 clone.opt = opt
151
152 return clone
153}
154
155func (c *baseClient) String() string {
156 return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
157}
158
159func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
160 cn, err := c.connPool.NewConn(ctx)
161 if err != nil {
162 return nil, err
163 }
164
165 err = c.initConn(ctx, cn)
166 if err != nil {
167 _ = c.connPool.CloseConn(cn)
168 return nil, err
169 }
170
171 return cn, nil
172}
173
174func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
175 if c.opt.Limiter != nil {
176 err := c.opt.Limiter.Allow()
177 if err != nil {
178 return nil, err
179 }
180 }
181
182 cn, err := c._getConn(ctx)
183 if err != nil {
184 if c.opt.Limiter != nil {
185 c.opt.Limiter.ReportResult(err)
186 }
187 return nil, err
188 }
189
190 return cn, nil
191}
192
193func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
194 cn, err := c.connPool.Get(ctx)
195 if err != nil {
196 return nil, err
197 }
198
199 if cn.Inited {
200 return cn, nil
201 }
202
203 if err := c.initConn(ctx, cn); err != nil {
204 c.connPool.Remove(ctx, cn, err)
205 if err := errors.Unwrap(err); err != nil {
206 return nil, err
207 }
208 return nil, err
209 }
210
211 return cn, nil
212}
213
214func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
215 if cn.Inited {
216 return nil
217 }
218 cn.Inited = true
219
220 if c.opt.Password == "" &&
221 c.opt.DB == 0 &&
222 !c.opt.readOnly &&
223 c.opt.OnConnect == nil {
224 return nil
225 }
226
227 connPool := pool.NewSingleConnPool(c.connPool, cn)
228 conn := newConn(ctx, c.opt, connPool)
229
230 _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
231 if c.opt.Password != "" {
232 if c.opt.Username != "" {
233 pipe.AuthACL(ctx, c.opt.Username, c.opt.Password)
234 } else {
235 pipe.Auth(ctx, c.opt.Password)
236 }
237 }
238
239 if c.opt.DB > 0 {
240 pipe.Select(ctx, c.opt.DB)
241 }
242
243 if c.opt.readOnly {
244 pipe.ReadOnly(ctx)
245 }
246
247 return nil
248 })
249 if err != nil {
250 return err
251 }
252
253 if c.opt.OnConnect != nil {
254 return c.opt.OnConnect(ctx, conn)
255 }
256 return nil
257}
258
259func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
260 if c.opt.Limiter != nil {
261 c.opt.Limiter.ReportResult(err)
262 }
263
264 if isBadConn(err, false, c.opt.Addr) {
265 c.connPool.Remove(ctx, cn, err)
266 } else {
267 c.connPool.Put(ctx, cn)
268 }
269}
270
271func (c *baseClient) withConn(
272 ctx context.Context, fn func(context.Context, *pool.Conn) error,
273) error {
274 cn, err := c.getConn(ctx)
275 if err != nil {
276 return err
277 }
278
279 defer func() {
280 c.releaseConn(ctx, cn, err)
281 }()
282
283 done := ctx.Done() //nolint:ifshort
284
285 if done == nil {
286 err = fn(ctx, cn)
287 return err
288 }
289
290 errc := make(chan error, 1)
291 go func() { errc <- fn(ctx, cn) }()
292
293 select {
294 case <-done:
295 _ = cn.Close()
296 // Wait for the goroutine to finish and send something.
297 <-errc
298
299 err = ctx.Err()
300 return err
301 case err = <-errc:
302 return err
303 }
304}
305
306func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
307 var lastErr error
308 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
309 attempt := attempt
310
311 retry, err := c._process(ctx, cmd, attempt)
312 if err == nil || !retry {
313 return err
314 }
315
316 lastErr = err
317 }
318 return lastErr
319}
320
321func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) {
322 if attempt > 0 {
323 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
324 return false, err
325 }
326 }
327
328 retryTimeout := uint32(1)
329 err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
330 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
331 return writeCmd(wr, cmd)
332 })
333 if err != nil {
334 return err
335 }
336
337 err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
338 if err != nil {
339 if cmd.readTimeout() == nil {
340 atomic.StoreUint32(&retryTimeout, 1)
341 }
342 return err
343 }
344
345 return nil
346 })
347 if err == nil {
348 return false, nil
349 }
350
351 retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
352 return retry, err
353}
354
355func (c *baseClient) retryBackoff(attempt int) time.Duration {
356 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
357}
358
359func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
360 if timeout := cmd.readTimeout(); timeout != nil {
361 t := *timeout
362 if t == 0 {
363 return 0
364 }
365 return t + 10*time.Second
366 }
367 return c.opt.ReadTimeout
368}
369
370// Close closes the client, releasing any open resources.
371//
372// It is rare to Close a Client, as the Client is meant to be
373// long-lived and shared between many goroutines.
374func (c *baseClient) Close() error {
375 var firstErr error
376 if c.onClose != nil {
377 if err := c.onClose(); err != nil {
378 firstErr = err
379 }
380 }
381 if err := c.connPool.Close(); err != nil && firstErr == nil {
382 firstErr = err
383 }
384 return firstErr
385}
386
387func (c *baseClient) getAddr() string {
388 return c.opt.Addr
389}
390
391func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
392 return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
393}
394
395func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
396 return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
397}
398
399type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
400
401func (c *baseClient) generalProcessPipeline(
402 ctx context.Context, cmds []Cmder, p pipelineProcessor,
403) error {
404 err := c._generalProcessPipeline(ctx, cmds, p)
405 if err != nil {
406 setCmdsErr(cmds, err)
407 return err
408 }
409 return cmdsFirstErr(cmds)
410}
411
412func (c *baseClient) _generalProcessPipeline(
413 ctx context.Context, cmds []Cmder, p pipelineProcessor,
414) error {
415 var lastErr error
416 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
417 if attempt > 0 {
418 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
419 return err
420 }
421 }
422
423 var canRetry bool
424 lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
425 var err error
426 canRetry, err = p(ctx, cn, cmds)
427 return err
428 })
429 if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
430 return lastErr
431 }
432 }
433 return lastErr
434}
435
436func (c *baseClient) pipelineProcessCmds(
437 ctx context.Context, cn *pool.Conn, cmds []Cmder,
438) (bool, error) {
439 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
440 return writeCmds(wr, cmds)
441 })
442 if err != nil {
443 return true, err
444 }
445
446 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
447 return pipelineReadCmds(rd, cmds)
448 })
449 return true, err
450}
451
452func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
453 for _, cmd := range cmds {
454 err := cmd.readReply(rd)
455 cmd.SetErr(err)
456 if err != nil && !isRedisError(err) {
457 return err
458 }
459 }
460 return nil
461}
462
463func (c *baseClient) txPipelineProcessCmds(
464 ctx context.Context, cn *pool.Conn, cmds []Cmder,
465) (bool, error) {
466 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
467 return writeCmds(wr, cmds)
468 })
469 if err != nil {
470 return true, err
471 }
472
473 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
474 statusCmd := cmds[0].(*StatusCmd)
475 // Trim multi and exec.
476 cmds = cmds[1 : len(cmds)-1]
477
478 err := txPipelineReadQueued(rd, statusCmd, cmds)
479 if err != nil {
480 return err
481 }
482
483 return pipelineReadCmds(rd, cmds)
484 })
485 return false, err
486}
487
488func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
489 if len(cmds) == 0 {
490 panic("not reached")
491 }
492 cmdCopy := make([]Cmder, len(cmds)+2)
493 cmdCopy[0] = NewStatusCmd(ctx, "multi")
494 copy(cmdCopy[1:], cmds)
495 cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec")
496 return cmdCopy
497}
498
499func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
500 // Parse queued replies.
501 if err := statusCmd.readReply(rd); err != nil {
502 return err
503 }
504
505 for range cmds {
506 if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) {
507 return err
508 }
509 }
510
511 // Parse number of replies.
512 line, err := rd.ReadLine()
513 if err != nil {
514 if err == Nil {
515 err = TxFailedErr
516 }
517 return err
518 }
519
520 switch line[0] {
521 case proto.ErrorReply:
522 return proto.ParseErrorReply(line)
523 case proto.ArrayReply:
524 // ok
525 default:
526 err := fmt.Errorf("redis: expected '*', but got line %q", line)
527 return err
528 }
529
530 return nil
531}
532
533//------------------------------------------------------------------------------
534
535// Client is a Redis client representing a pool of zero or more
536// underlying connections. It's safe for concurrent use by multiple
537// goroutines.
538type Client struct {
539 *baseClient
540 cmdable
541 hooks
542 ctx context.Context
543}
544
545// NewClient returns a client to the Redis Server specified by Options.
546func NewClient(opt *Options) *Client {
547 opt.init()
548
549 c := Client{
550 baseClient: newBaseClient(opt, newConnPool(opt)),
551 ctx: context.Background(),
552 }
553 c.cmdable = c.Process
554
555 return &c
556}
557
558func (c *Client) clone() *Client {
559 clone := *c
560 clone.cmdable = clone.Process
561 clone.hooks.lock()
562 return &clone
563}
564
565func (c *Client) WithTimeout(timeout time.Duration) *Client {
566 clone := c.clone()
567 clone.baseClient = c.baseClient.withTimeout(timeout)
568 return clone
569}
570
571func (c *Client) Context() context.Context {
572 return c.ctx
573}
574
575func (c *Client) WithContext(ctx context.Context) *Client {
576 if ctx == nil {
577 panic("nil context")
578 }
579 clone := c.clone()
580 clone.ctx = ctx
581 return clone
582}
583
584func (c *Client) Conn(ctx context.Context) *Conn {
585 return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool))
586}
587
588// Do creates a Cmd from the args and processes the cmd.
589func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
590 cmd := NewCmd(ctx, args...)
591 _ = c.Process(ctx, cmd)
592 return cmd
593}
594
595func (c *Client) Process(ctx context.Context, cmd Cmder) error {
596 return c.hooks.process(ctx, cmd, c.baseClient.process)
597}
598
599func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
600 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
601}
602
603func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
604 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
605}
606
607// Options returns read-only Options that were used to create the client.
608func (c *Client) Options() *Options {
609 return c.opt
610}
611
612type PoolStats pool.Stats
613
614// PoolStats returns connection pool stats.
615func (c *Client) PoolStats() *PoolStats {
616 stats := c.connPool.Stats()
617 return (*PoolStats)(stats)
618}
619
620func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
621 return c.Pipeline().Pipelined(ctx, fn)
622}
623
624func (c *Client) Pipeline() Pipeliner {
625 pipe := Pipeline{
626 ctx: c.ctx,
627 exec: c.processPipeline,
628 }
629 pipe.init()
630 return &pipe
631}
632
633func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
634 return c.TxPipeline().Pipelined(ctx, fn)
635}
636
637// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
638func (c *Client) TxPipeline() Pipeliner {
639 pipe := Pipeline{
640 ctx: c.ctx,
641 exec: c.processTxPipeline,
642 }
643 pipe.init()
644 return &pipe
645}
646
647func (c *Client) pubSub() *PubSub {
648 pubsub := &PubSub{
649 opt: c.opt,
650
651 newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
652 return c.newConn(ctx)
653 },
654 closeConn: c.connPool.CloseConn,
655 }
656 pubsub.init()
657 return pubsub
658}
659
660// Subscribe subscribes the client to the specified channels.
661// Channels can be omitted to create empty subscription.
662// Note that this method does not wait on a response from Redis, so the
663// subscription may not be active immediately. To force the connection to wait,
664// you may call the Receive() method on the returned *PubSub like so:
665//
666// sub := client.Subscribe(queryResp)
667// iface, err := sub.Receive()
668// if err != nil {
669// // handle error
670// }
671//
672// // Should be *Subscription, but others are possible if other actions have been
673// // taken on sub since it was created.
674// switch iface.(type) {
675// case *Subscription:
676// // subscribe succeeded
677// case *Message:
678// // received first message
679// case *Pong:
680// // pong received
681// default:
682// // handle error
683// }
684//
685// ch := sub.Channel()
686func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
687 pubsub := c.pubSub()
688 if len(channels) > 0 {
689 _ = pubsub.Subscribe(ctx, channels...)
690 }
691 return pubsub
692}
693
694// PSubscribe subscribes the client to the given patterns.
695// Patterns can be omitted to create empty subscription.
696func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
697 pubsub := c.pubSub()
698 if len(channels) > 0 {
699 _ = pubsub.PSubscribe(ctx, channels...)
700 }
701 return pubsub
702}
703
704//------------------------------------------------------------------------------
705
706type conn struct {
707 baseClient
708 cmdable
709 statefulCmdable
710 hooks // TODO: inherit hooks
711}
712
713// Conn represents a single Redis connection rather than a pool of connections.
714// Prefer running commands from Client unless there is a specific need
715// for a continuous single Redis connection.
716type Conn struct {
717 *conn
718 ctx context.Context
719}
720
721func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
722 c := Conn{
723 conn: &conn{
724 baseClient: baseClient{
725 opt: opt,
726 connPool: connPool,
727 },
728 },
729 ctx: ctx,
730 }
731 c.cmdable = c.Process
732 c.statefulCmdable = c.Process
733 return &c
734}
735
736func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
737 return c.hooks.process(ctx, cmd, c.baseClient.process)
738}
739
740func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error {
741 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
742}
743
744func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error {
745 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
746}
747
748func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
749 return c.Pipeline().Pipelined(ctx, fn)
750}
751
752func (c *Conn) Pipeline() Pipeliner {
753 pipe := Pipeline{
754 ctx: c.ctx,
755 exec: c.processPipeline,
756 }
757 pipe.init()
758 return &pipe
759}
760
761func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
762 return c.TxPipeline().Pipelined(ctx, fn)
763}
764
765// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
766func (c *Conn) TxPipeline() Pipeliner {
767 pipe := Pipeline{
768 ctx: c.ctx,
769 exec: c.processTxPipeline,
770 }
771 pipe.init()
772 return &pipe
773}