blob: efad7f13881fda6fbc3c91bcdbc9592092fe95c7 [file] [log] [blame]
Joey Armstrong5f51f2e2023-01-17 17:06:26 -05001package redis
2
3import (
4 "context"
5 "fmt"
6 "time"
7
8 "github.com/go-redis/redis/v8/internal"
9 "github.com/go-redis/redis/v8/internal/pool"
10 "github.com/go-redis/redis/v8/internal/proto"
11 "go.opentelemetry.io/otel/api/trace"
12 "go.opentelemetry.io/otel/label"
13)
14
15// Nil reply returned by Redis when key does not exist.
16const Nil = proto.Nil
17
18func SetLogger(logger internal.Logging) {
19 internal.Logger = logger
20}
21
22//------------------------------------------------------------------------------
23
24type Hook interface {
25 BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
26 AfterProcess(ctx context.Context, cmd Cmder) error
27
28 BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
29 AfterProcessPipeline(ctx context.Context, cmds []Cmder) error
30}
31
32type hooks struct {
33 hooks []Hook
34}
35
36func (hs *hooks) lock() {
37 hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
38}
39
40func (hs hooks) clone() hooks {
41 clone := hs
42 clone.lock()
43 return clone
44}
45
46func (hs *hooks) AddHook(hook Hook) {
47 hs.hooks = append(hs.hooks, hook)
48}
49
50func (hs hooks) process(
51 ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
52) error {
53 if len(hs.hooks) == 0 {
54 err := hs.withContext(ctx, func() error {
55 return fn(ctx, cmd)
56 })
57 cmd.SetErr(err)
58 return err
59 }
60
61 var hookIndex int
62 var retErr error
63
64 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
65 ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
66 if retErr != nil {
67 cmd.SetErr(retErr)
68 }
69 }
70
71 if retErr == nil {
72 retErr = hs.withContext(ctx, func() error {
73 return fn(ctx, cmd)
74 })
75 cmd.SetErr(retErr)
76 }
77
78 for hookIndex--; hookIndex >= 0; hookIndex-- {
79 if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
80 retErr = err
81 cmd.SetErr(retErr)
82 }
83 }
84
85 return retErr
86}
87
88func (hs hooks) processPipeline(
89 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
90) error {
91 if len(hs.hooks) == 0 {
92 err := hs.withContext(ctx, func() error {
93 return fn(ctx, cmds)
94 })
95 return err
96 }
97
98 var hookIndex int
99 var retErr error
100
101 for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
102 ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
103 if retErr != nil {
104 setCmdsErr(cmds, retErr)
105 }
106 }
107
108 if retErr == nil {
109 retErr = hs.withContext(ctx, func() error {
110 return fn(ctx, cmds)
111 })
112 }
113
114 for hookIndex--; hookIndex >= 0; hookIndex-- {
115 if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
116 retErr = err
117 setCmdsErr(cmds, retErr)
118 }
119 }
120
121 return retErr
122}
123
124func (hs hooks) processTxPipeline(
125 ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
126) error {
127 cmds = wrapMultiExec(ctx, cmds)
128 return hs.processPipeline(ctx, cmds, fn)
129}
130
131func (hs hooks) withContext(ctx context.Context, fn func() error) error {
132 done := ctx.Done()
133 if done == nil {
134 return fn()
135 }
136
137 errc := make(chan error, 1)
138 go func() { errc <- fn() }()
139
140 select {
141 case <-done:
142 return ctx.Err()
143 case err := <-errc:
144 return err
145 }
146}
147
148//------------------------------------------------------------------------------
149
150type baseClient struct {
151 opt *Options
152 connPool pool.Pooler
153
154 onClose func() error // hook called when client is closed
155}
156
157func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
158 return &baseClient{
159 opt: opt,
160 connPool: connPool,
161 }
162}
163
164func (c *baseClient) clone() *baseClient {
165 clone := *c
166 return &clone
167}
168
169func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
170 opt := c.opt.clone()
171 opt.ReadTimeout = timeout
172 opt.WriteTimeout = timeout
173
174 clone := c.clone()
175 clone.opt = opt
176
177 return clone
178}
179
180func (c *baseClient) String() string {
181 return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
182}
183
184func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
185 cn, err := c.connPool.NewConn(ctx)
186 if err != nil {
187 return nil, err
188 }
189
190 err = c.initConn(ctx, cn)
191 if err != nil {
192 _ = c.connPool.CloseConn(cn)
193 return nil, err
194 }
195
196 return cn, nil
197}
198
199func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
200 if c.opt.Limiter != nil {
201 err := c.opt.Limiter.Allow()
202 if err != nil {
203 return nil, err
204 }
205 }
206
207 cn, err := c._getConn(ctx)
208 if err != nil {
209 if c.opt.Limiter != nil {
210 c.opt.Limiter.ReportResult(err)
211 }
212 return nil, err
213 }
214
215 return cn, nil
216}
217
218func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
219 cn, err := c.connPool.Get(ctx)
220 if err != nil {
221 return nil, err
222 }
223
224 if cn.Inited {
225 return cn, nil
226 }
227
228 err = internal.WithSpan(ctx, "redis.init_conn", func(ctx context.Context, span trace.Span) error {
229 return c.initConn(ctx, cn)
230 })
231 if err != nil {
232 c.connPool.Remove(ctx, cn, err)
233 if err := internal.Unwrap(err); err != nil {
234 return nil, err
235 }
236 return nil, err
237 }
238
239 return cn, nil
240}
241
242func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
243 if cn.Inited {
244 return nil
245 }
246 cn.Inited = true
247
248 if c.opt.Password == "" &&
249 c.opt.DB == 0 &&
250 !c.opt.readOnly &&
251 c.opt.OnConnect == nil {
252 return nil
253 }
254
255 connPool := pool.NewSingleConnPool(c.connPool, cn)
256 conn := newConn(ctx, c.opt, connPool)
257
258 _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
259 if c.opt.Password != "" {
260 if c.opt.Username != "" {
261 pipe.AuthACL(ctx, c.opt.Username, c.opt.Password)
262 } else {
263 pipe.Auth(ctx, c.opt.Password)
264 }
265 }
266
267 if c.opt.DB > 0 {
268 pipe.Select(ctx, c.opt.DB)
269 }
270
271 if c.opt.readOnly {
272 pipe.ReadOnly(ctx)
273 }
274
275 return nil
276 })
277 if err != nil {
278 return err
279 }
280
281 if c.opt.OnConnect != nil {
282 return c.opt.OnConnect(ctx, conn)
283 }
284 return nil
285}
286
287func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
288 if c.opt.Limiter != nil {
289 c.opt.Limiter.ReportResult(err)
290 }
291
292 if isBadConn(err, false) {
293 c.connPool.Remove(ctx, cn, err)
294 } else {
295 c.connPool.Put(ctx, cn)
296 }
297}
298
299func (c *baseClient) withConn(
300 ctx context.Context, fn func(context.Context, *pool.Conn) error,
301) error {
302 return internal.WithSpan(ctx, "redis.with_conn", func(ctx context.Context, span trace.Span) error {
303 cn, err := c.getConn(ctx)
304 if err != nil {
305 return err
306 }
307
308 if span.IsRecording() {
309 if remoteAddr := cn.RemoteAddr(); remoteAddr != nil {
310 span.SetAttributes(label.String("net.peer.ip", remoteAddr.String()))
311 }
312 }
313
314 defer func() {
315 c.releaseConn(ctx, cn, err)
316 }()
317
318 err = fn(ctx, cn)
319 return err
320 })
321}
322
323func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
324 var lastErr error
325 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
326 attempt := attempt
327
328 var retry bool
329 err := internal.WithSpan(ctx, "redis.process", func(ctx context.Context, span trace.Span) error {
330 if attempt > 0 {
331 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
332 return err
333 }
334 }
335
336 retryTimeout := true
337 err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
338 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
339 return writeCmd(wr, cmd)
340 })
341 if err != nil {
342 return err
343 }
344
345 err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
346 if err != nil {
347 retryTimeout = cmd.readTimeout() == nil
348 return err
349 }
350
351 return nil
352 })
353 if err == nil {
354 return nil
355 }
356 retry = shouldRetry(err, retryTimeout)
357 return err
358 })
359 if err == nil || !retry {
360 return err
361 }
362 lastErr = err
363 }
364 return lastErr
365}
366
367func (c *baseClient) retryBackoff(attempt int) time.Duration {
368 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
369}
370
371func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
372 if timeout := cmd.readTimeout(); timeout != nil {
373 t := *timeout
374 if t == 0 {
375 return 0
376 }
377 return t + 10*time.Second
378 }
379 return c.opt.ReadTimeout
380}
381
382// Close closes the client, releasing any open resources.
383//
384// It is rare to Close a Client, as the Client is meant to be
385// long-lived and shared between many goroutines.
386func (c *baseClient) Close() error {
387 var firstErr error
388 if c.onClose != nil {
389 if err := c.onClose(); err != nil {
390 firstErr = err
391 }
392 }
393 if err := c.connPool.Close(); err != nil && firstErr == nil {
394 firstErr = err
395 }
396 return firstErr
397}
398
399func (c *baseClient) getAddr() string {
400 return c.opt.Addr
401}
402
403func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
404 return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
405}
406
407func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
408 return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
409}
410
411type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
412
413func (c *baseClient) generalProcessPipeline(
414 ctx context.Context, cmds []Cmder, p pipelineProcessor,
415) error {
416 err := c._generalProcessPipeline(ctx, cmds, p)
417 if err != nil {
418 setCmdsErr(cmds, err)
419 return err
420 }
421 return cmdsFirstErr(cmds)
422}
423
424func (c *baseClient) _generalProcessPipeline(
425 ctx context.Context, cmds []Cmder, p pipelineProcessor,
426) error {
427 var lastErr error
428 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
429 if attempt > 0 {
430 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
431 return err
432 }
433 }
434
435 var canRetry bool
436 lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
437 var err error
438 canRetry, err = p(ctx, cn, cmds)
439 return err
440 })
441 if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
442 return lastErr
443 }
444 }
445 return lastErr
446}
447
448func (c *baseClient) pipelineProcessCmds(
449 ctx context.Context, cn *pool.Conn, cmds []Cmder,
450) (bool, error) {
451 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
452 return writeCmds(wr, cmds)
453 })
454 if err != nil {
455 return true, err
456 }
457
458 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
459 return pipelineReadCmds(rd, cmds)
460 })
461 return true, err
462}
463
464func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
465 for _, cmd := range cmds {
466 err := cmd.readReply(rd)
467 cmd.SetErr(err)
468 if err != nil && !isRedisError(err) {
469 return err
470 }
471 }
472 return nil
473}
474
475func (c *baseClient) txPipelineProcessCmds(
476 ctx context.Context, cn *pool.Conn, cmds []Cmder,
477) (bool, error) {
478 err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
479 return writeCmds(wr, cmds)
480 })
481 if err != nil {
482 return true, err
483 }
484
485 err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
486 statusCmd := cmds[0].(*StatusCmd)
487 // Trim multi and exec.
488 cmds = cmds[1 : len(cmds)-1]
489
490 err := txPipelineReadQueued(rd, statusCmd, cmds)
491 if err != nil {
492 return err
493 }
494
495 return pipelineReadCmds(rd, cmds)
496 })
497 return false, err
498}
499
500func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
501 if len(cmds) == 0 {
502 panic("not reached")
503 }
504 cmdCopy := make([]Cmder, len(cmds)+2)
505 cmdCopy[0] = NewStatusCmd(ctx, "multi")
506 copy(cmdCopy[1:], cmds)
507 cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec")
508 return cmdCopy
509}
510
511func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
512 // Parse queued replies.
513 if err := statusCmd.readReply(rd); err != nil {
514 return err
515 }
516
517 for range cmds {
518 if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) {
519 return err
520 }
521 }
522
523 // Parse number of replies.
524 line, err := rd.ReadLine()
525 if err != nil {
526 if err == Nil {
527 err = TxFailedErr
528 }
529 return err
530 }
531
532 switch line[0] {
533 case proto.ErrorReply:
534 return proto.ParseErrorReply(line)
535 case proto.ArrayReply:
536 // ok
537 default:
538 err := fmt.Errorf("redis: expected '*', but got line %q", line)
539 return err
540 }
541
542 return nil
543}
544
545//------------------------------------------------------------------------------
546
547// Client is a Redis client representing a pool of zero or more
548// underlying connections. It's safe for concurrent use by multiple
549// goroutines.
550type Client struct {
551 *baseClient
552 cmdable
553 hooks
554 ctx context.Context
555}
556
557// NewClient returns a client to the Redis Server specified by Options.
558func NewClient(opt *Options) *Client {
559 opt.init()
560
561 c := Client{
562 baseClient: newBaseClient(opt, newConnPool(opt)),
563 ctx: context.Background(),
564 }
565 c.cmdable = c.Process
566
567 return &c
568}
569
570func (c *Client) clone() *Client {
571 clone := *c
572 clone.cmdable = clone.Process
573 clone.hooks.lock()
574 return &clone
575}
576
577func (c *Client) WithTimeout(timeout time.Duration) *Client {
578 clone := c.clone()
579 clone.baseClient = c.baseClient.withTimeout(timeout)
580 return clone
581}
582
583func (c *Client) Context() context.Context {
584 return c.ctx
585}
586
587func (c *Client) WithContext(ctx context.Context) *Client {
588 if ctx == nil {
589 panic("nil context")
590 }
591 clone := c.clone()
592 clone.ctx = ctx
593 return clone
594}
595
596func (c *Client) Conn(ctx context.Context) *Conn {
597 return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool))
598}
599
600// Do creates a Cmd from the args and processes the cmd.
601func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
602 cmd := NewCmd(ctx, args...)
603 _ = c.Process(ctx, cmd)
604 return cmd
605}
606
607func (c *Client) Process(ctx context.Context, cmd Cmder) error {
608 return c.hooks.process(ctx, cmd, c.baseClient.process)
609}
610
611func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
612 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
613}
614
615func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
616 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
617}
618
619// Options returns read-only Options that were used to create the client.
620func (c *Client) Options() *Options {
621 return c.opt
622}
623
624type PoolStats pool.Stats
625
626// PoolStats returns connection pool stats.
627func (c *Client) PoolStats() *PoolStats {
628 stats := c.connPool.Stats()
629 return (*PoolStats)(stats)
630}
631
632func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
633 return c.Pipeline().Pipelined(ctx, fn)
634}
635
636func (c *Client) Pipeline() Pipeliner {
637 pipe := Pipeline{
638 ctx: c.ctx,
639 exec: c.processPipeline,
640 }
641 pipe.init()
642 return &pipe
643}
644
645func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
646 return c.TxPipeline().Pipelined(ctx, fn)
647}
648
649// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
650func (c *Client) TxPipeline() Pipeliner {
651 pipe := Pipeline{
652 ctx: c.ctx,
653 exec: c.processTxPipeline,
654 }
655 pipe.init()
656 return &pipe
657}
658
659func (c *Client) pubSub() *PubSub {
660 pubsub := &PubSub{
661 opt: c.opt,
662
663 newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
664 return c.newConn(ctx)
665 },
666 closeConn: c.connPool.CloseConn,
667 }
668 pubsub.init()
669 return pubsub
670}
671
672// Subscribe subscribes the client to the specified channels.
673// Channels can be omitted to create empty subscription.
674// Note that this method does not wait on a response from Redis, so the
675// subscription may not be active immediately. To force the connection to wait,
676// you may call the Receive() method on the returned *PubSub like so:
677//
678// sub := client.Subscribe(queryResp)
679// iface, err := sub.Receive()
680// if err != nil {
681// // handle error
682// }
683//
684// // Should be *Subscription, but others are possible if other actions have been
685// // taken on sub since it was created.
686// switch iface.(type) {
687// case *Subscription:
688// // subscribe succeeded
689// case *Message:
690// // received first message
691// case *Pong:
692// // pong received
693// default:
694// // handle error
695// }
696//
697// ch := sub.Channel()
698func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
699 pubsub := c.pubSub()
700 if len(channels) > 0 {
701 _ = pubsub.Subscribe(ctx, channels...)
702 }
703 return pubsub
704}
705
706// PSubscribe subscribes the client to the given patterns.
707// Patterns can be omitted to create empty subscription.
708func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
709 pubsub := c.pubSub()
710 if len(channels) > 0 {
711 _ = pubsub.PSubscribe(ctx, channels...)
712 }
713 return pubsub
714}
715
716//------------------------------------------------------------------------------
717
718type conn struct {
719 baseClient
720 cmdable
721 statefulCmdable
722 hooks // TODO: inherit hooks
723}
724
725// Conn is like Client, but its pool contains single connection.
726type Conn struct {
727 *conn
728 ctx context.Context
729}
730
731func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
732 c := Conn{
733 conn: &conn{
734 baseClient: baseClient{
735 opt: opt,
736 connPool: connPool,
737 },
738 },
739 ctx: ctx,
740 }
741 c.cmdable = c.Process
742 c.statefulCmdable = c.Process
743 return &c
744}
745
746func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
747 return c.hooks.process(ctx, cmd, c.baseClient.process)
748}
749
750func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error {
751 return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
752}
753
754func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error {
755 return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
756}
757
758func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
759 return c.Pipeline().Pipelined(ctx, fn)
760}
761
762func (c *Conn) Pipeline() Pipeliner {
763 pipe := Pipeline{
764 ctx: c.ctx,
765 exec: c.processPipeline,
766 }
767 pipe.init()
768 return &pipe
769}
770
771func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
772 return c.TxPipeline().Pipelined(ctx, fn)
773}
774
775// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
776func (c *Conn) TxPipeline() Pipeliner {
777 pipe := Pipeline{
778 ctx: c.ctx,
779 exec: c.processTxPipeline,
780 }
781 pipe.init()
782 return &pipe
783}