blob: 64947de63f6c001f26dfa04d1a4191dd2572ac10 [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001// Package grpcurl provides the core functionality exposed by the grpcurl command, for
2// dynamically connecting to a server, using the reflection service to inspect the server,
3// and invoking RPCs. The grpcurl command-line tool constructs a DescriptorSource, based
4// on the command-line parameters, and supplies an InvocationEventHandler to supply request
5// data (which can come from command-line args or the process's stdin) and to log the
6// events (to the process's stdout).
7package grpcurl
8
9import (
10 "bytes"
11 "crypto/tls"
12 "crypto/x509"
13 "encoding/base64"
14 "errors"
15 "fmt"
16 "io/ioutil"
17 "net"
18 "sort"
19 "strings"
20
21 "github.com/golang/protobuf/proto"
22 descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
23 "github.com/golang/protobuf/ptypes"
24 "github.com/golang/protobuf/ptypes/empty"
25 "github.com/golang/protobuf/ptypes/struct"
26 "github.com/jhump/protoreflect/desc"
27 "github.com/jhump/protoreflect/desc/protoprint"
28 "github.com/jhump/protoreflect/dynamic"
29 "golang.org/x/net/context"
30 "google.golang.org/grpc"
31 "google.golang.org/grpc/credentials"
32 "google.golang.org/grpc/metadata"
33)
34
35// ListServices uses the given descriptor source to return a sorted list of fully-qualified
36// service names.
37func ListServices(source DescriptorSource) ([]string, error) {
38 svcs, err := source.ListServices()
39 if err != nil {
40 return nil, err
41 }
42 sort.Strings(svcs)
43 return svcs, nil
44}
45
46type sourceWithFiles interface {
47 GetAllFiles() ([]*desc.FileDescriptor, error)
48}
49
50var _ sourceWithFiles = (*fileSource)(nil)
51
52// GetAllFiles uses the given descriptor source to return a list of file descriptors.
53func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) {
54 var files []*desc.FileDescriptor
55 srcFiles, ok := source.(sourceWithFiles)
56
57 // If an error occurs, we still try to load as many files as we can, so that
58 // caller can decide whether to ignore error or not.
59 var firstError error
60 if ok {
61 files, firstError = srcFiles.GetAllFiles()
62 } else {
63 // Source does not implement GetAllFiles method, so use ListServices
64 // and grab files from there.
65 svcNames, err := source.ListServices()
66 if err != nil {
67 firstError = err
68 } else {
69 allFiles := map[string]*desc.FileDescriptor{}
70 for _, name := range svcNames {
71 d, err := source.FindSymbol(name)
72 if err != nil {
73 if firstError == nil {
74 firstError = err
75 }
76 } else {
77 addAllFilesToSet(d.GetFile(), allFiles)
78 }
79 }
80 files = make([]*desc.FileDescriptor, len(allFiles))
81 i := 0
82 for _, fd := range allFiles {
83 files[i] = fd
84 i++
85 }
86 }
87 }
88
89 sort.Sort(filesByName(files))
90 return files, firstError
91}
92
93type filesByName []*desc.FileDescriptor
94
95func (f filesByName) Len() int {
96 return len(f)
97}
98
99func (f filesByName) Less(i, j int) bool {
100 return f[i].GetName() < f[j].GetName()
101}
102
103func (f filesByName) Swap(i, j int) {
104 f[i], f[j] = f[j], f[i]
105}
106
107func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) {
108 if _, ok := all[fd.GetName()]; ok {
109 // already added
110 return
111 }
112 all[fd.GetName()] = fd
113 for _, dep := range fd.GetDependencies() {
114 addAllFilesToSet(dep, all)
115 }
116}
117
118// ListMethods uses the given descriptor source to return a sorted list of method names
119// for the specified fully-qualified service name.
120func ListMethods(source DescriptorSource, serviceName string) ([]string, error) {
121 dsc, err := source.FindSymbol(serviceName)
122 if err != nil {
123 return nil, err
124 }
125 if sd, ok := dsc.(*desc.ServiceDescriptor); !ok {
126 return nil, notFound("Service", serviceName)
127 } else {
128 methods := make([]string, 0, len(sd.GetMethods()))
129 for _, method := range sd.GetMethods() {
130 methods = append(methods, method.GetFullyQualifiedName())
131 }
132 sort.Strings(methods)
133 return methods, nil
134 }
135}
136
137// MetadataFromHeaders converts a list of header strings (each string in
138// "Header-Name: Header-Value" form) into metadata. If a string has a header
139// name without a value (e.g. does not contain a colon), the value is assumed
140// to be blank. Binary headers (those whose names end in "-bin") should be
141// base64-encoded. But if they cannot be base64-decoded, they will be assumed to
142// be in raw form and used as is.
143func MetadataFromHeaders(headers []string) metadata.MD {
144 md := make(metadata.MD)
145 for _, part := range headers {
146 if part != "" {
147 pieces := strings.SplitN(part, ":", 2)
148 if len(pieces) == 1 {
149 pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
150 }
151 headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
152 val := strings.TrimSpace(pieces[1])
153 if strings.HasSuffix(headerName, "-bin") {
154 if v, err := decode(val); err == nil {
155 val = v
156 }
157 }
158 md[headerName] = append(md[headerName], val)
159 }
160 }
161 return md
162}
163
164var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding}
165
166func decode(val string) (string, error) {
167 var firstErr error
168 var b []byte
169 // we are lenient and can accept any of the flavors of base64 encoding
170 for _, d := range base64Codecs {
171 var err error
172 b, err = d.DecodeString(val)
173 if err != nil {
174 if firstErr == nil {
175 firstErr = err
176 }
177 continue
178 }
179 return string(b), nil
180 }
181 return "", firstErr
182}
183
184// MetadataToString returns a string representation of the given metadata, for
185// displaying to users.
186func MetadataToString(md metadata.MD) string {
187 if len(md) == 0 {
188 return "(empty)"
189 }
190
191 keys := make([]string, 0, len(md))
192 for k := range md {
193 keys = append(keys, k)
194 }
195 sort.Strings(keys)
196
197 var b bytes.Buffer
198 first := true
199 for _, k := range keys {
200 vs := md[k]
201 for _, v := range vs {
202 if first {
203 first = false
204 } else {
205 b.WriteString("\n")
206 }
207 b.WriteString(k)
208 b.WriteString(": ")
209 if strings.HasSuffix(k, "-bin") {
210 v = base64.StdEncoding.EncodeToString([]byte(v))
211 }
212 b.WriteString(v)
213 }
214 }
215 return b.String()
216}
217
218var printer = &protoprint.Printer{
219 Compact: true,
220 OmitComments: protoprint.CommentsNonDoc,
221 SortElements: true,
222 ForceFullyQualifiedNames: true,
223}
224
225// GetDescriptorText returns a string representation of the given descriptor.
226// This returns a snippet of proto source that describes the given element.
227func GetDescriptorText(dsc desc.Descriptor, _ DescriptorSource) (string, error) {
228 // Note: DescriptorSource is not used, but remains an argument for backwards
229 // compatibility with previous implementation.
230 txt, err := printer.PrintProtoToString(dsc)
231 if err != nil {
232 return "", err
233 }
234 // callers don't expect trailing newlines
235 if txt[len(txt)-1] == '\n' {
236 txt = txt[:len(txt)-1]
237 }
238 return txt, nil
239}
240
241// EnsureExtensions uses the given descriptor source to download extensions for
242// the given message. It returns a copy of the given message, but as a dynamic
243// message that knows about all extensions known to the given descriptor source.
244func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message {
245 // load any server extensions so we can properly describe custom options
246 dsc, err := desc.LoadMessageDescriptorForMessage(msg)
247 if err != nil {
248 return msg
249 }
250
251 var ext dynamic.ExtensionRegistry
252 if err = fetchAllExtensions(source, &ext, dsc, map[string]bool{}); err != nil {
253 return msg
254 }
255
256 // convert message into dynamic message that knows about applicable extensions
257 // (that way we can show meaningful info for custom options instead of printing as unknown)
258 msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
259 dm, err := fullyConvertToDynamic(msgFactory, msg)
260 if err != nil {
261 return msg
262 }
263 return dm
264}
265
266// fetchAllExtensions recursively fetches from the server extensions for the given message type as well as
267// for all message types of nested fields. The extensions are added to the given dynamic registry of extensions
268// so that all server-known extensions can be correctly parsed by grpcurl.
269func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry, md *desc.MessageDescriptor, alreadyFetched map[string]bool) error {
270 msgTypeName := md.GetFullyQualifiedName()
271 if alreadyFetched[msgTypeName] {
272 return nil
273 }
274 alreadyFetched[msgTypeName] = true
275 if len(md.GetExtensionRanges()) > 0 {
276 fds, err := source.AllExtensionsForType(msgTypeName)
277 if err != nil {
278 return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err)
279 }
280 for _, fd := range fds {
281 if err := ext.AddExtension(fd); err != nil {
282 return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err)
283 }
284 }
285 }
286 // recursively fetch extensions for the types of any message fields
287 for _, fd := range md.GetFields() {
288 if fd.GetMessageType() != nil {
289 err := fetchAllExtensions(source, ext, fd.GetMessageType(), alreadyFetched)
290 if err != nil {
291 return err
292 }
293 }
294 }
295 return nil
296}
297
298// fullConvertToDynamic attempts to convert the given message to a dynamic message as well
299// as any nested messages it may contain as field values. If the given message factory has
300// extensions registered that were not known when the given message was parsed, this effectively
301// allows re-parsing to identify those extensions.
302func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (proto.Message, error) {
303 if _, ok := msg.(*dynamic.Message); ok {
304 return msg, nil // already a dynamic message
305 }
306 md, err := desc.LoadMessageDescriptorForMessage(msg)
307 if err != nil {
308 return nil, err
309 }
310 newMsg := msgFact.NewMessage(md)
311 dm, ok := newMsg.(*dynamic.Message)
312 if !ok {
313 // if message factory didn't produce a dynamic message, then we should leave msg as is
314 return msg, nil
315 }
316
317 if err := dm.ConvertFrom(msg); err != nil {
318 return nil, err
319 }
320
321 // recursively convert all field values, too
322 for _, fd := range md.GetFields() {
323 if fd.IsMap() {
324 if fd.GetMapValueType().GetMessageType() != nil {
325 m := dm.GetField(fd).(map[interface{}]interface{})
326 for k, v := range m {
327 // keys can't be nested messages; so we only need to recurse through map values, not keys
328 newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
329 if err != nil {
330 return nil, err
331 }
332 dm.PutMapField(fd, k, newVal)
333 }
334 }
335 } else if fd.IsRepeated() {
336 if fd.GetMessageType() != nil {
337 s := dm.GetField(fd).([]interface{})
338 for i, e := range s {
339 newVal, err := fullyConvertToDynamic(msgFact, e.(proto.Message))
340 if err != nil {
341 return nil, err
342 }
343 dm.SetRepeatedField(fd, i, newVal)
344 }
345 }
346 } else {
347 if fd.GetMessageType() != nil {
348 v := dm.GetField(fd)
349 newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
350 if err != nil {
351 return nil, err
352 }
353 dm.SetField(fd, newVal)
354 }
355 }
356 }
357 return dm, nil
358}
359
360// MakeTemplate returns a message instance for the given descriptor that is a
361// suitable template for creating an instance of that message in JSON. In
362// particular, it ensures that any repeated fields (which include map fields)
363// are not empty, so they will render with a single element (to show the types
364// and optionally nested fields). It also ensures that nested messages are not
365// nil by setting them to a message that is also fleshed out as a template
366// message.
367func MakeTemplate(md *desc.MessageDescriptor) proto.Message {
368 return makeTemplate(md, nil)
369}
370
371func makeTemplate(md *desc.MessageDescriptor, path []*desc.MessageDescriptor) proto.Message {
372 switch md.GetFullyQualifiedName() {
373 case "google.protobuf.Any":
374 // empty type URL is not allowed by JSON representation
375 // so we must give it a dummy type
376 msg, _ := ptypes.MarshalAny(&empty.Empty{})
377 return msg
378 case "google.protobuf.Value":
379 // unset kind is not allowed by JSON representation
380 // so we must give it something
381 return &structpb.Value{
382 Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
383 Fields: map[string]*structpb.Value{
384 "google.protobuf.Value": {Kind: &structpb.Value_StringValue{
385 StringValue: "supports arbitrary JSON",
386 }},
387 },
388 }},
389 }
390 case "google.protobuf.ListValue":
391 return &structpb.ListValue{
392 Values: []*structpb.Value{
393 {
394 Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
395 Fields: map[string]*structpb.Value{
396 "google.protobuf.ListValue": {Kind: &structpb.Value_StringValue{
397 StringValue: "is an array of arbitrary JSON values",
398 }},
399 },
400 }},
401 },
402 },
403 }
404 case "google.protobuf.Struct":
405 return &structpb.Struct{
406 Fields: map[string]*structpb.Value{
407 "google.protobuf.Struct": {Kind: &structpb.Value_StringValue{
408 StringValue: "supports arbitrary JSON objects",
409 }},
410 },
411 }
412 }
413
414 dm := dynamic.NewMessage(md)
415
416 // if the message is a recursive structure, we don't want to blow the stack
417 for _, seen := range path {
418 if seen == md {
419 // already visited this type; avoid infinite recursion
420 return dm
421 }
422 }
423 path = append(path, dm.GetMessageDescriptor())
424
425 // for repeated fields, add a single element with default value
426 // and for message fields, add a message with all default fields
427 // that also has non-nil message and non-empty repeated fields
428
429 for _, fd := range dm.GetMessageDescriptor().GetFields() {
430 if fd.IsRepeated() {
431 switch fd.GetType() {
432 case descpb.FieldDescriptorProto_TYPE_FIXED32,
433 descpb.FieldDescriptorProto_TYPE_UINT32:
434 dm.AddRepeatedField(fd, uint32(0))
435
436 case descpb.FieldDescriptorProto_TYPE_SFIXED32,
437 descpb.FieldDescriptorProto_TYPE_SINT32,
438 descpb.FieldDescriptorProto_TYPE_INT32,
439 descpb.FieldDescriptorProto_TYPE_ENUM:
440 dm.AddRepeatedField(fd, int32(0))
441
442 case descpb.FieldDescriptorProto_TYPE_FIXED64,
443 descpb.FieldDescriptorProto_TYPE_UINT64:
444 dm.AddRepeatedField(fd, uint64(0))
445
446 case descpb.FieldDescriptorProto_TYPE_SFIXED64,
447 descpb.FieldDescriptorProto_TYPE_SINT64,
448 descpb.FieldDescriptorProto_TYPE_INT64:
449 dm.AddRepeatedField(fd, int64(0))
450
451 case descpb.FieldDescriptorProto_TYPE_STRING:
452 dm.AddRepeatedField(fd, "")
453
454 case descpb.FieldDescriptorProto_TYPE_BYTES:
455 dm.AddRepeatedField(fd, []byte{})
456
457 case descpb.FieldDescriptorProto_TYPE_BOOL:
458 dm.AddRepeatedField(fd, false)
459
460 case descpb.FieldDescriptorProto_TYPE_FLOAT:
461 dm.AddRepeatedField(fd, float32(0))
462
463 case descpb.FieldDescriptorProto_TYPE_DOUBLE:
464 dm.AddRepeatedField(fd, float64(0))
465
466 case descpb.FieldDescriptorProto_TYPE_MESSAGE,
467 descpb.FieldDescriptorProto_TYPE_GROUP:
468 dm.AddRepeatedField(fd, makeTemplate(fd.GetMessageType(), path))
469 }
470 } else if fd.GetMessageType() != nil {
471 dm.SetField(fd, makeTemplate(fd.GetMessageType(), path))
472 }
473 }
474 return dm
475}
476
477// ClientTransportCredentials builds transport credentials for a gRPC client using the
478// given properties. If cacertFile is blank, only standard trusted certs are used to
479// verify the server certs. If clientCertFile is blank, the client will not use a client
480// certificate. If clientCertFile is not blank then clientKeyFile must not be blank.
481func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (credentials.TransportCredentials, error) {
482 var tlsConf tls.Config
483
484 if clientCertFile != "" {
485 // Load the client certificates from disk
486 certificate, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile)
487 if err != nil {
488 return nil, fmt.Errorf("could not load client key pair: %v", err)
489 }
490 tlsConf.Certificates = []tls.Certificate{certificate}
491 }
492
493 if insecureSkipVerify {
494 tlsConf.InsecureSkipVerify = true
495 } else if cacertFile != "" {
496 // Create a certificate pool from the certificate authority
497 certPool := x509.NewCertPool()
498 ca, err := ioutil.ReadFile(cacertFile)
499 if err != nil {
500 return nil, fmt.Errorf("could not read ca certificate: %v", err)
501 }
502
503 // Append the certificates from the CA
504 if ok := certPool.AppendCertsFromPEM(ca); !ok {
505 return nil, errors.New("failed to append ca certs")
506 }
507
508 tlsConf.RootCAs = certPool
509 }
510
511 return credentials.NewTLS(&tlsConf), nil
512}
513
514// ServerTransportCredentials builds transport credentials for a gRPC server using the
515// given properties. If cacertFile is blank, the server will not request client certs
516// unless requireClientCerts is true. When requireClientCerts is false and cacertFile is
517// not blank, the server will verify client certs when presented, but will not require
518// client certs. The serverCertFile and serverKeyFile must both not be blank.
519func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string, requireClientCerts bool) (credentials.TransportCredentials, error) {
520 var tlsConf tls.Config
521 // TODO(jh): Remove this line once https://github.com/golang/go/issues/28779 is fixed
522 // in Go tip. Until then, the recently merged TLS 1.3 support breaks the TLS tests.
523 tlsConf.MaxVersion = tls.VersionTLS12
524
525 // Load the server certificates from disk
526 certificate, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile)
527 if err != nil {
528 return nil, fmt.Errorf("could not load key pair: %v", err)
529 }
530 tlsConf.Certificates = []tls.Certificate{certificate}
531
532 if cacertFile != "" {
533 // Create a certificate pool from the certificate authority
534 certPool := x509.NewCertPool()
535 ca, err := ioutil.ReadFile(cacertFile)
536 if err != nil {
537 return nil, fmt.Errorf("could not read ca certificate: %v", err)
538 }
539
540 // Append the certificates from the CA
541 if ok := certPool.AppendCertsFromPEM(ca); !ok {
542 return nil, errors.New("failed to append ca certs")
543 }
544
545 tlsConf.ClientCAs = certPool
546 }
547
548 if requireClientCerts {
549 tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
550 } else if cacertFile != "" {
551 tlsConf.ClientAuth = tls.VerifyClientCertIfGiven
552 } else {
553 tlsConf.ClientAuth = tls.NoClientCert
554 }
555
556 return credentials.NewTLS(&tlsConf), nil
557}
558
559// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
560// and blocking until the returned connection is ready. If the given credentials are nil, the
561// connection will be insecure (plain-text).
562func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
563 // grpc.Dial doesn't provide any information on permanent connection errors (like
564 // TLS handshake failures). So in order to provide good error messages, we need a
565 // custom dialer that can provide that info. That means we manage the TLS handshake.
566 result := make(chan interface{}, 1)
567
568 writeResult := func(res interface{}) {
569 // non-blocking write: we only need the first result
570 select {
571 case result <- res:
572 default:
573 }
574 }
575
576 dialer := func(ctx context.Context, address string) (net.Conn, error) {
577 conn, err := (&net.Dialer{}).DialContext(ctx, network, address)
578 if err != nil {
579 writeResult(err)
580 return nil, err
581 }
582 if creds != nil {
583 conn, _, err = creds.ClientHandshake(ctx, address, conn)
584 if err != nil {
585 writeResult(err)
586 return nil, err
587 }
588 }
589 return conn, nil
590 }
591
592 // Even with grpc.FailOnNonTempDialError, this call will usually timeout in
593 // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
594 // know when we're done. So we run it in a goroutine and then use result
595 // channel to either get the channel or fail-fast.
596 go func() {
597 opts = append(opts,
598 grpc.WithBlock(),
599 grpc.FailOnNonTempDialError(true),
600 grpc.WithContextDialer(dialer),
601 grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
602 )
603 conn, err := grpc.DialContext(ctx, address, opts...)
604 var res interface{}
605 if err != nil {
606 res = err
607 } else {
608 res = conn
609 }
610 writeResult(res)
611 }()
612
613 select {
614 case res := <-result:
615 if conn, ok := res.(*grpc.ClientConn); ok {
616 return conn, nil
617 }
618 return nil, res.(error)
619 case <-ctx.Done():
620 return nil, ctx.Err()
621 }
622}