blob: 62d8312018bffc4528b4a4304bd0e5c48150e7d4 [file] [log] [blame]
nikesh.krishnan6dd882b2023-03-14 10:02:41 +05301// Copyright 2016 Michal Witkowski. All Rights Reserved.
2// See LICENSE for licensing terms.
3
4package grpc_retry
5
6import (
7 "context"
8 "fmt"
9 "io"
10 "sync"
11 "time"
12
13 "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
14 "golang.org/x/net/trace"
15 "google.golang.org/grpc"
16 "google.golang.org/grpc/codes"
17 "google.golang.org/grpc/metadata"
18 "google.golang.org/grpc/status"
19)
20
21const (
22 AttemptMetadataKey = "x-retry-attempty"
23)
24
25// UnaryClientInterceptor returns a new retrying unary client interceptor.
26//
27// The default configuration of the interceptor is to not retry *at all*. This behaviour can be
28// changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
29func UnaryClientInterceptor(optFuncs ...CallOption) grpc.UnaryClientInterceptor {
30 intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
31 return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
32 grpcOpts, retryOpts := filterCallOptions(opts)
33 callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
34 // short circuit for simplicity, and avoiding allocations.
35 if callOpts.max == 0 {
36 return invoker(parentCtx, method, req, reply, cc, grpcOpts...)
37 }
38 var lastErr error
39 for attempt := uint(0); attempt < callOpts.max; attempt++ {
40 if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil {
41 return err
42 }
43 callCtx := perCallContext(parentCtx, callOpts, attempt)
44 lastErr = invoker(callCtx, method, req, reply, cc, grpcOpts...)
45 // TODO(mwitkow): Maybe dial and transport errors should be retriable?
46 if lastErr == nil {
47 return nil
48 }
49 logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr)
50 if isContextError(lastErr) {
51 if parentCtx.Err() != nil {
52 logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err())
53 // its the parent context deadline or cancellation.
54 return lastErr
55 } else if callOpts.perCallTimeout != 0 {
56 // We have set a perCallTimeout in the retry middleware, which would result in a context error if
57 // the deadline was exceeded, in which case try again.
58 logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt)
59 continue
60 }
61 }
62 if !isRetriable(lastErr, callOpts) {
63 return lastErr
64 }
65 }
66 return lastErr
67 }
68}
69
70// StreamClientInterceptor returns a new retrying stream client interceptor for server side streaming calls.
71//
72// The default configuration of the interceptor is to not retry *at all*. This behaviour can be
73// changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
74//
75// Retry logic is available *only for ServerStreams*, i.e. 1:n streams, as the internal logic needs
76// to buffer the messages sent by the client. If retry is enabled on any other streams (ClientStreams,
77// BidiStreams), the retry interceptor will fail the call.
78func StreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientInterceptor {
79 intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
80 return func(parentCtx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
81 grpcOpts, retryOpts := filterCallOptions(opts)
82 callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
83 // short circuit for simplicity, and avoiding allocations.
84 if callOpts.max == 0 {
85 return streamer(parentCtx, desc, cc, method, grpcOpts...)
86 }
87 if desc.ClientStreams {
88 return nil, status.Errorf(codes.Unimplemented, "grpc_retry: cannot retry on ClientStreams, set grpc_retry.Disable()")
89 }
90
91 var lastErr error
92 for attempt := uint(0); attempt < callOpts.max; attempt++ {
93 if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil {
94 return nil, err
95 }
96 callCtx := perCallContext(parentCtx, callOpts, 0)
97
98 var newStreamer grpc.ClientStream
99 newStreamer, lastErr = streamer(callCtx, desc, cc, method, grpcOpts...)
100 if lastErr == nil {
101 retryingStreamer := &serverStreamingRetryingStream{
102 ClientStream: newStreamer,
103 callOpts: callOpts,
104 parentCtx: parentCtx,
105 streamerCall: func(ctx context.Context) (grpc.ClientStream, error) {
106 return streamer(ctx, desc, cc, method, grpcOpts...)
107 },
108 }
109 return retryingStreamer, nil
110 }
111
112 logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr)
113 if isContextError(lastErr) {
114 if parentCtx.Err() != nil {
115 logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err())
116 // its the parent context deadline or cancellation.
117 return nil, lastErr
118 } else if callOpts.perCallTimeout != 0 {
119 // We have set a perCallTimeout in the retry middleware, which would result in a context error if
120 // the deadline was exceeded, in which case try again.
121 logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt)
122 continue
123 }
124 }
125 if !isRetriable(lastErr, callOpts) {
126 return nil, lastErr
127 }
128 }
129 return nil, lastErr
130 }
131}
132
133// type serverStreamingRetryingStream is the implementation of grpc.ClientStream that acts as a
134// proxy to the underlying call. If any of the RecvMsg() calls fail, it will try to reestablish
135// a new ClientStream according to the retry policy.
136type serverStreamingRetryingStream struct {
137 grpc.ClientStream
138 bufferedSends []interface{} // single message that the client can sen
139 receivedGood bool // indicates whether any prior receives were successful
140 wasClosedSend bool // indicates that CloseSend was closed
141 parentCtx context.Context
142 callOpts *options
143 streamerCall func(ctx context.Context) (grpc.ClientStream, error)
144 mu sync.RWMutex
145}
146
147func (s *serverStreamingRetryingStream) setStream(clientStream grpc.ClientStream) {
148 s.mu.Lock()
149 s.ClientStream = clientStream
150 s.mu.Unlock()
151}
152
153func (s *serverStreamingRetryingStream) getStream() grpc.ClientStream {
154 s.mu.RLock()
155 defer s.mu.RUnlock()
156 return s.ClientStream
157}
158
159func (s *serverStreamingRetryingStream) SendMsg(m interface{}) error {
160 s.mu.Lock()
161 s.bufferedSends = append(s.bufferedSends, m)
162 s.mu.Unlock()
163 return s.getStream().SendMsg(m)
164}
165
166func (s *serverStreamingRetryingStream) CloseSend() error {
167 s.mu.Lock()
168 s.wasClosedSend = true
169 s.mu.Unlock()
170 return s.getStream().CloseSend()
171}
172
173func (s *serverStreamingRetryingStream) Header() (metadata.MD, error) {
174 return s.getStream().Header()
175}
176
177func (s *serverStreamingRetryingStream) Trailer() metadata.MD {
178 return s.getStream().Trailer()
179}
180
181func (s *serverStreamingRetryingStream) RecvMsg(m interface{}) error {
182 attemptRetry, lastErr := s.receiveMsgAndIndicateRetry(m)
183 if !attemptRetry {
184 return lastErr // success or hard failure
185 }
186 // We start off from attempt 1, because zeroth was already made on normal SendMsg().
187 for attempt := uint(1); attempt < s.callOpts.max; attempt++ {
188 if err := waitRetryBackoff(attempt, s.parentCtx, s.callOpts); err != nil {
189 return err
190 }
191 callCtx := perCallContext(s.parentCtx, s.callOpts, attempt)
192 newStream, err := s.reestablishStreamAndResendBuffer(callCtx)
193 if err != nil {
194 // Retry dial and transport errors of establishing stream as grpc doesn't retry.
195 if isRetriable(err, s.callOpts) {
196 continue
197 }
198 return err
199 }
200
201 s.setStream(newStream)
202 attemptRetry, lastErr = s.receiveMsgAndIndicateRetry(m)
203 //fmt.Printf("Received message and indicate: %v %v\n", attemptRetry, lastErr)
204 if !attemptRetry {
205 return lastErr
206 }
207 }
208 return lastErr
209}
210
211func (s *serverStreamingRetryingStream) receiveMsgAndIndicateRetry(m interface{}) (bool, error) {
212 s.mu.RLock()
213 wasGood := s.receivedGood
214 s.mu.RUnlock()
215 err := s.getStream().RecvMsg(m)
216 if err == nil || err == io.EOF {
217 s.mu.Lock()
218 s.receivedGood = true
219 s.mu.Unlock()
220 return false, err
221 } else if wasGood {
222 // previous RecvMsg in the stream succeeded, no retry logic should interfere
223 return false, err
224 }
225 if isContextError(err) {
226 if s.parentCtx.Err() != nil {
227 logTrace(s.parentCtx, "grpc_retry parent context error: %v", s.parentCtx.Err())
228 return false, err
229 } else if s.callOpts.perCallTimeout != 0 {
230 // We have set a perCallTimeout in the retry middleware, which would result in a context error if
231 // the deadline was exceeded, in which case try again.
232 logTrace(s.parentCtx, "grpc_retry context error from retry call")
233 return true, err
234 }
235 }
236 return isRetriable(err, s.callOpts), err
237}
238
239func (s *serverStreamingRetryingStream) reestablishStreamAndResendBuffer(
240 callCtx context.Context,
241) (grpc.ClientStream, error) {
242 s.mu.RLock()
243 bufferedSends := s.bufferedSends
244 s.mu.RUnlock()
245 newStream, err := s.streamerCall(callCtx)
246 if err != nil {
247 logTrace(callCtx, "grpc_retry failed redialing new stream: %v", err)
248 return nil, err
249 }
250 for _, msg := range bufferedSends {
251 if err := newStream.SendMsg(msg); err != nil {
252 logTrace(callCtx, "grpc_retry failed resending message: %v", err)
253 return nil, err
254 }
255 }
256 if err := newStream.CloseSend(); err != nil {
257 logTrace(callCtx, "grpc_retry failed CloseSend on new stream %v", err)
258 return nil, err
259 }
260 return newStream, nil
261}
262
263func waitRetryBackoff(attempt uint, parentCtx context.Context, callOpts *options) error {
264 var waitTime time.Duration = 0
265 if attempt > 0 {
266 waitTime = callOpts.backoffFunc(parentCtx, attempt)
267 }
268 if waitTime > 0 {
269 logTrace(parentCtx, "grpc_retry attempt: %d, backoff for %v", attempt, waitTime)
270 timer := time.NewTimer(waitTime)
271 select {
272 case <-parentCtx.Done():
273 timer.Stop()
274 return contextErrToGrpcErr(parentCtx.Err())
275 case <-timer.C:
276 }
277 }
278 return nil
279}
280
281func isRetriable(err error, callOpts *options) bool {
282 errCode := status.Code(err)
283 if isContextError(err) {
284 // context errors are not retriable based on user settings.
285 return false
286 }
287 for _, code := range callOpts.codes {
288 if code == errCode {
289 return true
290 }
291 }
292 return false
293}
294
295func isContextError(err error) bool {
296 code := status.Code(err)
297 return code == codes.DeadlineExceeded || code == codes.Canceled
298}
299
300func perCallContext(parentCtx context.Context, callOpts *options, attempt uint) context.Context {
301 ctx := parentCtx
302 if callOpts.perCallTimeout != 0 {
303 ctx, _ = context.WithTimeout(ctx, callOpts.perCallTimeout)
304 }
305 if attempt > 0 && callOpts.includeHeader {
306 mdClone := metautils.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, fmt.Sprintf("%d", attempt))
307 ctx = mdClone.ToOutgoing(ctx)
308 }
309 return ctx
310}
311
312func contextErrToGrpcErr(err error) error {
313 switch err {
314 case context.DeadlineExceeded:
315 return status.Error(codes.DeadlineExceeded, err.Error())
316 case context.Canceled:
317 return status.Error(codes.Canceled, err.Error())
318 default:
319 return status.Error(codes.Unknown, err.Error())
320 }
321}
322
323func logTrace(ctx context.Context, format string, a ...interface{}) {
324 tr, ok := trace.FromContext(ctx)
325 if !ok {
326 return
327 }
328 tr.LazyPrintf(format, a...)
329}