blob: d2f16cb4db1a835e90fc999642e14a8e89c904b9 [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001package grpcurl
2
3import (
4 "bytes"
5 "fmt"
6 "io"
7 "strings"
8 "sync"
9 "sync/atomic"
10
11 "github.com/golang/protobuf/jsonpb"
12 "github.com/golang/protobuf/proto"
13 "github.com/jhump/protoreflect/desc"
14 "github.com/jhump/protoreflect/dynamic"
15 "github.com/jhump/protoreflect/dynamic/grpcdynamic"
16 "github.com/jhump/protoreflect/grpcreflect"
17 "golang.org/x/net/context"
18 "google.golang.org/grpc"
19 "google.golang.org/grpc/codes"
20 "google.golang.org/grpc/metadata"
21 "google.golang.org/grpc/status"
22)
23
24// InvocationEventHandler is a bag of callbacks for handling events that occur in the course
25// of invoking an RPC. The handler also provides request data that is sent. The callbacks are
26// generally called in the order they are listed below.
27type InvocationEventHandler interface {
28 // OnResolveMethod is called with a descriptor of the method that is being invoked.
29 OnResolveMethod(*desc.MethodDescriptor)
30 // OnSendHeaders is called with the request metadata that is being sent.
31 OnSendHeaders(metadata.MD)
32 // OnReceiveHeaders is called when response headers have been received.
33 OnReceiveHeaders(metadata.MD)
34 // OnReceiveResponse is called for each response message received.
35 OnReceiveResponse(proto.Message)
36 // OnReceiveTrailers is called when response trailers and final RPC status have been received.
37 OnReceiveTrailers(*status.Status, metadata.MD)
38}
39
40// RequestMessageSupplier is a function that is called to retrieve request
41// messages for a GRPC operation. This type is deprecated and will be removed in
42// a future release.
43//
44// Deprecated: This is only used with the deprecated InvokeRpc. Instead, use
45// RequestSupplier with InvokeRPC.
46type RequestMessageSupplier func() ([]byte, error)
47
48// InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated
49// and will be removed in a future release. It just delegates to the similarly named InvokeRPC
50// method, whose signature is only slightly different.
51//
52// Deprecated: use InvokeRPC instead.
53func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string,
54 headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error {
55
56 return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error {
57 // New function is almost identical, but the request supplier function works differently.
58 // So we adapt the logic here to maintain compatibility.
59 data, err := requestData()
60 if err != nil {
61 return err
62 }
63 return jsonpb.Unmarshal(bytes.NewReader(data), m)
64 })
65}
66
67// RequestSupplier is a function that is called to populate messages for a gRPC operation. The
68// function should populate the given message or return a non-nil error. If the supplier has no
69// more messages, it should return io.EOF. When it returns io.EOF, it should not in any way
70// modify the given message argument.
71type RequestSupplier func(proto.Message) error
72
73// InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source
74// is used to determine the type of method and the type of request and response message. The given
75// headers are sent as request metadata. Methods on the given event handler are called as the
76// invocation proceeds.
77//
78// The given requestData function supplies the actual data to send. It should return io.EOF when
79// there is no more request data. If the method being invoked is a unary or server-streaming RPC
80// (e.g. exactly one request message) and there is no request data (e.g. the first invocation of
81// the function returns io.EOF), then an empty request message is sent.
82//
83// If the requestData function and the given event handler coordinate or share any state, they should
84// be thread-safe. This is because the requestData function may be called from a different goroutine
85// than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where
86// one goroutine sends request messages and another consumes the response messages).
87func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string,
88 headers []string, handler InvocationEventHandler, requestData RequestSupplier) error {
89
90 md := MetadataFromHeaders(headers)
91
92 svc, mth := parseSymbol(methodName)
93 if svc == "" || mth == "" {
94 return fmt.Errorf("given method name %q is not in expected format: 'service/method' or 'service.method'", methodName)
95 }
96 dsc, err := source.FindSymbol(svc)
97 if err != nil {
98 if isNotFoundError(err) {
99 return fmt.Errorf("target server does not expose service %q", svc)
100 }
101 return fmt.Errorf("failed to query for service descriptor %q: %v", svc, err)
102 }
103 sd, ok := dsc.(*desc.ServiceDescriptor)
104 if !ok {
105 return fmt.Errorf("target server does not expose service %q", svc)
106 }
107 mtd := sd.FindMethodByName(mth)
108 if mtd == nil {
109 return fmt.Errorf("service %q does not include a method named %q", svc, mth)
110 }
111
112 handler.OnResolveMethod(mtd)
113
114 // we also download any applicable extensions so we can provide full support for parsing user-provided data
115 var ext dynamic.ExtensionRegistry
116 alreadyFetched := map[string]bool{}
117 if err = fetchAllExtensions(source, &ext, mtd.GetInputType(), alreadyFetched); err != nil {
118 return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetInputType().GetFullyQualifiedName(), err)
119 }
120 if err = fetchAllExtensions(source, &ext, mtd.GetOutputType(), alreadyFetched); err != nil {
121 return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetOutputType().GetFullyQualifiedName(), err)
122 }
123
124 msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
125 req := msgFactory.NewMessage(mtd.GetInputType())
126
127 handler.OnSendHeaders(md)
128 ctx = metadata.NewOutgoingContext(ctx, md)
129
130 stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory)
131 ctx, cancel := context.WithCancel(ctx)
132 defer cancel()
133
134 if mtd.IsClientStreaming() && mtd.IsServerStreaming() {
135 return invokeBidi(ctx, stub, mtd, handler, requestData, req)
136 } else if mtd.IsClientStreaming() {
137 return invokeClientStream(ctx, stub, mtd, handler, requestData, req)
138 } else if mtd.IsServerStreaming() {
139 return invokeServerStream(ctx, stub, mtd, handler, requestData, req)
140 } else {
141 return invokeUnary(ctx, stub, mtd, handler, requestData, req)
142 }
143}
144
145func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
146 requestData RequestSupplier, req proto.Message) error {
147
148 err := requestData(req)
149 if err != nil && err != io.EOF {
150 return fmt.Errorf("error getting request data: %v", err)
151 }
152 if err != io.EOF {
153 // verify there is no second message, which is a usage error
154 err := requestData(req)
155 if err == nil {
156 return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
157 } else if err != io.EOF {
158 return fmt.Errorf("error getting request data: %v", err)
159 }
160 }
161
162 // Now we can actually invoke the RPC!
163 var respHeaders metadata.MD
164 var respTrailers metadata.MD
165 resp, err := stub.InvokeRpc(ctx, md, req, grpc.Trailer(&respTrailers), grpc.Header(&respHeaders))
166
167 stat, ok := status.FromError(err)
168 if !ok {
169 // Error codes sent from the server will get printed differently below.
170 // So just bail for other kinds of errors here.
171 return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
172 }
173
174 handler.OnReceiveHeaders(respHeaders)
175
176 if stat.Code() == codes.OK {
177 handler.OnReceiveResponse(resp)
178 }
179
180 handler.OnReceiveTrailers(stat, respTrailers)
181
182 return nil
183}
184
185func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
186 requestData RequestSupplier, req proto.Message) error {
187
188 // invoke the RPC!
189 str, err := stub.InvokeRpcClientStream(ctx, md)
190
191 // Upload each request message in the stream
192 var resp proto.Message
193 for err == nil {
194 err = requestData(req)
195 if err == io.EOF {
196 resp, err = str.CloseAndReceive()
197 break
198 }
199 if err != nil {
200 return fmt.Errorf("error getting request data: %v", err)
201 }
202
203 err = str.SendMsg(req)
204 if err == io.EOF {
205 // We get EOF on send if the server says "go away"
206 // We have to use CloseAndReceive to get the actual code
207 resp, err = str.CloseAndReceive()
208 break
209 }
210
211 req.Reset()
212 }
213
214 // finally, process response data
215 stat, ok := status.FromError(err)
216 if !ok {
217 // Error codes sent from the server will get printed differently below.
218 // So just bail for other kinds of errors here.
219 return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
220 }
221
222 if respHeaders, err := str.Header(); err == nil {
223 handler.OnReceiveHeaders(respHeaders)
224 }
225
226 if stat.Code() == codes.OK {
227 handler.OnReceiveResponse(resp)
228 }
229
230 handler.OnReceiveTrailers(stat, str.Trailer())
231
232 return nil
233}
234
235func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
236 requestData RequestSupplier, req proto.Message) error {
237
238 err := requestData(req)
239 if err != nil && err != io.EOF {
240 return fmt.Errorf("error getting request data: %v", err)
241 }
242 if err != io.EOF {
243 // verify there is no second message, which is a usage error
244 err := requestData(req)
245 if err == nil {
246 return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
247 } else if err != io.EOF {
248 return fmt.Errorf("error getting request data: %v", err)
249 }
250 }
251
252 // Now we can actually invoke the RPC!
253 str, err := stub.InvokeRpcServerStream(ctx, md, req)
254
255 if respHeaders, err := str.Header(); err == nil {
256 handler.OnReceiveHeaders(respHeaders)
257 }
258
259 // Download each response message
260 for err == nil {
261 var resp proto.Message
262 resp, err = str.RecvMsg()
263 if err != nil {
264 if err == io.EOF {
265 err = nil
266 }
267 break
268 }
269 handler.OnReceiveResponse(resp)
270 }
271
272 stat, ok := status.FromError(err)
273 if !ok {
274 // Error codes sent from the server will get printed differently below.
275 // So just bail for other kinds of errors here.
276 return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
277 }
278
279 handler.OnReceiveTrailers(stat, str.Trailer())
280
281 return nil
282}
283
284func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
285 requestData RequestSupplier, req proto.Message) error {
286
287 ctx, cancel := context.WithCancel(ctx)
288 defer cancel()
289
290 // invoke the RPC!
291 str, err := stub.InvokeRpcBidiStream(ctx, md)
292
293 var wg sync.WaitGroup
294 var sendErr atomic.Value
295
296 defer wg.Wait()
297
298 if err == nil {
299 wg.Add(1)
300 go func() {
301 defer wg.Done()
302
303 // Concurrently upload each request message in the stream
304 var err error
305 for err == nil {
306 err = requestData(req)
307
308 if err == io.EOF {
309 err = str.CloseSend()
310 break
311 }
312 if err != nil {
313 err = fmt.Errorf("error getting request data: %v", err)
314 cancel()
315 break
316 }
317
318 err = str.SendMsg(req)
319
320 req.Reset()
321 }
322
323 if err != nil {
324 sendErr.Store(err)
325 }
326 }()
327 }
328
329 if respHeaders, err := str.Header(); err == nil {
330 handler.OnReceiveHeaders(respHeaders)
331 }
332
333 // Download each response message
334 for err == nil {
335 var resp proto.Message
336 resp, err = str.RecvMsg()
337 if err != nil {
338 if err == io.EOF {
339 err = nil
340 }
341 break
342 }
343 handler.OnReceiveResponse(resp)
344 }
345
346 if se, ok := sendErr.Load().(error); ok && se != io.EOF {
347 err = se
348 }
349
350 stat, ok := status.FromError(err)
351 if !ok {
352 // Error codes sent from the server will get printed differently below.
353 // So just bail for other kinds of errors here.
354 return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
355 }
356
357 handler.OnReceiveTrailers(stat, str.Trailer())
358
359 return nil
360}
361
362type notFoundError string
363
364func notFound(kind, name string) error {
365 return notFoundError(fmt.Sprintf("%s not found: %s", kind, name))
366}
367
368func (e notFoundError) Error() string {
369 return string(e)
370}
371
372func isNotFoundError(err error) bool {
373 if grpcreflect.IsElementNotFoundError(err) {
374 return true
375 }
376 _, ok := err.(notFoundError)
377 return ok
378}
379
380func parseSymbol(svcAndMethod string) (string, string) {
381 pos := strings.LastIndex(svcAndMethod, "/")
382 if pos < 0 {
383 pos = strings.LastIndex(svcAndMethod, ".")
384 if pos < 0 {
385 return "", ""
386 }
387 }
388 return svcAndMethod[:pos], svcAndMethod[pos+1:]
389}