blob: 34d05f35aeeb215506f0abdf9ae3581739b65ab2 [file] [log] [blame]
Joey Armstronge8c091f2023-01-17 16:56:26 -05001package redis
2
3import (
4 "context"
5 "crypto/tls"
6 "errors"
7 "fmt"
8 "net"
9 "strconv"
10 "sync"
11 "sync/atomic"
12 "time"
13
14 "github.com/cespare/xxhash/v2"
15 "github.com/dgryski/go-rendezvous"
16 "github.com/go-redis/redis/v8/internal"
17 "github.com/go-redis/redis/v8/internal/hashtag"
18 "github.com/go-redis/redis/v8/internal/pool"
19 "github.com/go-redis/redis/v8/internal/rand"
20)
21
22var errRingShardsDown = errors.New("redis: all ring shards are down")
23
24//------------------------------------------------------------------------------
25
26type ConsistentHash interface {
27 Get(string) string
28}
29
30type rendezvousWrapper struct {
31 *rendezvous.Rendezvous
32}
33
34func (w rendezvousWrapper) Get(key string) string {
35 return w.Lookup(key)
36}
37
38func newRendezvous(shards []string) ConsistentHash {
39 return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)}
40}
41
42//------------------------------------------------------------------------------
43
44// RingOptions are used to configure a ring client and should be
45// passed to NewRing.
46type RingOptions struct {
47 // Map of name => host:port addresses of ring shards.
48 Addrs map[string]string
49
50 // NewClient creates a shard client with provided name and options.
51 NewClient func(name string, opt *Options) *Client
52
53 // Frequency of PING commands sent to check shards availability.
54 // Shard is considered down after 3 subsequent failed checks.
55 HeartbeatFrequency time.Duration
56
57 // NewConsistentHash returns a consistent hash that is used
58 // to distribute keys across the shards.
59 //
60 // See https://medium.com/@dgryski/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8
61 // for consistent hashing algorithmic tradeoffs.
62 NewConsistentHash func(shards []string) ConsistentHash
63
64 // Following options are copied from Options struct.
65
66 Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
67 OnConnect func(ctx context.Context, cn *Conn) error
68
69 Username string
70 Password string
71 DB int
72
73 MaxRetries int
74 MinRetryBackoff time.Duration
75 MaxRetryBackoff time.Duration
76
77 DialTimeout time.Duration
78 ReadTimeout time.Duration
79 WriteTimeout time.Duration
80
81 PoolSize int
82 MinIdleConns int
83 MaxConnAge time.Duration
84 PoolTimeout time.Duration
85 IdleTimeout time.Duration
86 IdleCheckFrequency time.Duration
87
88 TLSConfig *tls.Config
89 Limiter Limiter
90}
91
92func (opt *RingOptions) init() {
93 if opt.NewClient == nil {
94 opt.NewClient = func(name string, opt *Options) *Client {
95 return NewClient(opt)
96 }
97 }
98
99 if opt.HeartbeatFrequency == 0 {
100 opt.HeartbeatFrequency = 500 * time.Millisecond
101 }
102
103 if opt.NewConsistentHash == nil {
104 opt.NewConsistentHash = newRendezvous
105 }
106
107 if opt.MaxRetries == -1 {
108 opt.MaxRetries = 0
109 } else if opt.MaxRetries == 0 {
110 opt.MaxRetries = 3
111 }
112 switch opt.MinRetryBackoff {
113 case -1:
114 opt.MinRetryBackoff = 0
115 case 0:
116 opt.MinRetryBackoff = 8 * time.Millisecond
117 }
118 switch opt.MaxRetryBackoff {
119 case -1:
120 opt.MaxRetryBackoff = 0
121 case 0:
122 opt.MaxRetryBackoff = 512 * time.Millisecond
123 }
124}
125
126func (opt *RingOptions) clientOptions() *Options {
127 return &Options{
128 Dialer: opt.Dialer,
129 OnConnect: opt.OnConnect,
130
131 Username: opt.Username,
132 Password: opt.Password,
133 DB: opt.DB,
134
135 MaxRetries: -1,
136
137 DialTimeout: opt.DialTimeout,
138 ReadTimeout: opt.ReadTimeout,
139 WriteTimeout: opt.WriteTimeout,
140
141 PoolSize: opt.PoolSize,
142 MinIdleConns: opt.MinIdleConns,
143 MaxConnAge: opt.MaxConnAge,
144 PoolTimeout: opt.PoolTimeout,
145 IdleTimeout: opt.IdleTimeout,
146 IdleCheckFrequency: opt.IdleCheckFrequency,
147
148 TLSConfig: opt.TLSConfig,
149 Limiter: opt.Limiter,
150 }
151}
152
153//------------------------------------------------------------------------------
154
155type ringShard struct {
156 Client *Client
157 down int32
158}
159
160func newRingShard(opt *RingOptions, name, addr string) *ringShard {
161 clopt := opt.clientOptions()
162 clopt.Addr = addr
163
164 return &ringShard{
165 Client: opt.NewClient(name, clopt),
166 }
167}
168
169func (shard *ringShard) String() string {
170 var state string
171 if shard.IsUp() {
172 state = "up"
173 } else {
174 state = "down"
175 }
176 return fmt.Sprintf("%s is %s", shard.Client, state)
177}
178
179func (shard *ringShard) IsDown() bool {
180 const threshold = 3
181 return atomic.LoadInt32(&shard.down) >= threshold
182}
183
184func (shard *ringShard) IsUp() bool {
185 return !shard.IsDown()
186}
187
188// Vote votes to set shard state and returns true if state was changed.
189func (shard *ringShard) Vote(up bool) bool {
190 if up {
191 changed := shard.IsDown()
192 atomic.StoreInt32(&shard.down, 0)
193 return changed
194 }
195
196 if shard.IsDown() {
197 return false
198 }
199
200 atomic.AddInt32(&shard.down, 1)
201 return shard.IsDown()
202}
203
204//------------------------------------------------------------------------------
205
206type ringShards struct {
207 opt *RingOptions
208
209 mu sync.RWMutex
210 hash ConsistentHash
211 shards map[string]*ringShard // read only
212 list []*ringShard // read only
213 numShard int
214 closed bool
215}
216
217func newRingShards(opt *RingOptions) *ringShards {
218 shards := make(map[string]*ringShard, len(opt.Addrs))
219 list := make([]*ringShard, 0, len(shards))
220
221 for name, addr := range opt.Addrs {
222 shard := newRingShard(opt, name, addr)
223 shards[name] = shard
224
225 list = append(list, shard)
226 }
227
228 c := &ringShards{
229 opt: opt,
230
231 shards: shards,
232 list: list,
233 }
234 c.rebalance()
235
236 return c
237}
238
239func (c *ringShards) List() []*ringShard {
240 var list []*ringShard
241
242 c.mu.RLock()
243 if !c.closed {
244 list = c.list
245 }
246 c.mu.RUnlock()
247
248 return list
249}
250
251func (c *ringShards) Hash(key string) string {
252 key = hashtag.Key(key)
253
254 var hash string
255
256 c.mu.RLock()
257 if c.numShard > 0 {
258 hash = c.hash.Get(key)
259 }
260 c.mu.RUnlock()
261
262 return hash
263}
264
265func (c *ringShards) GetByKey(key string) (*ringShard, error) {
266 key = hashtag.Key(key)
267
268 c.mu.RLock()
269
270 if c.closed {
271 c.mu.RUnlock()
272 return nil, pool.ErrClosed
273 }
274
275 if c.numShard == 0 {
276 c.mu.RUnlock()
277 return nil, errRingShardsDown
278 }
279
280 hash := c.hash.Get(key)
281 if hash == "" {
282 c.mu.RUnlock()
283 return nil, errRingShardsDown
284 }
285
286 shard := c.shards[hash]
287 c.mu.RUnlock()
288
289 return shard, nil
290}
291
292func (c *ringShards) GetByName(shardName string) (*ringShard, error) {
293 if shardName == "" {
294 return c.Random()
295 }
296
297 c.mu.RLock()
298 shard := c.shards[shardName]
299 c.mu.RUnlock()
300 return shard, nil
301}
302
303func (c *ringShards) Random() (*ringShard, error) {
304 return c.GetByKey(strconv.Itoa(rand.Int()))
305}
306
307// heartbeat monitors state of each shard in the ring.
308func (c *ringShards) Heartbeat(frequency time.Duration) {
309 ticker := time.NewTicker(frequency)
310 defer ticker.Stop()
311
312 ctx := context.Background()
313 for range ticker.C {
314 var rebalance bool
315
316 for _, shard := range c.List() {
317 err := shard.Client.Ping(ctx).Err()
318 isUp := err == nil || err == pool.ErrPoolTimeout
319 if shard.Vote(isUp) {
320 internal.Logger.Printf(context.Background(), "ring shard state changed: %s", shard)
321 rebalance = true
322 }
323 }
324
325 if rebalance {
326 c.rebalance()
327 }
328 }
329}
330
331// rebalance removes dead shards from the Ring.
332func (c *ringShards) rebalance() {
333 c.mu.RLock()
334 shards := c.shards
335 c.mu.RUnlock()
336
337 liveShards := make([]string, 0, len(shards))
338
339 for name, shard := range shards {
340 if shard.IsUp() {
341 liveShards = append(liveShards, name)
342 }
343 }
344
345 hash := c.opt.NewConsistentHash(liveShards)
346
347 c.mu.Lock()
348 c.hash = hash
349 c.numShard = len(liveShards)
350 c.mu.Unlock()
351}
352
353func (c *ringShards) Len() int {
354 c.mu.RLock()
355 l := c.numShard
356 c.mu.RUnlock()
357 return l
358}
359
360func (c *ringShards) Close() error {
361 c.mu.Lock()
362 defer c.mu.Unlock()
363
364 if c.closed {
365 return nil
366 }
367 c.closed = true
368
369 var firstErr error
370 for _, shard := range c.shards {
371 if err := shard.Client.Close(); err != nil && firstErr == nil {
372 firstErr = err
373 }
374 }
375 c.hash = nil
376 c.shards = nil
377 c.list = nil
378
379 return firstErr
380}
381
382//------------------------------------------------------------------------------
383
384type ring struct {
385 opt *RingOptions
386 shards *ringShards
387 cmdsInfoCache *cmdsInfoCache //nolint:structcheck
388}
389
390// Ring is a Redis client that uses consistent hashing to distribute
391// keys across multiple Redis servers (shards). It's safe for
392// concurrent use by multiple goroutines.
393//
394// Ring monitors the state of each shard and removes dead shards from
395// the ring. When a shard comes online it is added back to the ring. This
396// gives you maximum availability and partition tolerance, but no
397// consistency between different shards or even clients. Each client
398// uses shards that are available to the client and does not do any
399// coordination when shard state is changed.
400//
401// Ring should be used when you need multiple Redis servers for caching
402// and can tolerate losing data when one of the servers dies.
403// Otherwise you should use Redis Cluster.
404type Ring struct {
405 *ring
406 cmdable
407 hooks
408 ctx context.Context
409}
410
411func NewRing(opt *RingOptions) *Ring {
412 opt.init()
413
414 ring := Ring{
415 ring: &ring{
416 opt: opt,
417 shards: newRingShards(opt),
418 },
419 ctx: context.Background(),
420 }
421
422 ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
423 ring.cmdable = ring.Process
424
425 go ring.shards.Heartbeat(opt.HeartbeatFrequency)
426
427 return &ring
428}
429
430func (c *Ring) Context() context.Context {
431 return c.ctx
432}
433
434func (c *Ring) WithContext(ctx context.Context) *Ring {
435 if ctx == nil {
436 panic("nil context")
437 }
438 clone := *c
439 clone.cmdable = clone.Process
440 clone.hooks.lock()
441 clone.ctx = ctx
442 return &clone
443}
444
445// Do creates a Cmd from the args and processes the cmd.
446func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd {
447 cmd := NewCmd(ctx, args...)
448 _ = c.Process(ctx, cmd)
449 return cmd
450}
451
452func (c *Ring) Process(ctx context.Context, cmd Cmder) error {
453 return c.hooks.process(ctx, cmd, c.process)
454}
455
456// Options returns read-only Options that were used to create the client.
457func (c *Ring) Options() *RingOptions {
458 return c.opt
459}
460
461func (c *Ring) retryBackoff(attempt int) time.Duration {
462 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
463}
464
465// PoolStats returns accumulated connection pool stats.
466func (c *Ring) PoolStats() *PoolStats {
467 shards := c.shards.List()
468 var acc PoolStats
469 for _, shard := range shards {
470 s := shard.Client.connPool.Stats()
471 acc.Hits += s.Hits
472 acc.Misses += s.Misses
473 acc.Timeouts += s.Timeouts
474 acc.TotalConns += s.TotalConns
475 acc.IdleConns += s.IdleConns
476 }
477 return &acc
478}
479
480// Len returns the current number of shards in the ring.
481func (c *Ring) Len() int {
482 return c.shards.Len()
483}
484
485// Subscribe subscribes the client to the specified channels.
486func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub {
487 if len(channels) == 0 {
488 panic("at least one channel is required")
489 }
490
491 shard, err := c.shards.GetByKey(channels[0])
492 if err != nil {
493 // TODO: return PubSub with sticky error
494 panic(err)
495 }
496 return shard.Client.Subscribe(ctx, channels...)
497}
498
499// PSubscribe subscribes the client to the given patterns.
500func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
501 if len(channels) == 0 {
502 panic("at least one channel is required")
503 }
504
505 shard, err := c.shards.GetByKey(channels[0])
506 if err != nil {
507 // TODO: return PubSub with sticky error
508 panic(err)
509 }
510 return shard.Client.PSubscribe(ctx, channels...)
511}
512
513// ForEachShard concurrently calls the fn on each live shard in the ring.
514// It returns the first error if any.
515func (c *Ring) ForEachShard(
516 ctx context.Context,
517 fn func(ctx context.Context, client *Client) error,
518) error {
519 shards := c.shards.List()
520 var wg sync.WaitGroup
521 errCh := make(chan error, 1)
522 for _, shard := range shards {
523 if shard.IsDown() {
524 continue
525 }
526
527 wg.Add(1)
528 go func(shard *ringShard) {
529 defer wg.Done()
530 err := fn(ctx, shard.Client)
531 if err != nil {
532 select {
533 case errCh <- err:
534 default:
535 }
536 }
537 }(shard)
538 }
539 wg.Wait()
540
541 select {
542 case err := <-errCh:
543 return err
544 default:
545 return nil
546 }
547}
548
549func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) {
550 shards := c.shards.List()
551 var firstErr error
552 for _, shard := range shards {
553 cmdsInfo, err := shard.Client.Command(ctx).Result()
554 if err == nil {
555 return cmdsInfo, nil
556 }
557 if firstErr == nil {
558 firstErr = err
559 }
560 }
561 if firstErr == nil {
562 return nil, errRingShardsDown
563 }
564 return nil, firstErr
565}
566
567func (c *Ring) cmdInfo(ctx context.Context, name string) *CommandInfo {
568 cmdsInfo, err := c.cmdsInfoCache.Get(ctx)
569 if err != nil {
570 return nil
571 }
572 info := cmdsInfo[name]
573 if info == nil {
574 internal.Logger.Printf(c.Context(), "info for cmd=%s not found", name)
575 }
576 return info
577}
578
579func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) {
580 cmdInfo := c.cmdInfo(ctx, cmd.Name())
581 pos := cmdFirstKeyPos(cmd, cmdInfo)
582 if pos == 0 {
583 return c.shards.Random()
584 }
585 firstKey := cmd.stringArg(pos)
586 return c.shards.GetByKey(firstKey)
587}
588
589func (c *Ring) process(ctx context.Context, cmd Cmder) error {
590 var lastErr error
591 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
592 if attempt > 0 {
593 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
594 return err
595 }
596 }
597
598 shard, err := c.cmdShard(ctx, cmd)
599 if err != nil {
600 return err
601 }
602
603 lastErr = shard.Client.Process(ctx, cmd)
604 if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) {
605 return lastErr
606 }
607 }
608 return lastErr
609}
610
611func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
612 return c.Pipeline().Pipelined(ctx, fn)
613}
614
615func (c *Ring) Pipeline() Pipeliner {
616 pipe := Pipeline{
617 ctx: c.ctx,
618 exec: c.processPipeline,
619 }
620 pipe.init()
621 return &pipe
622}
623
624func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error {
625 return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
626 return c.generalProcessPipeline(ctx, cmds, false)
627 })
628}
629
630func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
631 return c.TxPipeline().Pipelined(ctx, fn)
632}
633
634func (c *Ring) TxPipeline() Pipeliner {
635 pipe := Pipeline{
636 ctx: c.ctx,
637 exec: c.processTxPipeline,
638 }
639 pipe.init()
640 return &pipe
641}
642
643func (c *Ring) processTxPipeline(ctx context.Context, cmds []Cmder) error {
644 return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
645 return c.generalProcessPipeline(ctx, cmds, true)
646 })
647}
648
649func (c *Ring) generalProcessPipeline(
650 ctx context.Context, cmds []Cmder, tx bool,
651) error {
652 cmdsMap := make(map[string][]Cmder)
653 for _, cmd := range cmds {
654 cmdInfo := c.cmdInfo(ctx, cmd.Name())
655 hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo))
656 if hash != "" {
657 hash = c.shards.Hash(hash)
658 }
659 cmdsMap[hash] = append(cmdsMap[hash], cmd)
660 }
661
662 var wg sync.WaitGroup
663 for hash, cmds := range cmdsMap {
664 wg.Add(1)
665 go func(hash string, cmds []Cmder) {
666 defer wg.Done()
667
668 _ = c.processShardPipeline(ctx, hash, cmds, tx)
669 }(hash, cmds)
670 }
671
672 wg.Wait()
673 return cmdsFirstErr(cmds)
674}
675
676func (c *Ring) processShardPipeline(
677 ctx context.Context, hash string, cmds []Cmder, tx bool,
678) error {
679 // TODO: retry?
680 shard, err := c.shards.GetByName(hash)
681 if err != nil {
682 setCmdsErr(cmds, err)
683 return err
684 }
685
686 if tx {
687 return shard.Client.processTxPipeline(ctx, cmds)
688 }
689 return shard.Client.processPipeline(ctx, cmds)
690}
691
692func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
693 if len(keys) == 0 {
694 return fmt.Errorf("redis: Watch requires at least one key")
695 }
696
697 var shards []*ringShard
698 for _, key := range keys {
699 if key != "" {
700 shard, err := c.shards.GetByKey(hashtag.Key(key))
701 if err != nil {
702 return err
703 }
704
705 shards = append(shards, shard)
706 }
707 }
708
709 if len(shards) == 0 {
710 return fmt.Errorf("redis: Watch requires at least one shard")
711 }
712
713 if len(shards) > 1 {
714 for _, shard := range shards[1:] {
715 if shard.Client != shards[0].Client {
716 err := fmt.Errorf("redis: Watch requires all keys to be in the same shard")
717 return err
718 }
719 }
720 }
721
722 return shards[0].Client.Watch(ctx, fn, keys...)
723}
724
725// Close closes the ring client, releasing any open resources.
726//
727// It is rare to Close a Ring, as the Ring is meant to be long-lived
728// and shared between many goroutines.
729func (c *Ring) Close() error {
730 return c.shards.Close()
731}