blob: 4df00fc857f532da2e7d77844d0f8ea4f31225ac [file] [log] [blame]
Naveen Sampath04696f72022-06-13 15:19:14 +05301package redis
2
3import (
4 "context"
5 "crypto/tls"
6 "errors"
7 "fmt"
8 "net"
9 "strconv"
10 "sync"
11 "sync/atomic"
12 "time"
13
14 "github.com/cespare/xxhash/v2"
15 rendezvous "github.com/dgryski/go-rendezvous" //nolint
16
17 "github.com/go-redis/redis/v8/internal"
18 "github.com/go-redis/redis/v8/internal/hashtag"
19 "github.com/go-redis/redis/v8/internal/pool"
20 "github.com/go-redis/redis/v8/internal/rand"
21)
22
23var errRingShardsDown = errors.New("redis: all ring shards are down")
24
25//------------------------------------------------------------------------------
26
27type ConsistentHash interface {
28 Get(string) string
29}
30
31type rendezvousWrapper struct {
32 *rendezvous.Rendezvous
33}
34
35func (w rendezvousWrapper) Get(key string) string {
36 return w.Lookup(key)
37}
38
39func newRendezvous(shards []string) ConsistentHash {
40 return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)}
41}
42
43//------------------------------------------------------------------------------
44
45// RingOptions are used to configure a ring client and should be
46// passed to NewRing.
47type RingOptions struct {
48 // Map of name => host:port addresses of ring shards.
49 Addrs map[string]string
50
51 // NewClient creates a shard client with provided name and options.
52 NewClient func(name string, opt *Options) *Client
53
54 // Frequency of PING commands sent to check shards availability.
55 // Shard is considered down after 3 subsequent failed checks.
56 HeartbeatFrequency time.Duration
57
58 // NewConsistentHash returns a consistent hash that is used
59 // to distribute keys across the shards.
60 //
61 // See https://medium.com/@dgryski/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8
62 // for consistent hashing algorithmic tradeoffs.
63 NewConsistentHash func(shards []string) ConsistentHash
64
65 // Following options are copied from Options struct.
66
67 Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
68 OnConnect func(ctx context.Context, cn *Conn) error
69
70 Username string
71 Password string
72 DB int
73
74 MaxRetries int
75 MinRetryBackoff time.Duration
76 MaxRetryBackoff time.Duration
77
78 DialTimeout time.Duration
79 ReadTimeout time.Duration
80 WriteTimeout time.Duration
81
82 // PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO).
83 PoolFIFO bool
84
85 PoolSize int
86 MinIdleConns int
87 MaxConnAge time.Duration
88 PoolTimeout time.Duration
89 IdleTimeout time.Duration
90 IdleCheckFrequency time.Duration
91
92 TLSConfig *tls.Config
93 Limiter Limiter
94}
95
96func (opt *RingOptions) init() {
97 if opt.NewClient == nil {
98 opt.NewClient = func(name string, opt *Options) *Client {
99 return NewClient(opt)
100 }
101 }
102
103 if opt.HeartbeatFrequency == 0 {
104 opt.HeartbeatFrequency = 500 * time.Millisecond
105 }
106
107 if opt.NewConsistentHash == nil {
108 opt.NewConsistentHash = newRendezvous
109 }
110
111 if opt.MaxRetries == -1 {
112 opt.MaxRetries = 0
113 } else if opt.MaxRetries == 0 {
114 opt.MaxRetries = 3
115 }
116 switch opt.MinRetryBackoff {
117 case -1:
118 opt.MinRetryBackoff = 0
119 case 0:
120 opt.MinRetryBackoff = 8 * time.Millisecond
121 }
122 switch opt.MaxRetryBackoff {
123 case -1:
124 opt.MaxRetryBackoff = 0
125 case 0:
126 opt.MaxRetryBackoff = 512 * time.Millisecond
127 }
128}
129
130func (opt *RingOptions) clientOptions() *Options {
131 return &Options{
132 Dialer: opt.Dialer,
133 OnConnect: opt.OnConnect,
134
135 Username: opt.Username,
136 Password: opt.Password,
137 DB: opt.DB,
138
139 MaxRetries: -1,
140
141 DialTimeout: opt.DialTimeout,
142 ReadTimeout: opt.ReadTimeout,
143 WriteTimeout: opt.WriteTimeout,
144
145 PoolFIFO: opt.PoolFIFO,
146 PoolSize: opt.PoolSize,
147 MinIdleConns: opt.MinIdleConns,
148 MaxConnAge: opt.MaxConnAge,
149 PoolTimeout: opt.PoolTimeout,
150 IdleTimeout: opt.IdleTimeout,
151 IdleCheckFrequency: opt.IdleCheckFrequency,
152
153 TLSConfig: opt.TLSConfig,
154 Limiter: opt.Limiter,
155 }
156}
157
158//------------------------------------------------------------------------------
159
160type ringShard struct {
161 Client *Client
162 down int32
163}
164
165func newRingShard(opt *RingOptions, name, addr string) *ringShard {
166 clopt := opt.clientOptions()
167 clopt.Addr = addr
168
169 return &ringShard{
170 Client: opt.NewClient(name, clopt),
171 }
172}
173
174func (shard *ringShard) String() string {
175 var state string
176 if shard.IsUp() {
177 state = "up"
178 } else {
179 state = "down"
180 }
181 return fmt.Sprintf("%s is %s", shard.Client, state)
182}
183
184func (shard *ringShard) IsDown() bool {
185 const threshold = 3
186 return atomic.LoadInt32(&shard.down) >= threshold
187}
188
189func (shard *ringShard) IsUp() bool {
190 return !shard.IsDown()
191}
192
193// Vote votes to set shard state and returns true if state was changed.
194func (shard *ringShard) Vote(up bool) bool {
195 if up {
196 changed := shard.IsDown()
197 atomic.StoreInt32(&shard.down, 0)
198 return changed
199 }
200
201 if shard.IsDown() {
202 return false
203 }
204
205 atomic.AddInt32(&shard.down, 1)
206 return shard.IsDown()
207}
208
209//------------------------------------------------------------------------------
210
211type ringShards struct {
212 opt *RingOptions
213
214 mu sync.RWMutex
215 hash ConsistentHash
216 shards map[string]*ringShard // read only
217 list []*ringShard // read only
218 numShard int
219 closed bool
220}
221
222func newRingShards(opt *RingOptions) *ringShards {
223 shards := make(map[string]*ringShard, len(opt.Addrs))
224 list := make([]*ringShard, 0, len(shards))
225
226 for name, addr := range opt.Addrs {
227 shard := newRingShard(opt, name, addr)
228 shards[name] = shard
229
230 list = append(list, shard)
231 }
232
233 c := &ringShards{
234 opt: opt,
235
236 shards: shards,
237 list: list,
238 }
239 c.rebalance()
240
241 return c
242}
243
244func (c *ringShards) List() []*ringShard {
245 var list []*ringShard
246
247 c.mu.RLock()
248 if !c.closed {
249 list = c.list
250 }
251 c.mu.RUnlock()
252
253 return list
254}
255
256func (c *ringShards) Hash(key string) string {
257 key = hashtag.Key(key)
258
259 var hash string
260
261 c.mu.RLock()
262 if c.numShard > 0 {
263 hash = c.hash.Get(key)
264 }
265 c.mu.RUnlock()
266
267 return hash
268}
269
270func (c *ringShards) GetByKey(key string) (*ringShard, error) {
271 key = hashtag.Key(key)
272
273 c.mu.RLock()
274
275 if c.closed {
276 c.mu.RUnlock()
277 return nil, pool.ErrClosed
278 }
279
280 if c.numShard == 0 {
281 c.mu.RUnlock()
282 return nil, errRingShardsDown
283 }
284
285 hash := c.hash.Get(key)
286 if hash == "" {
287 c.mu.RUnlock()
288 return nil, errRingShardsDown
289 }
290
291 shard := c.shards[hash]
292 c.mu.RUnlock()
293
294 return shard, nil
295}
296
297func (c *ringShards) GetByName(shardName string) (*ringShard, error) {
298 if shardName == "" {
299 return c.Random()
300 }
301
302 c.mu.RLock()
303 shard := c.shards[shardName]
304 c.mu.RUnlock()
305 return shard, nil
306}
307
308func (c *ringShards) Random() (*ringShard, error) {
309 return c.GetByKey(strconv.Itoa(rand.Int()))
310}
311
312// heartbeat monitors state of each shard in the ring.
313func (c *ringShards) Heartbeat(frequency time.Duration) {
314 ticker := time.NewTicker(frequency)
315 defer ticker.Stop()
316
317 ctx := context.Background()
318 for range ticker.C {
319 var rebalance bool
320
321 for _, shard := range c.List() {
322 err := shard.Client.Ping(ctx).Err()
323 isUp := err == nil || err == pool.ErrPoolTimeout
324 if shard.Vote(isUp) {
325 internal.Logger.Printf(context.Background(), "ring shard state changed: %s", shard)
326 rebalance = true
327 }
328 }
329
330 if rebalance {
331 c.rebalance()
332 }
333 }
334}
335
336// rebalance removes dead shards from the Ring.
337func (c *ringShards) rebalance() {
338 c.mu.RLock()
339 shards := c.shards
340 c.mu.RUnlock()
341
342 liveShards := make([]string, 0, len(shards))
343
344 for name, shard := range shards {
345 if shard.IsUp() {
346 liveShards = append(liveShards, name)
347 }
348 }
349
350 hash := c.opt.NewConsistentHash(liveShards)
351
352 c.mu.Lock()
353 c.hash = hash
354 c.numShard = len(liveShards)
355 c.mu.Unlock()
356}
357
358func (c *ringShards) Len() int {
359 c.mu.RLock()
360 l := c.numShard
361 c.mu.RUnlock()
362 return l
363}
364
365func (c *ringShards) Close() error {
366 c.mu.Lock()
367 defer c.mu.Unlock()
368
369 if c.closed {
370 return nil
371 }
372 c.closed = true
373
374 var firstErr error
375 for _, shard := range c.shards {
376 if err := shard.Client.Close(); err != nil && firstErr == nil {
377 firstErr = err
378 }
379 }
380 c.hash = nil
381 c.shards = nil
382 c.list = nil
383
384 return firstErr
385}
386
387//------------------------------------------------------------------------------
388
389type ring struct {
390 opt *RingOptions
391 shards *ringShards
392 cmdsInfoCache *cmdsInfoCache //nolint:structcheck
393}
394
395// Ring is a Redis client that uses consistent hashing to distribute
396// keys across multiple Redis servers (shards). It's safe for
397// concurrent use by multiple goroutines.
398//
399// Ring monitors the state of each shard and removes dead shards from
400// the ring. When a shard comes online it is added back to the ring. This
401// gives you maximum availability and partition tolerance, but no
402// consistency between different shards or even clients. Each client
403// uses shards that are available to the client and does not do any
404// coordination when shard state is changed.
405//
406// Ring should be used when you need multiple Redis servers for caching
407// and can tolerate losing data when one of the servers dies.
408// Otherwise you should use Redis Cluster.
409type Ring struct {
410 *ring
411 cmdable
412 hooks
413 ctx context.Context
414}
415
416func NewRing(opt *RingOptions) *Ring {
417 opt.init()
418
419 ring := Ring{
420 ring: &ring{
421 opt: opt,
422 shards: newRingShards(opt),
423 },
424 ctx: context.Background(),
425 }
426
427 ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
428 ring.cmdable = ring.Process
429
430 go ring.shards.Heartbeat(opt.HeartbeatFrequency)
431
432 return &ring
433}
434
435func (c *Ring) Context() context.Context {
436 return c.ctx
437}
438
439func (c *Ring) WithContext(ctx context.Context) *Ring {
440 if ctx == nil {
441 panic("nil context")
442 }
443 clone := *c
444 clone.cmdable = clone.Process
445 clone.hooks.lock()
446 clone.ctx = ctx
447 return &clone
448}
449
450// Do creates a Cmd from the args and processes the cmd.
451func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd {
452 cmd := NewCmd(ctx, args...)
453 _ = c.Process(ctx, cmd)
454 return cmd
455}
456
457func (c *Ring) Process(ctx context.Context, cmd Cmder) error {
458 return c.hooks.process(ctx, cmd, c.process)
459}
460
461// Options returns read-only Options that were used to create the client.
462func (c *Ring) Options() *RingOptions {
463 return c.opt
464}
465
466func (c *Ring) retryBackoff(attempt int) time.Duration {
467 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
468}
469
470// PoolStats returns accumulated connection pool stats.
471func (c *Ring) PoolStats() *PoolStats {
472 shards := c.shards.List()
473 var acc PoolStats
474 for _, shard := range shards {
475 s := shard.Client.connPool.Stats()
476 acc.Hits += s.Hits
477 acc.Misses += s.Misses
478 acc.Timeouts += s.Timeouts
479 acc.TotalConns += s.TotalConns
480 acc.IdleConns += s.IdleConns
481 }
482 return &acc
483}
484
485// Len returns the current number of shards in the ring.
486func (c *Ring) Len() int {
487 return c.shards.Len()
488}
489
490// Subscribe subscribes the client to the specified channels.
491func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub {
492 if len(channels) == 0 {
493 panic("at least one channel is required")
494 }
495
496 shard, err := c.shards.GetByKey(channels[0])
497 if err != nil {
498 // TODO: return PubSub with sticky error
499 panic(err)
500 }
501 return shard.Client.Subscribe(ctx, channels...)
502}
503
504// PSubscribe subscribes the client to the given patterns.
505func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
506 if len(channels) == 0 {
507 panic("at least one channel is required")
508 }
509
510 shard, err := c.shards.GetByKey(channels[0])
511 if err != nil {
512 // TODO: return PubSub with sticky error
513 panic(err)
514 }
515 return shard.Client.PSubscribe(ctx, channels...)
516}
517
518// ForEachShard concurrently calls the fn on each live shard in the ring.
519// It returns the first error if any.
520func (c *Ring) ForEachShard(
521 ctx context.Context,
522 fn func(ctx context.Context, client *Client) error,
523) error {
524 shards := c.shards.List()
525 var wg sync.WaitGroup
526 errCh := make(chan error, 1)
527 for _, shard := range shards {
528 if shard.IsDown() {
529 continue
530 }
531
532 wg.Add(1)
533 go func(shard *ringShard) {
534 defer wg.Done()
535 err := fn(ctx, shard.Client)
536 if err != nil {
537 select {
538 case errCh <- err:
539 default:
540 }
541 }
542 }(shard)
543 }
544 wg.Wait()
545
546 select {
547 case err := <-errCh:
548 return err
549 default:
550 return nil
551 }
552}
553
554func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) {
555 shards := c.shards.List()
556 var firstErr error
557 for _, shard := range shards {
558 cmdsInfo, err := shard.Client.Command(ctx).Result()
559 if err == nil {
560 return cmdsInfo, nil
561 }
562 if firstErr == nil {
563 firstErr = err
564 }
565 }
566 if firstErr == nil {
567 return nil, errRingShardsDown
568 }
569 return nil, firstErr
570}
571
572func (c *Ring) cmdInfo(ctx context.Context, name string) *CommandInfo {
573 cmdsInfo, err := c.cmdsInfoCache.Get(ctx)
574 if err != nil {
575 return nil
576 }
577 info := cmdsInfo[name]
578 if info == nil {
579 internal.Logger.Printf(ctx, "info for cmd=%s not found", name)
580 }
581 return info
582}
583
584func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) {
585 cmdInfo := c.cmdInfo(ctx, cmd.Name())
586 pos := cmdFirstKeyPos(cmd, cmdInfo)
587 if pos == 0 {
588 return c.shards.Random()
589 }
590 firstKey := cmd.stringArg(pos)
591 return c.shards.GetByKey(firstKey)
592}
593
594func (c *Ring) process(ctx context.Context, cmd Cmder) error {
595 var lastErr error
596 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
597 if attempt > 0 {
598 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
599 return err
600 }
601 }
602
603 shard, err := c.cmdShard(ctx, cmd)
604 if err != nil {
605 return err
606 }
607
608 lastErr = shard.Client.Process(ctx, cmd)
609 if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) {
610 return lastErr
611 }
612 }
613 return lastErr
614}
615
616func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
617 return c.Pipeline().Pipelined(ctx, fn)
618}
619
620func (c *Ring) Pipeline() Pipeliner {
621 pipe := Pipeline{
622 ctx: c.ctx,
623 exec: c.processPipeline,
624 }
625 pipe.init()
626 return &pipe
627}
628
629func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error {
630 return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
631 return c.generalProcessPipeline(ctx, cmds, false)
632 })
633}
634
635func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
636 return c.TxPipeline().Pipelined(ctx, fn)
637}
638
639func (c *Ring) TxPipeline() Pipeliner {
640 pipe := Pipeline{
641 ctx: c.ctx,
642 exec: c.processTxPipeline,
643 }
644 pipe.init()
645 return &pipe
646}
647
648func (c *Ring) processTxPipeline(ctx context.Context, cmds []Cmder) error {
649 return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
650 return c.generalProcessPipeline(ctx, cmds, true)
651 })
652}
653
654func (c *Ring) generalProcessPipeline(
655 ctx context.Context, cmds []Cmder, tx bool,
656) error {
657 cmdsMap := make(map[string][]Cmder)
658 for _, cmd := range cmds {
659 cmdInfo := c.cmdInfo(ctx, cmd.Name())
660 hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo))
661 if hash != "" {
662 hash = c.shards.Hash(hash)
663 }
664 cmdsMap[hash] = append(cmdsMap[hash], cmd)
665 }
666
667 var wg sync.WaitGroup
668 for hash, cmds := range cmdsMap {
669 wg.Add(1)
670 go func(hash string, cmds []Cmder) {
671 defer wg.Done()
672
673 _ = c.processShardPipeline(ctx, hash, cmds, tx)
674 }(hash, cmds)
675 }
676
677 wg.Wait()
678 return cmdsFirstErr(cmds)
679}
680
681func (c *Ring) processShardPipeline(
682 ctx context.Context, hash string, cmds []Cmder, tx bool,
683) error {
684 // TODO: retry?
685 shard, err := c.shards.GetByName(hash)
686 if err != nil {
687 setCmdsErr(cmds, err)
688 return err
689 }
690
691 if tx {
692 return shard.Client.processTxPipeline(ctx, cmds)
693 }
694 return shard.Client.processPipeline(ctx, cmds)
695}
696
697func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
698 if len(keys) == 0 {
699 return fmt.Errorf("redis: Watch requires at least one key")
700 }
701
702 var shards []*ringShard
703 for _, key := range keys {
704 if key != "" {
705 shard, err := c.shards.GetByKey(hashtag.Key(key))
706 if err != nil {
707 return err
708 }
709
710 shards = append(shards, shard)
711 }
712 }
713
714 if len(shards) == 0 {
715 return fmt.Errorf("redis: Watch requires at least one shard")
716 }
717
718 if len(shards) > 1 {
719 for _, shard := range shards[1:] {
720 if shard.Client != shards[0].Client {
721 err := fmt.Errorf("redis: Watch requires all keys to be in the same shard")
722 return err
723 }
724 }
725 }
726
727 return shards[0].Client.Watch(ctx, fn, keys...)
728}
729
730// Close closes the ring client, releasing any open resources.
731//
732// It is rare to Close a Ring, as the Ring is meant to be long-lived
733// and shared between many goroutines.
734func (c *Ring) Close() error {
735 return c.shards.Close()
736}