| package grpcurl |
| |
| import ( |
| "bytes" |
| "fmt" |
| "io" |
| "strings" |
| "sync" |
| "sync/atomic" |
| |
| "github.com/golang/protobuf/jsonpb" |
| "github.com/golang/protobuf/proto" |
| "github.com/jhump/protoreflect/desc" |
| "github.com/jhump/protoreflect/dynamic" |
| "github.com/jhump/protoreflect/dynamic/grpcdynamic" |
| "github.com/jhump/protoreflect/grpcreflect" |
| "golang.org/x/net/context" |
| "google.golang.org/grpc" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/metadata" |
| "google.golang.org/grpc/status" |
| ) |
| |
| // InvocationEventHandler is a bag of callbacks for handling events that occur in the course |
| // of invoking an RPC. The handler also provides request data that is sent. The callbacks are |
| // generally called in the order they are listed below. |
| type InvocationEventHandler interface { |
| // OnResolveMethod is called with a descriptor of the method that is being invoked. |
| OnResolveMethod(*desc.MethodDescriptor) |
| // OnSendHeaders is called with the request metadata that is being sent. |
| OnSendHeaders(metadata.MD) |
| // OnReceiveHeaders is called when response headers have been received. |
| OnReceiveHeaders(metadata.MD) |
| // OnReceiveResponse is called for each response message received. |
| OnReceiveResponse(proto.Message) |
| // OnReceiveTrailers is called when response trailers and final RPC status have been received. |
| OnReceiveTrailers(*status.Status, metadata.MD) |
| } |
| |
| // RequestMessageSupplier is a function that is called to retrieve request |
| // messages for a GRPC operation. This type is deprecated and will be removed in |
| // a future release. |
| // |
| // Deprecated: This is only used with the deprecated InvokeRpc. Instead, use |
| // RequestSupplier with InvokeRPC. |
| type RequestMessageSupplier func() ([]byte, error) |
| |
| // InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated |
| // and will be removed in a future release. It just delegates to the similarly named InvokeRPC |
| // method, whose signature is only slightly different. |
| // |
| // Deprecated: use InvokeRPC instead. |
| func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string, |
| headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error { |
| |
| return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error { |
| // New function is almost identical, but the request supplier function works differently. |
| // So we adapt the logic here to maintain compatibility. |
| data, err := requestData() |
| if err != nil { |
| return err |
| } |
| return jsonpb.Unmarshal(bytes.NewReader(data), m) |
| }) |
| } |
| |
| // RequestSupplier is a function that is called to populate messages for a gRPC operation. The |
| // function should populate the given message or return a non-nil error. If the supplier has no |
| // more messages, it should return io.EOF. When it returns io.EOF, it should not in any way |
| // modify the given message argument. |
| type RequestSupplier func(proto.Message) error |
| |
| // InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source |
| // is used to determine the type of method and the type of request and response message. The given |
| // headers are sent as request metadata. Methods on the given event handler are called as the |
| // invocation proceeds. |
| // |
| // The given requestData function supplies the actual data to send. It should return io.EOF when |
| // there is no more request data. If the method being invoked is a unary or server-streaming RPC |
| // (e.g. exactly one request message) and there is no request data (e.g. the first invocation of |
| // the function returns io.EOF), then an empty request message is sent. |
| // |
| // If the requestData function and the given event handler coordinate or share any state, they should |
| // be thread-safe. This is because the requestData function may be called from a different goroutine |
| // than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where |
| // one goroutine sends request messages and another consumes the response messages). |
| func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string, |
| headers []string, handler InvocationEventHandler, requestData RequestSupplier) error { |
| |
| md := MetadataFromHeaders(headers) |
| |
| svc, mth := parseSymbol(methodName) |
| if svc == "" || mth == "" { |
| return fmt.Errorf("given method name %q is not in expected format: 'service/method' or 'service.method'", methodName) |
| } |
| dsc, err := source.FindSymbol(svc) |
| if err != nil { |
| if isNotFoundError(err) { |
| return fmt.Errorf("target server does not expose service %q", svc) |
| } |
| return fmt.Errorf("failed to query for service descriptor %q: %v", svc, err) |
| } |
| sd, ok := dsc.(*desc.ServiceDescriptor) |
| if !ok { |
| return fmt.Errorf("target server does not expose service %q", svc) |
| } |
| mtd := sd.FindMethodByName(mth) |
| if mtd == nil { |
| return fmt.Errorf("service %q does not include a method named %q", svc, mth) |
| } |
| |
| handler.OnResolveMethod(mtd) |
| |
| // we also download any applicable extensions so we can provide full support for parsing user-provided data |
| var ext dynamic.ExtensionRegistry |
| alreadyFetched := map[string]bool{} |
| if err = fetchAllExtensions(source, &ext, mtd.GetInputType(), alreadyFetched); err != nil { |
| return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetInputType().GetFullyQualifiedName(), err) |
| } |
| if err = fetchAllExtensions(source, &ext, mtd.GetOutputType(), alreadyFetched); err != nil { |
| return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetOutputType().GetFullyQualifiedName(), err) |
| } |
| |
| msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext) |
| req := msgFactory.NewMessage(mtd.GetInputType()) |
| |
| handler.OnSendHeaders(md) |
| ctx = metadata.NewOutgoingContext(ctx, md) |
| |
| stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory) |
| ctx, cancel := context.WithCancel(ctx) |
| defer cancel() |
| |
| if mtd.IsClientStreaming() && mtd.IsServerStreaming() { |
| return invokeBidi(ctx, stub, mtd, handler, requestData, req) |
| } else if mtd.IsClientStreaming() { |
| return invokeClientStream(ctx, stub, mtd, handler, requestData, req) |
| } else if mtd.IsServerStreaming() { |
| return invokeServerStream(ctx, stub, mtd, handler, requestData, req) |
| } else { |
| return invokeUnary(ctx, stub, mtd, handler, requestData, req) |
| } |
| } |
| |
| func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, |
| requestData RequestSupplier, req proto.Message) error { |
| |
| err := requestData(req) |
| if err != nil && err != io.EOF { |
| return fmt.Errorf("error getting request data: %v", err) |
| } |
| if err != io.EOF { |
| // verify there is no second message, which is a usage error |
| err := requestData(req) |
| if err == nil { |
| return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) |
| } else if err != io.EOF { |
| return fmt.Errorf("error getting request data: %v", err) |
| } |
| } |
| |
| // Now we can actually invoke the RPC! |
| var respHeaders metadata.MD |
| var respTrailers metadata.MD |
| resp, err := stub.InvokeRpc(ctx, md, req, grpc.Trailer(&respTrailers), grpc.Header(&respHeaders)) |
| |
| stat, ok := status.FromError(err) |
| if !ok { |
| // Error codes sent from the server will get printed differently below. |
| // So just bail for other kinds of errors here. |
| return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) |
| } |
| |
| handler.OnReceiveHeaders(respHeaders) |
| |
| if stat.Code() == codes.OK { |
| handler.OnReceiveResponse(resp) |
| } |
| |
| handler.OnReceiveTrailers(stat, respTrailers) |
| |
| return nil |
| } |
| |
| func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, |
| requestData RequestSupplier, req proto.Message) error { |
| |
| // invoke the RPC! |
| str, err := stub.InvokeRpcClientStream(ctx, md) |
| |
| // Upload each request message in the stream |
| var resp proto.Message |
| for err == nil { |
| err = requestData(req) |
| if err == io.EOF { |
| resp, err = str.CloseAndReceive() |
| break |
| } |
| if err != nil { |
| return fmt.Errorf("error getting request data: %v", err) |
| } |
| |
| err = str.SendMsg(req) |
| if err == io.EOF { |
| // We get EOF on send if the server says "go away" |
| // We have to use CloseAndReceive to get the actual code |
| resp, err = str.CloseAndReceive() |
| break |
| } |
| |
| req.Reset() |
| } |
| |
| // finally, process response data |
| stat, ok := status.FromError(err) |
| if !ok { |
| // Error codes sent from the server will get printed differently below. |
| // So just bail for other kinds of errors here. |
| return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) |
| } |
| |
| if respHeaders, err := str.Header(); err == nil { |
| handler.OnReceiveHeaders(respHeaders) |
| } |
| |
| if stat.Code() == codes.OK { |
| handler.OnReceiveResponse(resp) |
| } |
| |
| handler.OnReceiveTrailers(stat, str.Trailer()) |
| |
| return nil |
| } |
| |
| func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, |
| requestData RequestSupplier, req proto.Message) error { |
| |
| err := requestData(req) |
| if err != nil && err != io.EOF { |
| return fmt.Errorf("error getting request data: %v", err) |
| } |
| if err != io.EOF { |
| // verify there is no second message, which is a usage error |
| err := requestData(req) |
| if err == nil { |
| return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName()) |
| } else if err != io.EOF { |
| return fmt.Errorf("error getting request data: %v", err) |
| } |
| } |
| |
| // Now we can actually invoke the RPC! |
| str, err := stub.InvokeRpcServerStream(ctx, md, req) |
| |
| if respHeaders, err := str.Header(); err == nil { |
| handler.OnReceiveHeaders(respHeaders) |
| } |
| |
| // Download each response message |
| for err == nil { |
| var resp proto.Message |
| resp, err = str.RecvMsg() |
| if err != nil { |
| if err == io.EOF { |
| err = nil |
| } |
| break |
| } |
| handler.OnReceiveResponse(resp) |
| } |
| |
| stat, ok := status.FromError(err) |
| if !ok { |
| // Error codes sent from the server will get printed differently below. |
| // So just bail for other kinds of errors here. |
| return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) |
| } |
| |
| handler.OnReceiveTrailers(stat, str.Trailer()) |
| |
| return nil |
| } |
| |
| func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler, |
| requestData RequestSupplier, req proto.Message) error { |
| |
| ctx, cancel := context.WithCancel(ctx) |
| defer cancel() |
| |
| // invoke the RPC! |
| str, err := stub.InvokeRpcBidiStream(ctx, md) |
| |
| var wg sync.WaitGroup |
| var sendErr atomic.Value |
| |
| defer wg.Wait() |
| |
| if err == nil { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| |
| // Concurrently upload each request message in the stream |
| var err error |
| for err == nil { |
| err = requestData(req) |
| |
| if err == io.EOF { |
| err = str.CloseSend() |
| break |
| } |
| if err != nil { |
| err = fmt.Errorf("error getting request data: %v", err) |
| cancel() |
| break |
| } |
| |
| err = str.SendMsg(req) |
| |
| req.Reset() |
| } |
| |
| if err != nil { |
| sendErr.Store(err) |
| } |
| }() |
| } |
| |
| if respHeaders, err := str.Header(); err == nil { |
| handler.OnReceiveHeaders(respHeaders) |
| } |
| |
| // Download each response message |
| for err == nil { |
| var resp proto.Message |
| resp, err = str.RecvMsg() |
| if err != nil { |
| if err == io.EOF { |
| err = nil |
| } |
| break |
| } |
| handler.OnReceiveResponse(resp) |
| } |
| |
| if se, ok := sendErr.Load().(error); ok && se != io.EOF { |
| err = se |
| } |
| |
| stat, ok := status.FromError(err) |
| if !ok { |
| // Error codes sent from the server will get printed differently below. |
| // So just bail for other kinds of errors here. |
| return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err) |
| } |
| |
| handler.OnReceiveTrailers(stat, str.Trailer()) |
| |
| return nil |
| } |
| |
| type notFoundError string |
| |
| func notFound(kind, name string) error { |
| return notFoundError(fmt.Sprintf("%s not found: %s", kind, name)) |
| } |
| |
| func (e notFoundError) Error() string { |
| return string(e) |
| } |
| |
| func isNotFoundError(err error) bool { |
| if grpcreflect.IsElementNotFoundError(err) { |
| return true |
| } |
| _, ok := err.(notFoundError) |
| return ok |
| } |
| |
| func parseSymbol(svcAndMethod string) (string, string) { |
| pos := strings.LastIndex(svcAndMethod, "/") |
| if pos < 0 { |
| pos = strings.LastIndex(svcAndMethod, ".") |
| if pos < 0 { |
| return "", "" |
| } |
| } |
| return svcAndMethod[:pos], svcAndMethod[pos+1:] |
| } |