blob: 45a2f5f49a713608c6e88060461a0ab084be6024 [file] [log] [blame]
khenaidooab1f7bd2019-11-14 14:00:27 -05001// Copyright 2016 Michal Witkowski. All Rights Reserved.
2// See LICENSE for licensing terms.
3
4// gRPC Server Interceptor chaining middleware.
5
6package grpc_middleware
7
8import (
9 "golang.org/x/net/context"
10 "google.golang.org/grpc"
11)
12
13// ChainUnaryServer creates a single interceptor out of a chain of many interceptors.
14//
15// Execution is done in left-to-right order, including passing of context.
16// For example ChainUnaryServer(one, two, three) will execute one before two before three, and three
17// will see context changes of one and two.
18func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
19 n := len(interceptors)
20
21 if n > 1 {
22 lastI := n - 1
23 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
24 var (
25 chainHandler grpc.UnaryHandler
26 curI int
27 )
28
29 chainHandler = func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
30 if curI == lastI {
31 return handler(currentCtx, currentReq)
32 }
33 curI++
34 resp, err := interceptors[curI](currentCtx, currentReq, info, chainHandler)
35 curI--
36 return resp, err
37 }
38
39 return interceptors[0](ctx, req, info, chainHandler)
40 }
41 }
42
43 if n == 1 {
44 return interceptors[0]
45 }
46
47 // n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
48 return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
49 return handler(ctx, req)
50 }
51}
52
53// ChainStreamServer creates a single interceptor out of a chain of many interceptors.
54//
55// Execution is done in left-to-right order, including passing of context.
56// For example ChainUnaryServer(one, two, three) will execute one before two before three.
57// If you want to pass context between interceptors, use WrapServerStream.
58func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
59 n := len(interceptors)
60
61 if n > 1 {
62 lastI := n - 1
63 return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
64 var (
65 chainHandler grpc.StreamHandler
66 curI int
67 )
68
69 chainHandler = func(currentSrv interface{}, currentStream grpc.ServerStream) error {
70 if curI == lastI {
71 return handler(currentSrv, currentStream)
72 }
73 curI++
74 err := interceptors[curI](currentSrv, currentStream, info, chainHandler)
75 curI--
76 return err
77 }
78
79 return interceptors[0](srv, stream, info, chainHandler)
80 }
81 }
82
83 if n == 1 {
84 return interceptors[0]
85 }
86
87 // n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
88 return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
89 return handler(srv, stream)
90 }
91}
92
93// ChainUnaryClient creates a single interceptor out of a chain of many interceptors.
94//
95// Execution is done in left-to-right order, including passing of context.
96// For example ChainUnaryClient(one, two, three) will execute one before two before three.
97func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
98 n := len(interceptors)
99
100 if n > 1 {
101 lastI := n - 1
102 return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
103 var (
104 chainHandler grpc.UnaryInvoker
105 curI int
106 )
107
108 chainHandler = func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error {
109 if curI == lastI {
110 return invoker(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentOpts...)
111 }
112 curI++
113 err := interceptors[curI](currentCtx, currentMethod, currentReq, currentRepl, currentConn, chainHandler, currentOpts...)
114 curI--
115 return err
116 }
117
118 return interceptors[0](ctx, method, req, reply, cc, chainHandler, opts...)
119 }
120 }
121
122 if n == 1 {
123 return interceptors[0]
124 }
125
126 // n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
127 return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
128 return invoker(ctx, method, req, reply, cc, opts...)
129 }
130}
131
132// ChainStreamClient creates a single interceptor out of a chain of many interceptors.
133//
134// Execution is done in left-to-right order, including passing of context.
135// For example ChainStreamClient(one, two, three) will execute one before two before three.
136func ChainStreamClient(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
137 n := len(interceptors)
138
139 if n > 1 {
140 lastI := n - 1
141 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
142 var (
143 chainHandler grpc.Streamer
144 curI int
145 )
146
147 chainHandler = func(currentCtx context.Context, currentDesc *grpc.StreamDesc, currentConn *grpc.ClientConn, currentMethod string, currentOpts ...grpc.CallOption) (grpc.ClientStream, error) {
148 if curI == lastI {
149 return streamer(currentCtx, currentDesc, currentConn, currentMethod, currentOpts...)
150 }
151 curI++
152 stream, err := interceptors[curI](currentCtx, currentDesc, currentConn, currentMethod, chainHandler, currentOpts...)
153 curI--
154 return stream, err
155 }
156
157 return interceptors[0](ctx, desc, cc, method, chainHandler, opts...)
158 }
159 }
160
161 if n == 1 {
162 return interceptors[0]
163 }
164
165 // n == 0; Dummy interceptor maintained for backward compatibility to avoid returning nil.
166 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
167 return streamer(ctx, desc, cc, method, opts...)
168 }
169}
170
171// Chain creates a single interceptor out of a chain of many interceptors.
172//
173// WithUnaryServerChain is a grpc.Server config option that accepts multiple unary interceptors.
174// Basically syntactic sugar.
175func WithUnaryServerChain(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
176 return grpc.UnaryInterceptor(ChainUnaryServer(interceptors...))
177}
178
179// WithStreamServerChain is a grpc.Server config option that accepts multiple stream interceptors.
180// Basically syntactic sugar.
181func WithStreamServerChain(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
182 return grpc.StreamInterceptor(ChainStreamServer(interceptors...))
183}