blob: 3c5c607a7d0b694e85efc658e382061e5fb01b65 [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001// Package grpcurl provides the core functionality exposed by the grpcurl command, for
2// dynamically connecting to a server, using the reflection service to inspect the server,
3// and invoking RPCs. The grpcurl command-line tool constructs a DescriptorSource, based
4// on the command-line parameters, and supplies an InvocationEventHandler to supply request
5// data (which can come from command-line args or the process's stdin) and to log the
6// events (to the process's stdout).
7package grpcurl
8
9import (
10 "bytes"
11 "crypto/tls"
12 "crypto/x509"
13 "encoding/base64"
14 "errors"
15 "fmt"
16 "io/ioutil"
17 "net"
Scott Baker4a35a702019-11-26 08:17:33 -080018 "os"
19 "regexp"
Zack Williamse940c7a2019-08-21 14:25:39 -070020 "sort"
21 "strings"
22
23 "github.com/golang/protobuf/proto"
24 descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
25 "github.com/golang/protobuf/ptypes"
26 "github.com/golang/protobuf/ptypes/empty"
27 "github.com/golang/protobuf/ptypes/struct"
28 "github.com/jhump/protoreflect/desc"
29 "github.com/jhump/protoreflect/desc/protoprint"
30 "github.com/jhump/protoreflect/dynamic"
31 "golang.org/x/net/context"
32 "google.golang.org/grpc"
33 "google.golang.org/grpc/credentials"
34 "google.golang.org/grpc/metadata"
35)
36
37// ListServices uses the given descriptor source to return a sorted list of fully-qualified
38// service names.
39func ListServices(source DescriptorSource) ([]string, error) {
40 svcs, err := source.ListServices()
41 if err != nil {
42 return nil, err
43 }
44 sort.Strings(svcs)
45 return svcs, nil
46}
47
48type sourceWithFiles interface {
49 GetAllFiles() ([]*desc.FileDescriptor, error)
50}
51
52var _ sourceWithFiles = (*fileSource)(nil)
53
54// GetAllFiles uses the given descriptor source to return a list of file descriptors.
55func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) {
56 var files []*desc.FileDescriptor
57 srcFiles, ok := source.(sourceWithFiles)
58
59 // If an error occurs, we still try to load as many files as we can, so that
60 // caller can decide whether to ignore error or not.
61 var firstError error
62 if ok {
63 files, firstError = srcFiles.GetAllFiles()
64 } else {
65 // Source does not implement GetAllFiles method, so use ListServices
66 // and grab files from there.
67 svcNames, err := source.ListServices()
68 if err != nil {
69 firstError = err
70 } else {
71 allFiles := map[string]*desc.FileDescriptor{}
72 for _, name := range svcNames {
73 d, err := source.FindSymbol(name)
74 if err != nil {
75 if firstError == nil {
76 firstError = err
77 }
78 } else {
79 addAllFilesToSet(d.GetFile(), allFiles)
80 }
81 }
82 files = make([]*desc.FileDescriptor, len(allFiles))
83 i := 0
84 for _, fd := range allFiles {
85 files[i] = fd
86 i++
87 }
88 }
89 }
90
91 sort.Sort(filesByName(files))
92 return files, firstError
93}
94
95type filesByName []*desc.FileDescriptor
96
97func (f filesByName) Len() int {
98 return len(f)
99}
100
101func (f filesByName) Less(i, j int) bool {
102 return f[i].GetName() < f[j].GetName()
103}
104
105func (f filesByName) Swap(i, j int) {
106 f[i], f[j] = f[j], f[i]
107}
108
109func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) {
110 if _, ok := all[fd.GetName()]; ok {
111 // already added
112 return
113 }
114 all[fd.GetName()] = fd
115 for _, dep := range fd.GetDependencies() {
116 addAllFilesToSet(dep, all)
117 }
118}
119
120// ListMethods uses the given descriptor source to return a sorted list of method names
121// for the specified fully-qualified service name.
122func ListMethods(source DescriptorSource, serviceName string) ([]string, error) {
123 dsc, err := source.FindSymbol(serviceName)
124 if err != nil {
125 return nil, err
126 }
127 if sd, ok := dsc.(*desc.ServiceDescriptor); !ok {
128 return nil, notFound("Service", serviceName)
129 } else {
130 methods := make([]string, 0, len(sd.GetMethods()))
131 for _, method := range sd.GetMethods() {
132 methods = append(methods, method.GetFullyQualifiedName())
133 }
134 sort.Strings(methods)
135 return methods, nil
136 }
137}
138
139// MetadataFromHeaders converts a list of header strings (each string in
140// "Header-Name: Header-Value" form) into metadata. If a string has a header
141// name without a value (e.g. does not contain a colon), the value is assumed
142// to be blank. Binary headers (those whose names end in "-bin") should be
143// base64-encoded. But if they cannot be base64-decoded, they will be assumed to
144// be in raw form and used as is.
145func MetadataFromHeaders(headers []string) metadata.MD {
146 md := make(metadata.MD)
147 for _, part := range headers {
148 if part != "" {
149 pieces := strings.SplitN(part, ":", 2)
150 if len(pieces) == 1 {
151 pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
152 }
153 headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
154 val := strings.TrimSpace(pieces[1])
155 if strings.HasSuffix(headerName, "-bin") {
156 if v, err := decode(val); err == nil {
157 val = v
158 }
159 }
160 md[headerName] = append(md[headerName], val)
161 }
162 }
163 return md
164}
165
Scott Baker4a35a702019-11-26 08:17:33 -0800166var envVarRegex = regexp.MustCompile(`\${\w+}`)
167
168// ExpandHeaders expands environment variables contained in the header string.
169// If no corresponding environment variable is found an error is returned.
170// TODO: Add escaping for `${`
171func ExpandHeaders(headers []string) ([]string, error) {
172 expandedHeaders := make([]string, len(headers))
173 for idx, header := range headers {
174 if header == "" {
175 continue
176 }
177 results := envVarRegex.FindAllString(header, -1)
178 if len(results) == 0 {
179 expandedHeaders[idx] = headers[idx]
180 continue
181 }
182 expandedHeader := header
183 for _, result := range results {
184 envVarName := result[2 : len(result)-1] // strip leading `${` and trailing `}`
185 envVarValue, ok := os.LookupEnv(envVarName)
186 if !ok {
187 return nil, fmt.Errorf("header %q refers to missing environment variable %q", header, envVarName)
188 }
189 expandedHeader = strings.Replace(expandedHeader, result, envVarValue, -1)
190 }
191 expandedHeaders[idx] = expandedHeader
192 }
193 return expandedHeaders, nil
194}
195
Zack Williamse940c7a2019-08-21 14:25:39 -0700196var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding}
197
198func decode(val string) (string, error) {
199 var firstErr error
200 var b []byte
201 // we are lenient and can accept any of the flavors of base64 encoding
202 for _, d := range base64Codecs {
203 var err error
204 b, err = d.DecodeString(val)
205 if err != nil {
206 if firstErr == nil {
207 firstErr = err
208 }
209 continue
210 }
211 return string(b), nil
212 }
213 return "", firstErr
214}
215
216// MetadataToString returns a string representation of the given metadata, for
217// displaying to users.
218func MetadataToString(md metadata.MD) string {
219 if len(md) == 0 {
220 return "(empty)"
221 }
222
223 keys := make([]string, 0, len(md))
224 for k := range md {
225 keys = append(keys, k)
226 }
227 sort.Strings(keys)
228
229 var b bytes.Buffer
230 first := true
231 for _, k := range keys {
232 vs := md[k]
233 for _, v := range vs {
234 if first {
235 first = false
236 } else {
237 b.WriteString("\n")
238 }
239 b.WriteString(k)
240 b.WriteString(": ")
241 if strings.HasSuffix(k, "-bin") {
242 v = base64.StdEncoding.EncodeToString([]byte(v))
243 }
244 b.WriteString(v)
245 }
246 }
247 return b.String()
248}
249
250var printer = &protoprint.Printer{
251 Compact: true,
252 OmitComments: protoprint.CommentsNonDoc,
253 SortElements: true,
254 ForceFullyQualifiedNames: true,
255}
256
257// GetDescriptorText returns a string representation of the given descriptor.
258// This returns a snippet of proto source that describes the given element.
259func GetDescriptorText(dsc desc.Descriptor, _ DescriptorSource) (string, error) {
260 // Note: DescriptorSource is not used, but remains an argument for backwards
261 // compatibility with previous implementation.
262 txt, err := printer.PrintProtoToString(dsc)
263 if err != nil {
264 return "", err
265 }
266 // callers don't expect trailing newlines
267 if txt[len(txt)-1] == '\n' {
268 txt = txt[:len(txt)-1]
269 }
270 return txt, nil
271}
272
273// EnsureExtensions uses the given descriptor source to download extensions for
274// the given message. It returns a copy of the given message, but as a dynamic
275// message that knows about all extensions known to the given descriptor source.
276func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message {
277 // load any server extensions so we can properly describe custom options
278 dsc, err := desc.LoadMessageDescriptorForMessage(msg)
279 if err != nil {
280 return msg
281 }
282
283 var ext dynamic.ExtensionRegistry
284 if err = fetchAllExtensions(source, &ext, dsc, map[string]bool{}); err != nil {
285 return msg
286 }
287
288 // convert message into dynamic message that knows about applicable extensions
289 // (that way we can show meaningful info for custom options instead of printing as unknown)
290 msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
291 dm, err := fullyConvertToDynamic(msgFactory, msg)
292 if err != nil {
293 return msg
294 }
295 return dm
296}
297
298// fetchAllExtensions recursively fetches from the server extensions for the given message type as well as
299// for all message types of nested fields. The extensions are added to the given dynamic registry of extensions
300// so that all server-known extensions can be correctly parsed by grpcurl.
301func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry, md *desc.MessageDescriptor, alreadyFetched map[string]bool) error {
302 msgTypeName := md.GetFullyQualifiedName()
303 if alreadyFetched[msgTypeName] {
304 return nil
305 }
306 alreadyFetched[msgTypeName] = true
307 if len(md.GetExtensionRanges()) > 0 {
308 fds, err := source.AllExtensionsForType(msgTypeName)
309 if err != nil {
310 return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err)
311 }
312 for _, fd := range fds {
313 if err := ext.AddExtension(fd); err != nil {
314 return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err)
315 }
316 }
317 }
318 // recursively fetch extensions for the types of any message fields
319 for _, fd := range md.GetFields() {
320 if fd.GetMessageType() != nil {
321 err := fetchAllExtensions(source, ext, fd.GetMessageType(), alreadyFetched)
322 if err != nil {
323 return err
324 }
325 }
326 }
327 return nil
328}
329
330// fullConvertToDynamic attempts to convert the given message to a dynamic message as well
331// as any nested messages it may contain as field values. If the given message factory has
332// extensions registered that were not known when the given message was parsed, this effectively
333// allows re-parsing to identify those extensions.
334func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (proto.Message, error) {
335 if _, ok := msg.(*dynamic.Message); ok {
336 return msg, nil // already a dynamic message
337 }
338 md, err := desc.LoadMessageDescriptorForMessage(msg)
339 if err != nil {
340 return nil, err
341 }
342 newMsg := msgFact.NewMessage(md)
343 dm, ok := newMsg.(*dynamic.Message)
344 if !ok {
345 // if message factory didn't produce a dynamic message, then we should leave msg as is
346 return msg, nil
347 }
348
349 if err := dm.ConvertFrom(msg); err != nil {
350 return nil, err
351 }
352
353 // recursively convert all field values, too
354 for _, fd := range md.GetFields() {
355 if fd.IsMap() {
356 if fd.GetMapValueType().GetMessageType() != nil {
357 m := dm.GetField(fd).(map[interface{}]interface{})
358 for k, v := range m {
359 // keys can't be nested messages; so we only need to recurse through map values, not keys
360 newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
361 if err != nil {
362 return nil, err
363 }
364 dm.PutMapField(fd, k, newVal)
365 }
366 }
367 } else if fd.IsRepeated() {
368 if fd.GetMessageType() != nil {
369 s := dm.GetField(fd).([]interface{})
370 for i, e := range s {
371 newVal, err := fullyConvertToDynamic(msgFact, e.(proto.Message))
372 if err != nil {
373 return nil, err
374 }
375 dm.SetRepeatedField(fd, i, newVal)
376 }
377 }
378 } else {
379 if fd.GetMessageType() != nil {
380 v := dm.GetField(fd)
381 newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
382 if err != nil {
383 return nil, err
384 }
385 dm.SetField(fd, newVal)
386 }
387 }
388 }
389 return dm, nil
390}
391
392// MakeTemplate returns a message instance for the given descriptor that is a
393// suitable template for creating an instance of that message in JSON. In
394// particular, it ensures that any repeated fields (which include map fields)
395// are not empty, so they will render with a single element (to show the types
396// and optionally nested fields). It also ensures that nested messages are not
397// nil by setting them to a message that is also fleshed out as a template
398// message.
399func MakeTemplate(md *desc.MessageDescriptor) proto.Message {
400 return makeTemplate(md, nil)
401}
402
403func makeTemplate(md *desc.MessageDescriptor, path []*desc.MessageDescriptor) proto.Message {
404 switch md.GetFullyQualifiedName() {
405 case "google.protobuf.Any":
406 // empty type URL is not allowed by JSON representation
407 // so we must give it a dummy type
408 msg, _ := ptypes.MarshalAny(&empty.Empty{})
409 return msg
410 case "google.protobuf.Value":
411 // unset kind is not allowed by JSON representation
412 // so we must give it something
413 return &structpb.Value{
414 Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
415 Fields: map[string]*structpb.Value{
416 "google.protobuf.Value": {Kind: &structpb.Value_StringValue{
417 StringValue: "supports arbitrary JSON",
418 }},
419 },
420 }},
421 }
422 case "google.protobuf.ListValue":
423 return &structpb.ListValue{
424 Values: []*structpb.Value{
425 {
426 Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
427 Fields: map[string]*structpb.Value{
428 "google.protobuf.ListValue": {Kind: &structpb.Value_StringValue{
429 StringValue: "is an array of arbitrary JSON values",
430 }},
431 },
432 }},
433 },
434 },
435 }
436 case "google.protobuf.Struct":
437 return &structpb.Struct{
438 Fields: map[string]*structpb.Value{
439 "google.protobuf.Struct": {Kind: &structpb.Value_StringValue{
440 StringValue: "supports arbitrary JSON objects",
441 }},
442 },
443 }
444 }
445
446 dm := dynamic.NewMessage(md)
447
448 // if the message is a recursive structure, we don't want to blow the stack
449 for _, seen := range path {
450 if seen == md {
451 // already visited this type; avoid infinite recursion
452 return dm
453 }
454 }
455 path = append(path, dm.GetMessageDescriptor())
456
457 // for repeated fields, add a single element with default value
458 // and for message fields, add a message with all default fields
459 // that also has non-nil message and non-empty repeated fields
460
461 for _, fd := range dm.GetMessageDescriptor().GetFields() {
462 if fd.IsRepeated() {
463 switch fd.GetType() {
464 case descpb.FieldDescriptorProto_TYPE_FIXED32,
465 descpb.FieldDescriptorProto_TYPE_UINT32:
466 dm.AddRepeatedField(fd, uint32(0))
467
468 case descpb.FieldDescriptorProto_TYPE_SFIXED32,
469 descpb.FieldDescriptorProto_TYPE_SINT32,
470 descpb.FieldDescriptorProto_TYPE_INT32,
471 descpb.FieldDescriptorProto_TYPE_ENUM:
472 dm.AddRepeatedField(fd, int32(0))
473
474 case descpb.FieldDescriptorProto_TYPE_FIXED64,
475 descpb.FieldDescriptorProto_TYPE_UINT64:
476 dm.AddRepeatedField(fd, uint64(0))
477
478 case descpb.FieldDescriptorProto_TYPE_SFIXED64,
479 descpb.FieldDescriptorProto_TYPE_SINT64,
480 descpb.FieldDescriptorProto_TYPE_INT64:
481 dm.AddRepeatedField(fd, int64(0))
482
483 case descpb.FieldDescriptorProto_TYPE_STRING:
484 dm.AddRepeatedField(fd, "")
485
486 case descpb.FieldDescriptorProto_TYPE_BYTES:
487 dm.AddRepeatedField(fd, []byte{})
488
489 case descpb.FieldDescriptorProto_TYPE_BOOL:
490 dm.AddRepeatedField(fd, false)
491
492 case descpb.FieldDescriptorProto_TYPE_FLOAT:
493 dm.AddRepeatedField(fd, float32(0))
494
495 case descpb.FieldDescriptorProto_TYPE_DOUBLE:
496 dm.AddRepeatedField(fd, float64(0))
497
498 case descpb.FieldDescriptorProto_TYPE_MESSAGE,
499 descpb.FieldDescriptorProto_TYPE_GROUP:
500 dm.AddRepeatedField(fd, makeTemplate(fd.GetMessageType(), path))
501 }
502 } else if fd.GetMessageType() != nil {
503 dm.SetField(fd, makeTemplate(fd.GetMessageType(), path))
504 }
505 }
506 return dm
507}
508
509// ClientTransportCredentials builds transport credentials for a gRPC client using the
510// given properties. If cacertFile is blank, only standard trusted certs are used to
511// verify the server certs. If clientCertFile is blank, the client will not use a client
512// certificate. If clientCertFile is not blank then clientKeyFile must not be blank.
513func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (credentials.TransportCredentials, error) {
514 var tlsConf tls.Config
515
516 if clientCertFile != "" {
517 // Load the client certificates from disk
518 certificate, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile)
519 if err != nil {
520 return nil, fmt.Errorf("could not load client key pair: %v", err)
521 }
522 tlsConf.Certificates = []tls.Certificate{certificate}
523 }
524
525 if insecureSkipVerify {
526 tlsConf.InsecureSkipVerify = true
527 } else if cacertFile != "" {
528 // Create a certificate pool from the certificate authority
529 certPool := x509.NewCertPool()
530 ca, err := ioutil.ReadFile(cacertFile)
531 if err != nil {
532 return nil, fmt.Errorf("could not read ca certificate: %v", err)
533 }
534
535 // Append the certificates from the CA
536 if ok := certPool.AppendCertsFromPEM(ca); !ok {
537 return nil, errors.New("failed to append ca certs")
538 }
539
540 tlsConf.RootCAs = certPool
541 }
542
543 return credentials.NewTLS(&tlsConf), nil
544}
545
546// ServerTransportCredentials builds transport credentials for a gRPC server using the
547// given properties. If cacertFile is blank, the server will not request client certs
548// unless requireClientCerts is true. When requireClientCerts is false and cacertFile is
549// not blank, the server will verify client certs when presented, but will not require
550// client certs. The serverCertFile and serverKeyFile must both not be blank.
551func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string, requireClientCerts bool) (credentials.TransportCredentials, error) {
552 var tlsConf tls.Config
553 // TODO(jh): Remove this line once https://github.com/golang/go/issues/28779 is fixed
554 // in Go tip. Until then, the recently merged TLS 1.3 support breaks the TLS tests.
555 tlsConf.MaxVersion = tls.VersionTLS12
556
557 // Load the server certificates from disk
558 certificate, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile)
559 if err != nil {
560 return nil, fmt.Errorf("could not load key pair: %v", err)
561 }
562 tlsConf.Certificates = []tls.Certificate{certificate}
563
564 if cacertFile != "" {
565 // Create a certificate pool from the certificate authority
566 certPool := x509.NewCertPool()
567 ca, err := ioutil.ReadFile(cacertFile)
568 if err != nil {
569 return nil, fmt.Errorf("could not read ca certificate: %v", err)
570 }
571
572 // Append the certificates from the CA
573 if ok := certPool.AppendCertsFromPEM(ca); !ok {
574 return nil, errors.New("failed to append ca certs")
575 }
576
577 tlsConf.ClientCAs = certPool
578 }
579
580 if requireClientCerts {
581 tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
582 } else if cacertFile != "" {
583 tlsConf.ClientAuth = tls.VerifyClientCertIfGiven
584 } else {
585 tlsConf.ClientAuth = tls.NoClientCert
586 }
587
588 return credentials.NewTLS(&tlsConf), nil
589}
590
591// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
592// and blocking until the returned connection is ready. If the given credentials are nil, the
593// connection will be insecure (plain-text).
594func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
595 // grpc.Dial doesn't provide any information on permanent connection errors (like
596 // TLS handshake failures). So in order to provide good error messages, we need a
597 // custom dialer that can provide that info. That means we manage the TLS handshake.
598 result := make(chan interface{}, 1)
599
600 writeResult := func(res interface{}) {
601 // non-blocking write: we only need the first result
602 select {
603 case result <- res:
604 default:
605 }
606 }
607
608 dialer := func(ctx context.Context, address string) (net.Conn, error) {
609 conn, err := (&net.Dialer{}).DialContext(ctx, network, address)
610 if err != nil {
611 writeResult(err)
612 return nil, err
613 }
614 if creds != nil {
615 conn, _, err = creds.ClientHandshake(ctx, address, conn)
616 if err != nil {
617 writeResult(err)
618 return nil, err
619 }
620 }
621 return conn, nil
622 }
623
624 // Even with grpc.FailOnNonTempDialError, this call will usually timeout in
625 // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
626 // know when we're done. So we run it in a goroutine and then use result
627 // channel to either get the channel or fail-fast.
628 go func() {
629 opts = append(opts,
630 grpc.WithBlock(),
631 grpc.FailOnNonTempDialError(true),
632 grpc.WithContextDialer(dialer),
633 grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
634 )
635 conn, err := grpc.DialContext(ctx, address, opts...)
636 var res interface{}
637 if err != nil {
638 res = err
639 } else {
640 res = conn
641 }
642 writeResult(res)
643 }()
644
645 select {
646 case res := <-result:
647 if conn, ok := res.(*grpc.ClientConn); ok {
648 return conn, nil
649 }
650 return nil, res.(error)
651 case <-ctx.Done():
652 return nil, ctx.Err()
653 }
654}