blob: efc2354af0c547b99cc7a96f5c595f34609e9675 [file] [log] [blame]
Naveen Sampath04696f72022-06-13 15:19:14 +05301package redis
2
3import (
4 "context"
5 "fmt"
6 "strings"
7 "sync"
8 "time"
9
10 "github.com/go-redis/redis/v8/internal"
11 "github.com/go-redis/redis/v8/internal/pool"
12 "github.com/go-redis/redis/v8/internal/proto"
13)
14
15// PubSub implements Pub/Sub commands as described in
16// http://redis.io/topics/pubsub. Message receiving is NOT safe
17// for concurrent use by multiple goroutines.
18//
19// PubSub automatically reconnects to Redis Server and resubscribes
20// to the channels in case of network errors.
21type PubSub struct {
22 opt *Options
23
24 newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
25 closeConn func(*pool.Conn) error
26
27 mu sync.Mutex
28 cn *pool.Conn
29 channels map[string]struct{}
30 patterns map[string]struct{}
31
32 closed bool
33 exit chan struct{}
34
35 cmd *Cmd
36
37 chOnce sync.Once
38 msgCh *channel
39 allCh *channel
40}
41
42func (c *PubSub) init() {
43 c.exit = make(chan struct{})
44}
45
46func (c *PubSub) String() string {
47 channels := mapKeys(c.channels)
48 channels = append(channels, mapKeys(c.patterns)...)
49 return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
50}
51
52func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
53 c.mu.Lock()
54 cn, err := c.conn(ctx, nil)
55 c.mu.Unlock()
56 return cn, err
57}
58
59func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) {
60 if c.closed {
61 return nil, pool.ErrClosed
62 }
63 if c.cn != nil {
64 return c.cn, nil
65 }
66
67 channels := mapKeys(c.channels)
68 channels = append(channels, newChannels...)
69
70 cn, err := c.newConn(ctx, channels)
71 if err != nil {
72 return nil, err
73 }
74
75 if err := c.resubscribe(ctx, cn); err != nil {
76 _ = c.closeConn(cn)
77 return nil, err
78 }
79
80 c.cn = cn
81 return cn, nil
82}
83
84func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
85 return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
86 return writeCmd(wr, cmd)
87 })
88}
89
90func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
91 var firstErr error
92
93 if len(c.channels) > 0 {
94 firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels))
95 }
96
97 if len(c.patterns) > 0 {
98 err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns))
99 if err != nil && firstErr == nil {
100 firstErr = err
101 }
102 }
103
104 return firstErr
105}
106
107func mapKeys(m map[string]struct{}) []string {
108 s := make([]string, len(m))
109 i := 0
110 for k := range m {
111 s[i] = k
112 i++
113 }
114 return s
115}
116
117func (c *PubSub) _subscribe(
118 ctx context.Context, cn *pool.Conn, redisCmd string, channels []string,
119) error {
120 args := make([]interface{}, 0, 1+len(channels))
121 args = append(args, redisCmd)
122 for _, channel := range channels {
123 args = append(args, channel)
124 }
125 cmd := NewSliceCmd(ctx, args...)
126 return c.writeCmd(ctx, cn, cmd)
127}
128
129func (c *PubSub) releaseConnWithLock(
130 ctx context.Context,
131 cn *pool.Conn,
132 err error,
133 allowTimeout bool,
134) {
135 c.mu.Lock()
136 c.releaseConn(ctx, cn, err, allowTimeout)
137 c.mu.Unlock()
138}
139
140func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
141 if c.cn != cn {
142 return
143 }
144 if isBadConn(err, allowTimeout, c.opt.Addr) {
145 c.reconnect(ctx, err)
146 }
147}
148
149func (c *PubSub) reconnect(ctx context.Context, reason error) {
150 _ = c.closeTheCn(reason)
151 _, _ = c.conn(ctx, nil)
152}
153
154func (c *PubSub) closeTheCn(reason error) error {
155 if c.cn == nil {
156 return nil
157 }
158 if !c.closed {
159 internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
160 }
161 err := c.closeConn(c.cn)
162 c.cn = nil
163 return err
164}
165
166func (c *PubSub) Close() error {
167 c.mu.Lock()
168 defer c.mu.Unlock()
169
170 if c.closed {
171 return pool.ErrClosed
172 }
173 c.closed = true
174 close(c.exit)
175
176 return c.closeTheCn(pool.ErrClosed)
177}
178
179// Subscribe the client to the specified channels. It returns
180// empty subscription if there are no channels.
181func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error {
182 c.mu.Lock()
183 defer c.mu.Unlock()
184
185 err := c.subscribe(ctx, "subscribe", channels...)
186 if c.channels == nil {
187 c.channels = make(map[string]struct{})
188 }
189 for _, s := range channels {
190 c.channels[s] = struct{}{}
191 }
192 return err
193}
194
195// PSubscribe the client to the given patterns. It returns
196// empty subscription if there are no patterns.
197func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
198 c.mu.Lock()
199 defer c.mu.Unlock()
200
201 err := c.subscribe(ctx, "psubscribe", patterns...)
202 if c.patterns == nil {
203 c.patterns = make(map[string]struct{})
204 }
205 for _, s := range patterns {
206 c.patterns[s] = struct{}{}
207 }
208 return err
209}
210
211// Unsubscribe the client from the given channels, or from all of
212// them if none is given.
213func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
214 c.mu.Lock()
215 defer c.mu.Unlock()
216
217 for _, channel := range channels {
218 delete(c.channels, channel)
219 }
220 err := c.subscribe(ctx, "unsubscribe", channels...)
221 return err
222}
223
224// PUnsubscribe the client from the given patterns, or from all of
225// them if none is given.
226func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
227 c.mu.Lock()
228 defer c.mu.Unlock()
229
230 for _, pattern := range patterns {
231 delete(c.patterns, pattern)
232 }
233 err := c.subscribe(ctx, "punsubscribe", patterns...)
234 return err
235}
236
237func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
238 cn, err := c.conn(ctx, channels)
239 if err != nil {
240 return err
241 }
242
243 err = c._subscribe(ctx, cn, redisCmd, channels)
244 c.releaseConn(ctx, cn, err, false)
245 return err
246}
247
248func (c *PubSub) Ping(ctx context.Context, payload ...string) error {
249 args := []interface{}{"ping"}
250 if len(payload) == 1 {
251 args = append(args, payload[0])
252 }
253 cmd := NewCmd(ctx, args...)
254
255 c.mu.Lock()
256 defer c.mu.Unlock()
257
258 cn, err := c.conn(ctx, nil)
259 if err != nil {
260 return err
261 }
262
263 err = c.writeCmd(ctx, cn, cmd)
264 c.releaseConn(ctx, cn, err, false)
265 return err
266}
267
268// Subscription received after a successful subscription to channel.
269type Subscription struct {
270 // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
271 Kind string
272 // Channel name we have subscribed to.
273 Channel string
274 // Number of channels we are currently subscribed to.
275 Count int
276}
277
278func (m *Subscription) String() string {
279 return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
280}
281
282// Message received as result of a PUBLISH command issued by another client.
283type Message struct {
284 Channel string
285 Pattern string
286 Payload string
287 PayloadSlice []string
288}
289
290func (m *Message) String() string {
291 return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
292}
293
294// Pong received as result of a PING command issued by another client.
295type Pong struct {
296 Payload string
297}
298
299func (p *Pong) String() string {
300 if p.Payload != "" {
301 return fmt.Sprintf("Pong<%s>", p.Payload)
302 }
303 return "Pong"
304}
305
306func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
307 switch reply := reply.(type) {
308 case string:
309 return &Pong{
310 Payload: reply,
311 }, nil
312 case []interface{}:
313 switch kind := reply[0].(string); kind {
314 case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
315 // Can be nil in case of "unsubscribe".
316 channel, _ := reply[1].(string)
317 return &Subscription{
318 Kind: kind,
319 Channel: channel,
320 Count: int(reply[2].(int64)),
321 }, nil
322 case "message":
323 switch payload := reply[2].(type) {
324 case string:
325 return &Message{
326 Channel: reply[1].(string),
327 Payload: payload,
328 }, nil
329 case []interface{}:
330 ss := make([]string, len(payload))
331 for i, s := range payload {
332 ss[i] = s.(string)
333 }
334 return &Message{
335 Channel: reply[1].(string),
336 PayloadSlice: ss,
337 }, nil
338 default:
339 return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload)
340 }
341 case "pmessage":
342 return &Message{
343 Pattern: reply[1].(string),
344 Channel: reply[2].(string),
345 Payload: reply[3].(string),
346 }, nil
347 case "pong":
348 return &Pong{
349 Payload: reply[1].(string),
350 }, nil
351 default:
352 return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
353 }
354 default:
355 return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
356 }
357}
358
359// ReceiveTimeout acts like Receive but returns an error if message
360// is not received in time. This is low-level API and in most cases
361// Channel should be used instead.
362func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
363 if c.cmd == nil {
364 c.cmd = NewCmd(ctx)
365 }
366
367 // Don't hold the lock to allow subscriptions and pings.
368
369 cn, err := c.connWithLock(ctx)
370 if err != nil {
371 return nil, err
372 }
373
374 err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
375 return c.cmd.readReply(rd)
376 })
377
378 c.releaseConnWithLock(ctx, cn, err, timeout > 0)
379
380 if err != nil {
381 return nil, err
382 }
383
384 return c.newMessage(c.cmd.Val())
385}
386
387// Receive returns a message as a Subscription, Message, Pong or error.
388// See PubSub example for details. This is low-level API and in most cases
389// Channel should be used instead.
390func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
391 return c.ReceiveTimeout(ctx, 0)
392}
393
394// ReceiveMessage returns a Message or error ignoring Subscription and Pong
395// messages. This is low-level API and in most cases Channel should be used
396// instead.
397func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
398 for {
399 msg, err := c.Receive(ctx)
400 if err != nil {
401 return nil, err
402 }
403
404 switch msg := msg.(type) {
405 case *Subscription:
406 // Ignore.
407 case *Pong:
408 // Ignore.
409 case *Message:
410 return msg, nil
411 default:
412 err := fmt.Errorf("redis: unknown message: %T", msg)
413 return nil, err
414 }
415 }
416}
417
418func (c *PubSub) getContext() context.Context {
419 if c.cmd != nil {
420 return c.cmd.ctx
421 }
422 return context.Background()
423}
424
425//------------------------------------------------------------------------------
426
427// Channel returns a Go channel for concurrently receiving messages.
428// The channel is closed together with the PubSub. If the Go channel
429// is blocked full for 30 seconds the message is dropped.
430// Receive* APIs can not be used after channel is created.
431//
432// go-redis periodically sends ping messages to test connection health
433// and re-subscribes if ping can not not received for 30 seconds.
434func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message {
435 c.chOnce.Do(func() {
436 c.msgCh = newChannel(c, opts...)
437 c.msgCh.initMsgChan()
438 })
439 if c.msgCh == nil {
440 err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
441 panic(err)
442 }
443 return c.msgCh.msgCh
444}
445
446// ChannelSize is like Channel, but creates a Go channel
447// with specified buffer size.
448//
449// Deprecated: use Channel(WithChannelSize(size)), remove in v9.
450func (c *PubSub) ChannelSize(size int) <-chan *Message {
451 return c.Channel(WithChannelSize(size))
452}
453
454// ChannelWithSubscriptions is like Channel, but message type can be either
455// *Subscription or *Message. Subscription messages can be used to detect
456// reconnections.
457//
458// ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
459func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} {
460 c.chOnce.Do(func() {
461 c.allCh = newChannel(c, WithChannelSize(size))
462 c.allCh.initAllChan()
463 })
464 if c.allCh == nil {
465 err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
466 panic(err)
467 }
468 return c.allCh.allCh
469}
470
471type ChannelOption func(c *channel)
472
473// WithChannelSize specifies the Go chan size that is used to buffer incoming messages.
474//
475// The default is 100 messages.
476func WithChannelSize(size int) ChannelOption {
477 return func(c *channel) {
478 c.chanSize = size
479 }
480}
481
482// WithChannelHealthCheckInterval specifies the health check interval.
483// PubSub will ping Redis Server if it does not receive any messages within the interval.
484// To disable health check, use zero interval.
485//
486// The default is 3 seconds.
487func WithChannelHealthCheckInterval(d time.Duration) ChannelOption {
488 return func(c *channel) {
489 c.checkInterval = d
490 }
491}
492
493// WithChannelSendTimeout specifies the channel send timeout after which
494// the message is dropped.
495//
496// The default is 60 seconds.
497func WithChannelSendTimeout(d time.Duration) ChannelOption {
498 return func(c *channel) {
499 c.chanSendTimeout = d
500 }
501}
502
503type channel struct {
504 pubSub *PubSub
505
506 msgCh chan *Message
507 allCh chan interface{}
508 ping chan struct{}
509
510 chanSize int
511 chanSendTimeout time.Duration
512 checkInterval time.Duration
513}
514
515func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel {
516 c := &channel{
517 pubSub: pubSub,
518
519 chanSize: 100,
520 chanSendTimeout: time.Minute,
521 checkInterval: 3 * time.Second,
522 }
523 for _, opt := range opts {
524 opt(c)
525 }
526 if c.checkInterval > 0 {
527 c.initHealthCheck()
528 }
529 return c
530}
531
532func (c *channel) initHealthCheck() {
533 ctx := context.TODO()
534 c.ping = make(chan struct{}, 1)
535
536 go func() {
537 timer := time.NewTimer(time.Minute)
538 timer.Stop()
539
540 for {
541 timer.Reset(c.checkInterval)
542 select {
543 case <-c.ping:
544 if !timer.Stop() {
545 <-timer.C
546 }
547 case <-timer.C:
548 if pingErr := c.pubSub.Ping(ctx); pingErr != nil {
549 c.pubSub.mu.Lock()
550 c.pubSub.reconnect(ctx, pingErr)
551 c.pubSub.mu.Unlock()
552 }
553 case <-c.pubSub.exit:
554 return
555 }
556 }
557 }()
558}
559
560// initMsgChan must be in sync with initAllChan.
561func (c *channel) initMsgChan() {
562 ctx := context.TODO()
563 c.msgCh = make(chan *Message, c.chanSize)
564
565 go func() {
566 timer := time.NewTimer(time.Minute)
567 timer.Stop()
568
569 var errCount int
570 for {
571 msg, err := c.pubSub.Receive(ctx)
572 if err != nil {
573 if err == pool.ErrClosed {
574 close(c.msgCh)
575 return
576 }
577 if errCount > 0 {
578 time.Sleep(100 * time.Millisecond)
579 }
580 errCount++
581 continue
582 }
583
584 errCount = 0
585
586 // Any message is as good as a ping.
587 select {
588 case c.ping <- struct{}{}:
589 default:
590 }
591
592 switch msg := msg.(type) {
593 case *Subscription:
594 // Ignore.
595 case *Pong:
596 // Ignore.
597 case *Message:
598 timer.Reset(c.chanSendTimeout)
599 select {
600 case c.msgCh <- msg:
601 if !timer.Stop() {
602 <-timer.C
603 }
604 case <-timer.C:
605 internal.Logger.Printf(
606 ctx, "redis: %s channel is full for %s (message is dropped)",
607 c, c.chanSendTimeout)
608 }
609 default:
610 internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
611 }
612 }
613 }()
614}
615
616// initAllChan must be in sync with initMsgChan.
617func (c *channel) initAllChan() {
618 ctx := context.TODO()
619 c.allCh = make(chan interface{}, c.chanSize)
620
621 go func() {
622 timer := time.NewTimer(time.Minute)
623 timer.Stop()
624
625 var errCount int
626 for {
627 msg, err := c.pubSub.Receive(ctx)
628 if err != nil {
629 if err == pool.ErrClosed {
630 close(c.allCh)
631 return
632 }
633 if errCount > 0 {
634 time.Sleep(100 * time.Millisecond)
635 }
636 errCount++
637 continue
638 }
639
640 errCount = 0
641
642 // Any message is as good as a ping.
643 select {
644 case c.ping <- struct{}{}:
645 default:
646 }
647
648 switch msg := msg.(type) {
649 case *Pong:
650 // Ignore.
651 case *Subscription, *Message:
652 timer.Reset(c.chanSendTimeout)
653 select {
654 case c.allCh <- msg:
655 if !timer.Stop() {
656 <-timer.C
657 }
658 case <-timer.C:
659 internal.Logger.Printf(
660 ctx, "redis: %s channel is full for %s (message is dropped)",
661 c, c.chanSendTimeout)
662 }
663 default:
664 internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
665 }
666 }
667 }()
668}