blob: c56270b443e2082e356a34e55e49cf7f0de7e455 [file] [log] [blame]
Joey Armstronge8c091f2023-01-17 16:56:26 -05001package redis
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9 "time"
10
11 "github.com/go-redis/redis/v8/internal"
12 "github.com/go-redis/redis/v8/internal/pool"
13 "github.com/go-redis/redis/v8/internal/proto"
14)
15
16const (
17 pingTimeout = time.Second
18 chanSendTimeout = time.Minute
19)
20
21var errPingTimeout = errors.New("redis: ping timeout")
22
23// PubSub implements Pub/Sub commands as described in
24// http://redis.io/topics/pubsub. Message receiving is NOT safe
25// for concurrent use by multiple goroutines.
26//
27// PubSub automatically reconnects to Redis Server and resubscribes
28// to the channels in case of network errors.
29type PubSub struct {
30 opt *Options
31
32 newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
33 closeConn func(*pool.Conn) error
34
35 mu sync.Mutex
36 cn *pool.Conn
37 channels map[string]struct{}
38 patterns map[string]struct{}
39
40 closed bool
41 exit chan struct{}
42
43 cmd *Cmd
44
45 chOnce sync.Once
46 msgCh chan *Message
47 allCh chan interface{}
48 ping chan struct{}
49}
50
51func (c *PubSub) String() string {
52 channels := mapKeys(c.channels)
53 channels = append(channels, mapKeys(c.patterns)...)
54 return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
55}
56
57func (c *PubSub) init() {
58 c.exit = make(chan struct{})
59}
60
61func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
62 c.mu.Lock()
63 cn, err := c.conn(ctx, nil)
64 c.mu.Unlock()
65 return cn, err
66}
67
68func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) {
69 if c.closed {
70 return nil, pool.ErrClosed
71 }
72 if c.cn != nil {
73 return c.cn, nil
74 }
75
76 channels := mapKeys(c.channels)
77 channels = append(channels, newChannels...)
78
79 cn, err := c.newConn(ctx, channels)
80 if err != nil {
81 return nil, err
82 }
83
84 if err := c.resubscribe(ctx, cn); err != nil {
85 _ = c.closeConn(cn)
86 return nil, err
87 }
88
89 c.cn = cn
90 return cn, nil
91}
92
93func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
94 return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
95 return writeCmd(wr, cmd)
96 })
97}
98
99func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
100 var firstErr error
101
102 if len(c.channels) > 0 {
103 firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels))
104 }
105
106 if len(c.patterns) > 0 {
107 err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns))
108 if err != nil && firstErr == nil {
109 firstErr = err
110 }
111 }
112
113 return firstErr
114}
115
116func mapKeys(m map[string]struct{}) []string {
117 s := make([]string, len(m))
118 i := 0
119 for k := range m {
120 s[i] = k
121 i++
122 }
123 return s
124}
125
126func (c *PubSub) _subscribe(
127 ctx context.Context, cn *pool.Conn, redisCmd string, channels []string,
128) error {
129 args := make([]interface{}, 0, 1+len(channels))
130 args = append(args, redisCmd)
131 for _, channel := range channels {
132 args = append(args, channel)
133 }
134 cmd := NewSliceCmd(ctx, args...)
135 return c.writeCmd(ctx, cn, cmd)
136}
137
138func (c *PubSub) releaseConnWithLock(
139 ctx context.Context,
140 cn *pool.Conn,
141 err error,
142 allowTimeout bool,
143) {
144 c.mu.Lock()
145 c.releaseConn(ctx, cn, err, allowTimeout)
146 c.mu.Unlock()
147}
148
149func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
150 if c.cn != cn {
151 return
152 }
153 if isBadConn(err, allowTimeout) {
154 c.reconnect(ctx, err)
155 }
156}
157
158func (c *PubSub) reconnect(ctx context.Context, reason error) {
159 _ = c.closeTheCn(reason)
160 _, _ = c.conn(ctx, nil)
161}
162
163func (c *PubSub) closeTheCn(reason error) error {
164 if c.cn == nil {
165 return nil
166 }
167 if !c.closed {
168 internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
169 }
170 err := c.closeConn(c.cn)
171 c.cn = nil
172 return err
173}
174
175func (c *PubSub) Close() error {
176 c.mu.Lock()
177 defer c.mu.Unlock()
178
179 if c.closed {
180 return pool.ErrClosed
181 }
182 c.closed = true
183 close(c.exit)
184
185 return c.closeTheCn(pool.ErrClosed)
186}
187
188// Subscribe the client to the specified channels. It returns
189// empty subscription if there are no channels.
190func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error {
191 c.mu.Lock()
192 defer c.mu.Unlock()
193
194 err := c.subscribe(ctx, "subscribe", channels...)
195 if c.channels == nil {
196 c.channels = make(map[string]struct{})
197 }
198 for _, s := range channels {
199 c.channels[s] = struct{}{}
200 }
201 return err
202}
203
204// PSubscribe the client to the given patterns. It returns
205// empty subscription if there are no patterns.
206func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
207 c.mu.Lock()
208 defer c.mu.Unlock()
209
210 err := c.subscribe(ctx, "psubscribe", patterns...)
211 if c.patterns == nil {
212 c.patterns = make(map[string]struct{})
213 }
214 for _, s := range patterns {
215 c.patterns[s] = struct{}{}
216 }
217 return err
218}
219
220// Unsubscribe the client from the given channels, or from all of
221// them if none is given.
222func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
223 c.mu.Lock()
224 defer c.mu.Unlock()
225
226 for _, channel := range channels {
227 delete(c.channels, channel)
228 }
229 err := c.subscribe(ctx, "unsubscribe", channels...)
230 return err
231}
232
233// PUnsubscribe the client from the given patterns, or from all of
234// them if none is given.
235func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
236 c.mu.Lock()
237 defer c.mu.Unlock()
238
239 for _, pattern := range patterns {
240 delete(c.patterns, pattern)
241 }
242 err := c.subscribe(ctx, "punsubscribe", patterns...)
243 return err
244}
245
246func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
247 cn, err := c.conn(ctx, channels)
248 if err != nil {
249 return err
250 }
251
252 err = c._subscribe(ctx, cn, redisCmd, channels)
253 c.releaseConn(ctx, cn, err, false)
254 return err
255}
256
257func (c *PubSub) Ping(ctx context.Context, payload ...string) error {
258 args := []interface{}{"ping"}
259 if len(payload) == 1 {
260 args = append(args, payload[0])
261 }
262 cmd := NewCmd(ctx, args...)
263
264 cn, err := c.connWithLock(ctx)
265 if err != nil {
266 return err
267 }
268
269 err = c.writeCmd(ctx, cn, cmd)
270 c.releaseConnWithLock(ctx, cn, err, false)
271 return err
272}
273
274// Subscription received after a successful subscription to channel.
275type Subscription struct {
276 // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
277 Kind string
278 // Channel name we have subscribed to.
279 Channel string
280 // Number of channels we are currently subscribed to.
281 Count int
282}
283
284func (m *Subscription) String() string {
285 return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
286}
287
288// Message received as result of a PUBLISH command issued by another client.
289type Message struct {
290 Channel string
291 Pattern string
292 Payload string
293 PayloadSlice []string
294}
295
296func (m *Message) String() string {
297 return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
298}
299
300// Pong received as result of a PING command issued by another client.
301type Pong struct {
302 Payload string
303}
304
305func (p *Pong) String() string {
306 if p.Payload != "" {
307 return fmt.Sprintf("Pong<%s>", p.Payload)
308 }
309 return "Pong"
310}
311
312func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
313 switch reply := reply.(type) {
314 case string:
315 return &Pong{
316 Payload: reply,
317 }, nil
318 case []interface{}:
319 switch kind := reply[0].(string); kind {
320 case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
321 // Can be nil in case of "unsubscribe".
322 channel, _ := reply[1].(string)
323 return &Subscription{
324 Kind: kind,
325 Channel: channel,
326 Count: int(reply[2].(int64)),
327 }, nil
328 case "message":
329 switch payload := reply[2].(type) {
330 case string:
331 return &Message{
332 Channel: reply[1].(string),
333 Payload: payload,
334 }, nil
335 case []interface{}:
336 ss := make([]string, len(payload))
337 for i, s := range payload {
338 ss[i] = s.(string)
339 }
340 return &Message{
341 Channel: reply[1].(string),
342 PayloadSlice: ss,
343 }, nil
344 default:
345 return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload)
346 }
347 case "pmessage":
348 return &Message{
349 Pattern: reply[1].(string),
350 Channel: reply[2].(string),
351 Payload: reply[3].(string),
352 }, nil
353 case "pong":
354 return &Pong{
355 Payload: reply[1].(string),
356 }, nil
357 default:
358 return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
359 }
360 default:
361 return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
362 }
363}
364
365// ReceiveTimeout acts like Receive but returns an error if message
366// is not received in time. This is low-level API and in most cases
367// Channel should be used instead.
368func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
369 if c.cmd == nil {
370 c.cmd = NewCmd(ctx)
371 }
372
373 cn, err := c.connWithLock(ctx)
374 if err != nil {
375 return nil, err
376 }
377
378 err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
379 return c.cmd.readReply(rd)
380 })
381
382 c.releaseConnWithLock(ctx, cn, err, timeout > 0)
383 if err != nil {
384 return nil, err
385 }
386
387 return c.newMessage(c.cmd.Val())
388}
389
390// Receive returns a message as a Subscription, Message, Pong or error.
391// See PubSub example for details. This is low-level API and in most cases
392// Channel should be used instead.
393func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
394 return c.ReceiveTimeout(ctx, 0)
395}
396
397// ReceiveMessage returns a Message or error ignoring Subscription and Pong
398// messages. This is low-level API and in most cases Channel should be used
399// instead.
400func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
401 for {
402 msg, err := c.Receive(ctx)
403 if err != nil {
404 return nil, err
405 }
406
407 switch msg := msg.(type) {
408 case *Subscription:
409 // Ignore.
410 case *Pong:
411 // Ignore.
412 case *Message:
413 return msg, nil
414 default:
415 err := fmt.Errorf("redis: unknown message: %T", msg)
416 return nil, err
417 }
418 }
419}
420
421// Channel returns a Go channel for concurrently receiving messages.
422// The channel is closed together with the PubSub. If the Go channel
423// is blocked full for 30 seconds the message is dropped.
424// Receive* APIs can not be used after channel is created.
425//
426// go-redis periodically sends ping messages to test connection health
427// and re-subscribes if ping can not not received for 30 seconds.
428func (c *PubSub) Channel() <-chan *Message {
429 return c.ChannelSize(100)
430}
431
432// ChannelSize is like Channel, but creates a Go channel
433// with specified buffer size.
434func (c *PubSub) ChannelSize(size int) <-chan *Message {
435 c.chOnce.Do(func() {
436 c.initPing()
437 c.initMsgChan(size)
438 })
439 if c.msgCh == nil {
440 err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
441 panic(err)
442 }
443 if cap(c.msgCh) != size {
444 err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created")
445 panic(err)
446 }
447 return c.msgCh
448}
449
450// ChannelWithSubscriptions is like Channel, but message type can be either
451// *Subscription or *Message. Subscription messages can be used to detect
452// reconnections.
453//
454// ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
455func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan interface{} {
456 c.chOnce.Do(func() {
457 c.initPing()
458 c.initAllChan(size)
459 })
460 if c.allCh == nil {
461 err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
462 panic(err)
463 }
464 if cap(c.allCh) != size {
465 err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created")
466 panic(err)
467 }
468 return c.allCh
469}
470
471func (c *PubSub) getContext() context.Context {
472 if c.cmd != nil {
473 return c.cmd.ctx
474 }
475 return context.Background()
476}
477
478func (c *PubSub) initPing() {
479 ctx := context.TODO()
480 c.ping = make(chan struct{}, 1)
481 go func() {
482 timer := time.NewTimer(time.Minute)
483 timer.Stop()
484
485 healthy := true
486 for {
487 timer.Reset(pingTimeout)
488 select {
489 case <-c.ping:
490 healthy = true
491 if !timer.Stop() {
492 <-timer.C
493 }
494 case <-timer.C:
495 pingErr := c.Ping(ctx)
496 if healthy {
497 healthy = false
498 } else {
499 if pingErr == nil {
500 pingErr = errPingTimeout
501 }
502 c.mu.Lock()
503 c.reconnect(ctx, pingErr)
504 healthy = true
505 c.mu.Unlock()
506 }
507 case <-c.exit:
508 return
509 }
510 }
511 }()
512}
513
514// initMsgChan must be in sync with initAllChan.
515func (c *PubSub) initMsgChan(size int) {
516 ctx := context.TODO()
517 c.msgCh = make(chan *Message, size)
518 go func() {
519 timer := time.NewTimer(time.Minute)
520 timer.Stop()
521
522 var errCount int
523 for {
524 msg, err := c.Receive(ctx)
525 if err != nil {
526 if err == pool.ErrClosed {
527 close(c.msgCh)
528 return
529 }
530 if errCount > 0 {
531 time.Sleep(100 * time.Millisecond)
532 }
533 errCount++
534 continue
535 }
536
537 errCount = 0
538
539 // Any message is as good as a ping.
540 select {
541 case c.ping <- struct{}{}:
542 default:
543 }
544
545 switch msg := msg.(type) {
546 case *Subscription:
547 // Ignore.
548 case *Pong:
549 // Ignore.
550 case *Message:
551 timer.Reset(chanSendTimeout)
552 select {
553 case c.msgCh <- msg:
554 if !timer.Stop() {
555 <-timer.C
556 }
557 case <-timer.C:
558 internal.Logger.Printf(
559 c.getContext(),
560 "redis: %s channel is full for %s (message is dropped)",
561 c,
562 chanSendTimeout,
563 )
564 }
565 default:
566 internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg)
567 }
568 }
569 }()
570}
571
572// initAllChan must be in sync with initMsgChan.
573func (c *PubSub) initAllChan(size int) {
574 ctx := context.TODO()
575 c.allCh = make(chan interface{}, size)
576 go func() {
577 timer := time.NewTimer(pingTimeout)
578 timer.Stop()
579
580 var errCount int
581 for {
582 msg, err := c.Receive(ctx)
583 if err != nil {
584 if err == pool.ErrClosed {
585 close(c.allCh)
586 return
587 }
588 if errCount > 0 {
589 time.Sleep(100 * time.Millisecond)
590 }
591 errCount++
592 continue
593 }
594
595 errCount = 0
596
597 // Any message is as good as a ping.
598 select {
599 case c.ping <- struct{}{}:
600 default:
601 }
602
603 switch msg := msg.(type) {
604 case *Subscription:
605 c.sendMessage(msg, timer)
606 case *Pong:
607 // Ignore.
608 case *Message:
609 c.sendMessage(msg, timer)
610 default:
611 internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg)
612 }
613 }
614 }()
615}
616
617func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) {
618 timer.Reset(pingTimeout)
619 select {
620 case c.allCh <- msg:
621 if !timer.Stop() {
622 <-timer.C
623 }
624 case <-timer.C:
625 internal.Logger.Printf(
626 c.getContext(),
627 "redis: %s channel is full for %s (message is dropped)", c, pingTimeout)
628 }
629}