| package redis |
| |
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "strings" |
| "sync" |
| "time" |
| |
| "github.com/go-redis/redis/v8/internal" |
| "github.com/go-redis/redis/v8/internal/pool" |
| "github.com/go-redis/redis/v8/internal/proto" |
| ) |
| |
| const ( |
| pingTimeout = time.Second |
| chanSendTimeout = time.Minute |
| ) |
| |
| var errPingTimeout = errors.New("redis: ping timeout") |
| |
| // PubSub implements Pub/Sub commands as described in |
| // http://redis.io/topics/pubsub. Message receiving is NOT safe |
| // for concurrent use by multiple goroutines. |
| // |
| // PubSub automatically reconnects to Redis Server and resubscribes |
| // to the channels in case of network errors. |
| type PubSub struct { |
| opt *Options |
| |
| newConn func(ctx context.Context, channels []string) (*pool.Conn, error) |
| closeConn func(*pool.Conn) error |
| |
| mu sync.Mutex |
| cn *pool.Conn |
| channels map[string]struct{} |
| patterns map[string]struct{} |
| |
| closed bool |
| exit chan struct{} |
| |
| cmd *Cmd |
| |
| chOnce sync.Once |
| msgCh chan *Message |
| allCh chan interface{} |
| ping chan struct{} |
| } |
| |
| func (c *PubSub) String() string { |
| channels := mapKeys(c.channels) |
| channels = append(channels, mapKeys(c.patterns)...) |
| return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) |
| } |
| |
| func (c *PubSub) init() { |
| c.exit = make(chan struct{}) |
| } |
| |
| func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { |
| c.mu.Lock() |
| cn, err := c.conn(ctx, nil) |
| c.mu.Unlock() |
| return cn, err |
| } |
| |
| func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) { |
| if c.closed { |
| return nil, pool.ErrClosed |
| } |
| if c.cn != nil { |
| return c.cn, nil |
| } |
| |
| channels := mapKeys(c.channels) |
| channels = append(channels, newChannels...) |
| |
| cn, err := c.newConn(ctx, channels) |
| if err != nil { |
| return nil, err |
| } |
| |
| if err := c.resubscribe(ctx, cn); err != nil { |
| _ = c.closeConn(cn) |
| return nil, err |
| } |
| |
| c.cn = cn |
| return cn, nil |
| } |
| |
| func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error { |
| return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { |
| return writeCmd(wr, cmd) |
| }) |
| } |
| |
| func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error { |
| var firstErr error |
| |
| if len(c.channels) > 0 { |
| firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels)) |
| } |
| |
| if len(c.patterns) > 0 { |
| err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns)) |
| if err != nil && firstErr == nil { |
| firstErr = err |
| } |
| } |
| |
| return firstErr |
| } |
| |
| func mapKeys(m map[string]struct{}) []string { |
| s := make([]string, len(m)) |
| i := 0 |
| for k := range m { |
| s[i] = k |
| i++ |
| } |
| return s |
| } |
| |
| func (c *PubSub) _subscribe( |
| ctx context.Context, cn *pool.Conn, redisCmd string, channels []string, |
| ) error { |
| args := make([]interface{}, 0, 1+len(channels)) |
| args = append(args, redisCmd) |
| for _, channel := range channels { |
| args = append(args, channel) |
| } |
| cmd := NewSliceCmd(ctx, args...) |
| return c.writeCmd(ctx, cn, cmd) |
| } |
| |
| func (c *PubSub) releaseConnWithLock( |
| ctx context.Context, |
| cn *pool.Conn, |
| err error, |
| allowTimeout bool, |
| ) { |
| c.mu.Lock() |
| c.releaseConn(ctx, cn, err, allowTimeout) |
| c.mu.Unlock() |
| } |
| |
| func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) { |
| if c.cn != cn { |
| return |
| } |
| if isBadConn(err, allowTimeout) { |
| c.reconnect(ctx, err) |
| } |
| } |
| |
| func (c *PubSub) reconnect(ctx context.Context, reason error) { |
| _ = c.closeTheCn(reason) |
| _, _ = c.conn(ctx, nil) |
| } |
| |
| func (c *PubSub) closeTheCn(reason error) error { |
| if c.cn == nil { |
| return nil |
| } |
| if !c.closed { |
| internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason) |
| } |
| err := c.closeConn(c.cn) |
| c.cn = nil |
| return err |
| } |
| |
| func (c *PubSub) Close() error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| |
| if c.closed { |
| return pool.ErrClosed |
| } |
| c.closed = true |
| close(c.exit) |
| |
| return c.closeTheCn(pool.ErrClosed) |
| } |
| |
| // Subscribe the client to the specified channels. It returns |
| // empty subscription if there are no channels. |
| func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| |
| err := c.subscribe(ctx, "subscribe", channels...) |
| if c.channels == nil { |
| c.channels = make(map[string]struct{}) |
| } |
| for _, s := range channels { |
| c.channels[s] = struct{}{} |
| } |
| return err |
| } |
| |
| // PSubscribe the client to the given patterns. It returns |
| // empty subscription if there are no patterns. |
| func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| |
| err := c.subscribe(ctx, "psubscribe", patterns...) |
| if c.patterns == nil { |
| c.patterns = make(map[string]struct{}) |
| } |
| for _, s := range patterns { |
| c.patterns[s] = struct{}{} |
| } |
| return err |
| } |
| |
| // Unsubscribe the client from the given channels, or from all of |
| // them if none is given. |
| func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| |
| for _, channel := range channels { |
| delete(c.channels, channel) |
| } |
| err := c.subscribe(ctx, "unsubscribe", channels...) |
| return err |
| } |
| |
| // PUnsubscribe the client from the given patterns, or from all of |
| // them if none is given. |
| func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| |
| for _, pattern := range patterns { |
| delete(c.patterns, pattern) |
| } |
| err := c.subscribe(ctx, "punsubscribe", patterns...) |
| return err |
| } |
| |
| func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error { |
| cn, err := c.conn(ctx, channels) |
| if err != nil { |
| return err |
| } |
| |
| err = c._subscribe(ctx, cn, redisCmd, channels) |
| c.releaseConn(ctx, cn, err, false) |
| return err |
| } |
| |
| func (c *PubSub) Ping(ctx context.Context, payload ...string) error { |
| args := []interface{}{"ping"} |
| if len(payload) == 1 { |
| args = append(args, payload[0]) |
| } |
| cmd := NewCmd(ctx, args...) |
| |
| cn, err := c.connWithLock(ctx) |
| if err != nil { |
| return err |
| } |
| |
| err = c.writeCmd(ctx, cn, cmd) |
| c.releaseConnWithLock(ctx, cn, err, false) |
| return err |
| } |
| |
| // Subscription received after a successful subscription to channel. |
| type Subscription struct { |
| // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe". |
| Kind string |
| // Channel name we have subscribed to. |
| Channel string |
| // Number of channels we are currently subscribed to. |
| Count int |
| } |
| |
| func (m *Subscription) String() string { |
| return fmt.Sprintf("%s: %s", m.Kind, m.Channel) |
| } |
| |
| // Message received as result of a PUBLISH command issued by another client. |
| type Message struct { |
| Channel string |
| Pattern string |
| Payload string |
| PayloadSlice []string |
| } |
| |
| func (m *Message) String() string { |
| return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) |
| } |
| |
| // Pong received as result of a PING command issued by another client. |
| type Pong struct { |
| Payload string |
| } |
| |
| func (p *Pong) String() string { |
| if p.Payload != "" { |
| return fmt.Sprintf("Pong<%s>", p.Payload) |
| } |
| return "Pong" |
| } |
| |
| func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { |
| switch reply := reply.(type) { |
| case string: |
| return &Pong{ |
| Payload: reply, |
| }, nil |
| case []interface{}: |
| switch kind := reply[0].(string); kind { |
| case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": |
| // Can be nil in case of "unsubscribe". |
| channel, _ := reply[1].(string) |
| return &Subscription{ |
| Kind: kind, |
| Channel: channel, |
| Count: int(reply[2].(int64)), |
| }, nil |
| case "message": |
| switch payload := reply[2].(type) { |
| case string: |
| return &Message{ |
| Channel: reply[1].(string), |
| Payload: payload, |
| }, nil |
| case []interface{}: |
| ss := make([]string, len(payload)) |
| for i, s := range payload { |
| ss[i] = s.(string) |
| } |
| return &Message{ |
| Channel: reply[1].(string), |
| PayloadSlice: ss, |
| }, nil |
| default: |
| return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload) |
| } |
| case "pmessage": |
| return &Message{ |
| Pattern: reply[1].(string), |
| Channel: reply[2].(string), |
| Payload: reply[3].(string), |
| }, nil |
| case "pong": |
| return &Pong{ |
| Payload: reply[1].(string), |
| }, nil |
| default: |
| return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind) |
| } |
| default: |
| return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply) |
| } |
| } |
| |
| // ReceiveTimeout acts like Receive but returns an error if message |
| // is not received in time. This is low-level API and in most cases |
| // Channel should be used instead. |
| func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) { |
| if c.cmd == nil { |
| c.cmd = NewCmd(ctx) |
| } |
| |
| cn, err := c.connWithLock(ctx) |
| if err != nil { |
| return nil, err |
| } |
| |
| err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error { |
| return c.cmd.readReply(rd) |
| }) |
| |
| c.releaseConnWithLock(ctx, cn, err, timeout > 0) |
| if err != nil { |
| return nil, err |
| } |
| |
| return c.newMessage(c.cmd.Val()) |
| } |
| |
| // Receive returns a message as a Subscription, Message, Pong or error. |
| // See PubSub example for details. This is low-level API and in most cases |
| // Channel should be used instead. |
| func (c *PubSub) Receive(ctx context.Context) (interface{}, error) { |
| return c.ReceiveTimeout(ctx, 0) |
| } |
| |
| // ReceiveMessage returns a Message or error ignoring Subscription and Pong |
| // messages. This is low-level API and in most cases Channel should be used |
| // instead. |
| func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { |
| for { |
| msg, err := c.Receive(ctx) |
| if err != nil { |
| return nil, err |
| } |
| |
| switch msg := msg.(type) { |
| case *Subscription: |
| // Ignore. |
| case *Pong: |
| // Ignore. |
| case *Message: |
| return msg, nil |
| default: |
| err := fmt.Errorf("redis: unknown message: %T", msg) |
| return nil, err |
| } |
| } |
| } |
| |
| // Channel returns a Go channel for concurrently receiving messages. |
| // The channel is closed together with the PubSub. If the Go channel |
| // is blocked full for 30 seconds the message is dropped. |
| // Receive* APIs can not be used after channel is created. |
| // |
| // go-redis periodically sends ping messages to test connection health |
| // and re-subscribes if ping can not not received for 30 seconds. |
| func (c *PubSub) Channel() <-chan *Message { |
| return c.ChannelSize(100) |
| } |
| |
| // ChannelSize is like Channel, but creates a Go channel |
| // with specified buffer size. |
| func (c *PubSub) ChannelSize(size int) <-chan *Message { |
| c.chOnce.Do(func() { |
| c.initPing() |
| c.initMsgChan(size) |
| }) |
| if c.msgCh == nil { |
| err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions") |
| panic(err) |
| } |
| if cap(c.msgCh) != size { |
| err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") |
| panic(err) |
| } |
| return c.msgCh |
| } |
| |
| // ChannelWithSubscriptions is like Channel, but message type can be either |
| // *Subscription or *Message. Subscription messages can be used to detect |
| // reconnections. |
| // |
| // ChannelWithSubscriptions can not be used together with Channel or ChannelSize. |
| func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan interface{} { |
| c.chOnce.Do(func() { |
| c.initPing() |
| c.initAllChan(size) |
| }) |
| if c.allCh == nil { |
| err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel") |
| panic(err) |
| } |
| if cap(c.allCh) != size { |
| err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") |
| panic(err) |
| } |
| return c.allCh |
| } |
| |
| func (c *PubSub) getContext() context.Context { |
| if c.cmd != nil { |
| return c.cmd.ctx |
| } |
| return context.Background() |
| } |
| |
| func (c *PubSub) initPing() { |
| ctx := context.TODO() |
| c.ping = make(chan struct{}, 1) |
| go func() { |
| timer := time.NewTimer(time.Minute) |
| timer.Stop() |
| |
| healthy := true |
| for { |
| timer.Reset(pingTimeout) |
| select { |
| case <-c.ping: |
| healthy = true |
| if !timer.Stop() { |
| <-timer.C |
| } |
| case <-timer.C: |
| pingErr := c.Ping(ctx) |
| if healthy { |
| healthy = false |
| } else { |
| if pingErr == nil { |
| pingErr = errPingTimeout |
| } |
| c.mu.Lock() |
| c.reconnect(ctx, pingErr) |
| healthy = true |
| c.mu.Unlock() |
| } |
| case <-c.exit: |
| return |
| } |
| } |
| }() |
| } |
| |
| // initMsgChan must be in sync with initAllChan. |
| func (c *PubSub) initMsgChan(size int) { |
| ctx := context.TODO() |
| c.msgCh = make(chan *Message, size) |
| go func() { |
| timer := time.NewTimer(time.Minute) |
| timer.Stop() |
| |
| var errCount int |
| for { |
| msg, err := c.Receive(ctx) |
| if err != nil { |
| if err == pool.ErrClosed { |
| close(c.msgCh) |
| return |
| } |
| if errCount > 0 { |
| time.Sleep(100 * time.Millisecond) |
| } |
| errCount++ |
| continue |
| } |
| |
| errCount = 0 |
| |
| // Any message is as good as a ping. |
| select { |
| case c.ping <- struct{}{}: |
| default: |
| } |
| |
| switch msg := msg.(type) { |
| case *Subscription: |
| // Ignore. |
| case *Pong: |
| // Ignore. |
| case *Message: |
| timer.Reset(chanSendTimeout) |
| select { |
| case c.msgCh <- msg: |
| if !timer.Stop() { |
| <-timer.C |
| } |
| case <-timer.C: |
| internal.Logger.Printf( |
| c.getContext(), |
| "redis: %s channel is full for %s (message is dropped)", |
| c, |
| chanSendTimeout, |
| ) |
| } |
| default: |
| internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) |
| } |
| } |
| }() |
| } |
| |
| // initAllChan must be in sync with initMsgChan. |
| func (c *PubSub) initAllChan(size int) { |
| ctx := context.TODO() |
| c.allCh = make(chan interface{}, size) |
| go func() { |
| timer := time.NewTimer(pingTimeout) |
| timer.Stop() |
| |
| var errCount int |
| for { |
| msg, err := c.Receive(ctx) |
| if err != nil { |
| if err == pool.ErrClosed { |
| close(c.allCh) |
| return |
| } |
| if errCount > 0 { |
| time.Sleep(100 * time.Millisecond) |
| } |
| errCount++ |
| continue |
| } |
| |
| errCount = 0 |
| |
| // Any message is as good as a ping. |
| select { |
| case c.ping <- struct{}{}: |
| default: |
| } |
| |
| switch msg := msg.(type) { |
| case *Subscription: |
| c.sendMessage(msg, timer) |
| case *Pong: |
| // Ignore. |
| case *Message: |
| c.sendMessage(msg, timer) |
| default: |
| internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) |
| } |
| } |
| }() |
| } |
| |
| func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) { |
| timer.Reset(pingTimeout) |
| select { |
| case c.allCh <- msg: |
| if !timer.Stop() { |
| <-timer.C |
| } |
| case <-timer.C: |
| internal.Logger.Printf( |
| c.getContext(), |
| "redis: %s channel is full for %s (message is dropped)", c, pingTimeout) |
| } |
| } |