Zack Williams | e940c7a | 2019-08-21 14:25:39 -0700 | [diff] [blame] | 1 | package grpcurl |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "fmt" |
| 6 | "io" |
| 7 | "strings" |
| 8 | "sync" |
| 9 | "sync/atomic" |
| 10 | |
| 11 | "github.com/golang/protobuf/jsonpb" |
| 12 | "github.com/golang/protobuf/proto" |
| 13 | "github.com/jhump/protoreflect/desc" |
| 14 | "github.com/jhump/protoreflect/dynamic" |
| 15 | "github.com/jhump/protoreflect/dynamic/grpcdynamic" |
| 16 | "github.com/jhump/protoreflect/grpcreflect" |
| 17 | "golang.org/x/net/context" |
| 18 | "google.golang.org/grpc" |
| 19 | "google.golang.org/grpc/codes" |
| 20 | "google.golang.org/grpc/metadata" |
| 21 | "google.golang.org/grpc/status" |
| 22 | ) |
| 23 | |
| 24 | // InvocationEventHandler is a bag of callbacks for handling events that occur in the course |
| 25 | // of invoking an RPC. The handler also provides request data that is sent. The callbacks are |
| 26 | // generally called in the order they are listed below. |
| 27 | type InvocationEventHandler interface { |
| 28 | // OnResolveMethod is called with a descriptor of the method that is being invoked. |
| 29 | OnResolveMethod(*desc.MethodDescriptor) |
| 30 | // OnSendHeaders is called with the request metadata that is being sent. |
| 31 | OnSendHeaders(metadata.MD) |
| 32 | // OnReceiveHeaders is called when response headers have been received. |
| 33 | OnReceiveHeaders(metadata.MD) |
| 34 | // OnReceiveResponse is called for each response message received. |
| 35 | OnReceiveResponse(proto.Message) |
| 36 | // OnReceiveTrailers is called when response trailers and final RPC status have been received. |
| 37 | OnReceiveTrailers(*status.Status, metadata.MD) |
| 38 | } |
| 39 | |
| 40 | // RequestMessageSupplier is a function that is called to retrieve request |
| 41 | // messages for a GRPC operation. This type is deprecated and will be removed in |
| 42 | // a future release. |
| 43 | // |
| 44 | // Deprecated: This is only used with the deprecated InvokeRpc. Instead, use |
| 45 | // RequestSupplier with InvokeRPC. |
| 46 | type RequestMessageSupplier func() ([]byte, error) |
| 47 | |
| 48 | // InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated |
| 49 | // and will be removed in a future release. It just delegates to the similarly named InvokeRPC |
| 50 | // method, whose signature is only slightly different. |
| 51 | // |
| 52 | // Deprecated: use InvokeRPC instead. |
| 53 | func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string, |
| 54 | headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error { |
| 55 | |
| 56 | return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error { |
| 57 | // New function is almost identical, but the request supplier function works differently. |
| 58 | // So we adapt the logic here to maintain compatibility. |
| 59 | data, err := requestData() |
| 60 | if err != nil { |
| 61 | return err |
| 62 | } |
| 63 | return jsonpb.Unmarshal(bytes.NewReader(data), m) |
| 64 | }) |
| 65 | } |
| 66 | |
| 67 | // RequestSupplier is a function that is called to populate messages for a gRPC operation. The |
| 68 | // function should populate the given message or return a non-nil error. If the supplier has no |
| 69 | // more messages, it should return io.EOF. When it returns io.EOF, it should not in any way |
| 70 | // modify the given message argument. |
| 71 | type RequestSupplier func(proto.Message) error |
| 72 | |
| 73 | // InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source |
| 74 | // is used to determine the type of method and the type of request and response message. The given |
| 75 | // headers are sent as request metadata. Methods on the given event handler are called as the |
| 76 | // invocation proceeds. |
| 77 | // |
| 78 | // The given requestData function supplies the actual data to send. It should return io.EOF when |
| 79 | // there is no more request data. If the method being invoked is a unary or server-streaming RPC |
| 80 | // (e.g. exactly one request message) and there is no request data (e.g. the first invocation of |
| 81 | // the function returns io.EOF), then an empty request message is sent. |
| 82 | // |
| 83 | // If the requestData function and the given event handler coordinate or share any state, they should |
| 84 | // be thread-safe. This is because the requestData function may be called from a different goroutine |
| 85 | // than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where |
| 86 | // one goroutine sends request messages and another consumes the response messages). |
| 87 | func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string, |
| 88 | headers []string, handler InvocationEventHandler, requestData RequestSupplier) error { |
| 89 | |
| 90 | md := MetadataFromHeaders(headers) |
| 91 | |
| 92 | svc, mth := parseSymbol(methodName) |
| 93 | if svc == "" || mth == "" { |
| 94 | return fmt.Errorf("given method name %q is not in expected format: 'service/method' or 'service.method'", methodName) |
| 95 | } |
| 96 | dsc, err := source.FindSymbol(svc) |
| 97 | if err != nil { |
| 98 | if isNotFoundError(err) { |
| 99 | return fmt.Errorf("target server does not expose service %q", svc) |
| 100 | } |
| 101 | return fmt.Errorf("failed to query for service descriptor %q: %v", svc, err) |
| 102 | } |
| 103 | sd, ok := dsc.(*desc.ServiceDescriptor) |
| 104 | if !ok { |
| 105 | return fmt.Errorf("target server does not expose service %q", svc) |
| 106 | } |
| 107 | mtd := sd.FindMethodByName(mth) |
| 108 | if mtd == nil { |
| 109 | return fmt.Errorf("service %q does not include a method named %q", svc, mth) |
| 110 | } |
| 111 | |
| 112 | handler.OnResolveMethod(mtd) |
| 113 | |
| 114 | // we also download any applicable extensions so we can provide full support for parsing user-provided data |
| 115 | var ext dynamic.ExtensionRegistry |
| 116 | alreadyFetched := map[string]bool{} |
| 117 | if err = fetchAllExtensions(source, &ext, mtd.GetInputType(), alreadyFetched); err != nil { |
| 118 | return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetInputType().GetFullyQualifiedName(), err) |
| 119 | } |
| 120 | if err = fetchAllExtensions(source, &ext, mtd.GetOutputType(), alreadyFetched); err != nil { |
| 121 | return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetOutputType().GetFullyQualifiedName(), err) |
| 122 | } |
| 123 | |
| 124 | msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext) |
| 125 | req := msgFactory.NewMessage(mtd.GetInputType()) |
| 126 | |
| 127 | handler.OnSendHeaders(md) |
| 128 | ctx = metadata.NewOutgoingContext(ctx, md) |
| 129 | |
| 130 | stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory) |
| 131 | ctx, cancel := context.WithCancel(ctx) |
| 132 | defer cancel() |
| 133 | |
| 134 | if mtd.IsClientStreaming() && mtd.IsServerStreaming() { |
| 135 | return invokeBidi(ctx, stub, mtd, handler, requestData, req) |
| 136 | } else if mtd.IsClientStreaming() { |
| 137 | return invokeClientStream(ctx, stub, mtd, handler, requestData, req) |
| 138 | } else if mtd.IsServerStreaming() { |
| 139 | return invokeServerStream(ctx, stub, mtd, handler, requestData, req) |
| 140 | } else { |
| 141 | return invokeUnary(ctx, stub, mtd, handler, requestData, req) |
| 142 | } |
| 143 | } |
| 144 | |
| 145 | func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, |
| 146 | requestData RequestSupplier, req proto.Message) error { |
| 147 | |
| 148 | err := requestData(req) |
| 149 | if err != nil && err != io.EOF { |
| 150 | return fmt.Errorf("error getting request data: %v", err) |
| 151 | } |
| 152 | if err != io.EOF { |
| 153 | // verify there is no second message, which is a usage error |
| 154 | err := requestData(req) |
| 155 | if err == nil { |
| 156 | return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) |
| 157 | } else if err != io.EOF { |
| 158 | return fmt.Errorf("error getting request data: %v", err) |
| 159 | } |
| 160 | } |
| 161 | |
| 162 | // Now we can actually invoke the RPC! |
| 163 | var respHeaders metadata.MD |
| 164 | var respTrailers metadata.MD |
| 165 | resp, err := stub.InvokeRpc(ctx, md, req, grpc.Trailer(&respTrailers), grpc.Header(&respHeaders)) |
| 166 | |
| 167 | stat, ok := status.FromError(err) |
| 168 | if !ok { |
| 169 | // Error codes sent from the server will get printed differently below. |
| 170 | // So just bail for other kinds of errors here. |
| 171 | return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) |
| 172 | } |
| 173 | |
| 174 | handler.OnReceiveHeaders(respHeaders) |
| 175 | |
| 176 | if stat.Code() == codes.OK { |
| 177 | handler.OnReceiveResponse(resp) |
| 178 | } |
| 179 | |
| 180 | handler.OnReceiveTrailers(stat, respTrailers) |
| 181 | |
| 182 | return nil |
| 183 | } |
| 184 | |
| 185 | func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, |
| 186 | requestData RequestSupplier, req proto.Message) error { |
| 187 | |
| 188 | // invoke the RPC! |
| 189 | str, err := stub.InvokeRpcClientStream(ctx, md) |
| 190 | |
| 191 | // Upload each request message in the stream |
| 192 | var resp proto.Message |
| 193 | for err == nil { |
| 194 | err = requestData(req) |
| 195 | if err == io.EOF { |
| 196 | resp, err = str.CloseAndReceive() |
| 197 | break |
| 198 | } |
| 199 | if err != nil { |
| 200 | return fmt.Errorf("error getting request data: %v", err) |
| 201 | } |
| 202 | |
| 203 | err = str.SendMsg(req) |
| 204 | if err == io.EOF { |
| 205 | // We get EOF on send if the server says "go away" |
| 206 | // We have to use CloseAndReceive to get the actual code |
| 207 | resp, err = str.CloseAndReceive() |
| 208 | break |
| 209 | } |
| 210 | |
| 211 | req.Reset() |
| 212 | } |
| 213 | |
| 214 | // finally, process response data |
| 215 | stat, ok := status.FromError(err) |
| 216 | if !ok { |
| 217 | // Error codes sent from the server will get printed differently below. |
| 218 | // So just bail for other kinds of errors here. |
| 219 | return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) |
| 220 | } |
| 221 | |
| 222 | if respHeaders, err := str.Header(); err == nil { |
| 223 | handler.OnReceiveHeaders(respHeaders) |
| 224 | } |
| 225 | |
| 226 | if stat.Code() == codes.OK { |
| 227 | handler.OnReceiveResponse(resp) |
| 228 | } |
| 229 | |
| 230 | handler.OnReceiveTrailers(stat, str.Trailer()) |
| 231 | |
| 232 | return nil |
| 233 | } |
| 234 | |
| 235 | func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, |
| 236 | requestData RequestSupplier, req proto.Message) error { |
| 237 | |
| 238 | err := requestData(req) |
| 239 | if err != nil && err != io.EOF { |
| 240 | return fmt.Errorf("error getting request data: %v", err) |
| 241 | } |
| 242 | if err != io.EOF { |
| 243 | // verify there is no second message, which is a usage error |
| 244 | err := requestData(req) |
| 245 | if err == nil { |
| 246 | return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) |
| 247 | } else if err != io.EOF { |
| 248 | return fmt.Errorf("error getting request data: %v", err) |
| 249 | } |
| 250 | } |
| 251 | |
| 252 | // Now we can actually invoke the RPC! |
| 253 | str, err := stub.InvokeRpcServerStream(ctx, md, req) |
| 254 | |
| 255 | if respHeaders, err := str.Header(); err == nil { |
| 256 | handler.OnReceiveHeaders(respHeaders) |
| 257 | } |
| 258 | |
| 259 | // Download each response message |
| 260 | for err == nil { |
| 261 | var resp proto.Message |
| 262 | resp, err = str.RecvMsg() |
| 263 | if err != nil { |
| 264 | if err == io.EOF { |
| 265 | err = nil |
| 266 | } |
| 267 | break |
| 268 | } |
| 269 | handler.OnReceiveResponse(resp) |
| 270 | } |
| 271 | |
| 272 | stat, ok := status.FromError(err) |
| 273 | if !ok { |
| 274 | // Error codes sent from the server will get printed differently below. |
| 275 | // So just bail for other kinds of errors here. |
| 276 | return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) |
| 277 | } |
| 278 | |
| 279 | handler.OnReceiveTrailers(stat, str.Trailer()) |
| 280 | |
| 281 | return nil |
| 282 | } |
| 283 | |
| 284 | func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, |
| 285 | requestData RequestSupplier, req proto.Message) error { |
| 286 | |
| 287 | ctx, cancel := context.WithCancel(ctx) |
| 288 | defer cancel() |
| 289 | |
| 290 | // invoke the RPC! |
| 291 | str, err := stub.InvokeRpcBidiStream(ctx, md) |
| 292 | |
| 293 | var wg sync.WaitGroup |
| 294 | var sendErr atomic.Value |
| 295 | |
| 296 | defer wg.Wait() |
| 297 | |
| 298 | if err == nil { |
| 299 | wg.Add(1) |
| 300 | go func() { |
| 301 | defer wg.Done() |
| 302 | |
| 303 | // Concurrently upload each request message in the stream |
| 304 | var err error |
| 305 | for err == nil { |
| 306 | err = requestData(req) |
| 307 | |
| 308 | if err == io.EOF { |
| 309 | err = str.CloseSend() |
| 310 | break |
| 311 | } |
| 312 | if err != nil { |
| 313 | err = fmt.Errorf("error getting request data: %v", err) |
| 314 | cancel() |
| 315 | break |
| 316 | } |
| 317 | |
| 318 | err = str.SendMsg(req) |
| 319 | |
| 320 | req.Reset() |
| 321 | } |
| 322 | |
| 323 | if err != nil { |
| 324 | sendErr.Store(err) |
| 325 | } |
| 326 | }() |
| 327 | } |
| 328 | |
| 329 | if respHeaders, err := str.Header(); err == nil { |
| 330 | handler.OnReceiveHeaders(respHeaders) |
| 331 | } |
| 332 | |
| 333 | // Download each response message |
| 334 | for err == nil { |
| 335 | var resp proto.Message |
| 336 | resp, err = str.RecvMsg() |
| 337 | if err != nil { |
| 338 | if err == io.EOF { |
| 339 | err = nil |
| 340 | } |
| 341 | break |
| 342 | } |
| 343 | handler.OnReceiveResponse(resp) |
| 344 | } |
| 345 | |
| 346 | if se, ok := sendErr.Load().(error); ok && se != io.EOF { |
| 347 | err = se |
| 348 | } |
| 349 | |
| 350 | stat, ok := status.FromError(err) |
| 351 | if !ok { |
| 352 | // Error codes sent from the server will get printed differently below. |
| 353 | // So just bail for other kinds of errors here. |
| 354 | return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) |
| 355 | } |
| 356 | |
| 357 | handler.OnReceiveTrailers(stat, str.Trailer()) |
| 358 | |
| 359 | return nil |
| 360 | } |
| 361 | |
| 362 | type notFoundError string |
| 363 | |
| 364 | func notFound(kind, name string) error { |
| 365 | return notFoundError(fmt.Sprintf("%s not found: %s", kind, name)) |
| 366 | } |
| 367 | |
| 368 | func (e notFoundError) Error() string { |
| 369 | return string(e) |
| 370 | } |
| 371 | |
| 372 | func isNotFoundError(err error) bool { |
| 373 | if grpcreflect.IsElementNotFoundError(err) { |
| 374 | return true |
| 375 | } |
| 376 | _, ok := err.(notFoundError) |
| 377 | return ok |
| 378 | } |
| 379 | |
| 380 | func parseSymbol(svcAndMethod string) (string, string) { |
| 381 | pos := strings.LastIndex(svcAndMethod, "/") |
| 382 | if pos < 0 { |
| 383 | pos = strings.LastIndex(svcAndMethod, ".") |
| 384 | if pos < 0 { |
| 385 | return "", "" |
| 386 | } |
| 387 | } |
| 388 | return svcAndMethod[:pos], svcAndMethod[pos+1:] |
| 389 | } |