Zack Williams | e940c7a | 2019-08-21 14:25:39 -0700 | [diff] [blame^] | 1 | // Package grpcurl provides the core functionality exposed by the grpcurl command, for |
| 2 | // dynamically connecting to a server, using the reflection service to inspect the server, |
| 3 | // and invoking RPCs. The grpcurl command-line tool constructs a DescriptorSource, based |
| 4 | // on the command-line parameters, and supplies an InvocationEventHandler to supply request |
| 5 | // data (which can come from command-line args or the process's stdin) and to log the |
| 6 | // events (to the process's stdout). |
| 7 | package grpcurl |
| 8 | |
| 9 | import ( |
| 10 | "bytes" |
| 11 | "crypto/tls" |
| 12 | "crypto/x509" |
| 13 | "encoding/base64" |
| 14 | "errors" |
| 15 | "fmt" |
| 16 | "io/ioutil" |
| 17 | "net" |
| 18 | "sort" |
| 19 | "strings" |
| 20 | |
| 21 | "github.com/golang/protobuf/proto" |
| 22 | descpb "github.com/golang/protobuf/protoc-gen-go/descriptor" |
| 23 | "github.com/golang/protobuf/ptypes" |
| 24 | "github.com/golang/protobuf/ptypes/empty" |
| 25 | "github.com/golang/protobuf/ptypes/struct" |
| 26 | "github.com/jhump/protoreflect/desc" |
| 27 | "github.com/jhump/protoreflect/desc/protoprint" |
| 28 | "github.com/jhump/protoreflect/dynamic" |
| 29 | "golang.org/x/net/context" |
| 30 | "google.golang.org/grpc" |
| 31 | "google.golang.org/grpc/credentials" |
| 32 | "google.golang.org/grpc/metadata" |
| 33 | ) |
| 34 | |
| 35 | // ListServices uses the given descriptor source to return a sorted list of fully-qualified |
| 36 | // service names. |
| 37 | func ListServices(source DescriptorSource) ([]string, error) { |
| 38 | svcs, err := source.ListServices() |
| 39 | if err != nil { |
| 40 | return nil, err |
| 41 | } |
| 42 | sort.Strings(svcs) |
| 43 | return svcs, nil |
| 44 | } |
| 45 | |
| 46 | type sourceWithFiles interface { |
| 47 | GetAllFiles() ([]*desc.FileDescriptor, error) |
| 48 | } |
| 49 | |
| 50 | var _ sourceWithFiles = (*fileSource)(nil) |
| 51 | |
| 52 | // GetAllFiles uses the given descriptor source to return a list of file descriptors. |
| 53 | func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) { |
| 54 | var files []*desc.FileDescriptor |
| 55 | srcFiles, ok := source.(sourceWithFiles) |
| 56 | |
| 57 | // If an error occurs, we still try to load as many files as we can, so that |
| 58 | // caller can decide whether to ignore error or not. |
| 59 | var firstError error |
| 60 | if ok { |
| 61 | files, firstError = srcFiles.GetAllFiles() |
| 62 | } else { |
| 63 | // Source does not implement GetAllFiles method, so use ListServices |
| 64 | // and grab files from there. |
| 65 | svcNames, err := source.ListServices() |
| 66 | if err != nil { |
| 67 | firstError = err |
| 68 | } else { |
| 69 | allFiles := map[string]*desc.FileDescriptor{} |
| 70 | for _, name := range svcNames { |
| 71 | d, err := source.FindSymbol(name) |
| 72 | if err != nil { |
| 73 | if firstError == nil { |
| 74 | firstError = err |
| 75 | } |
| 76 | } else { |
| 77 | addAllFilesToSet(d.GetFile(), allFiles) |
| 78 | } |
| 79 | } |
| 80 | files = make([]*desc.FileDescriptor, len(allFiles)) |
| 81 | i := 0 |
| 82 | for _, fd := range allFiles { |
| 83 | files[i] = fd |
| 84 | i++ |
| 85 | } |
| 86 | } |
| 87 | } |
| 88 | |
| 89 | sort.Sort(filesByName(files)) |
| 90 | return files, firstError |
| 91 | } |
| 92 | |
| 93 | type filesByName []*desc.FileDescriptor |
| 94 | |
| 95 | func (f filesByName) Len() int { |
| 96 | return len(f) |
| 97 | } |
| 98 | |
| 99 | func (f filesByName) Less(i, j int) bool { |
| 100 | return f[i].GetName() < f[j].GetName() |
| 101 | } |
| 102 | |
| 103 | func (f filesByName) Swap(i, j int) { |
| 104 | f[i], f[j] = f[j], f[i] |
| 105 | } |
| 106 | |
| 107 | func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) { |
| 108 | if _, ok := all[fd.GetName()]; ok { |
| 109 | // already added |
| 110 | return |
| 111 | } |
| 112 | all[fd.GetName()] = fd |
| 113 | for _, dep := range fd.GetDependencies() { |
| 114 | addAllFilesToSet(dep, all) |
| 115 | } |
| 116 | } |
| 117 | |
| 118 | // ListMethods uses the given descriptor source to return a sorted list of method names |
| 119 | // for the specified fully-qualified service name. |
| 120 | func ListMethods(source DescriptorSource, serviceName string) ([]string, error) { |
| 121 | dsc, err := source.FindSymbol(serviceName) |
| 122 | if err != nil { |
| 123 | return nil, err |
| 124 | } |
| 125 | if sd, ok := dsc.(*desc.ServiceDescriptor); !ok { |
| 126 | return nil, notFound("Service", serviceName) |
| 127 | } else { |
| 128 | methods := make([]string, 0, len(sd.GetMethods())) |
| 129 | for _, method := range sd.GetMethods() { |
| 130 | methods = append(methods, method.GetFullyQualifiedName()) |
| 131 | } |
| 132 | sort.Strings(methods) |
| 133 | return methods, nil |
| 134 | } |
| 135 | } |
| 136 | |
| 137 | // MetadataFromHeaders converts a list of header strings (each string in |
| 138 | // "Header-Name: Header-Value" form) into metadata. If a string has a header |
| 139 | // name without a value (e.g. does not contain a colon), the value is assumed |
| 140 | // to be blank. Binary headers (those whose names end in "-bin") should be |
| 141 | // base64-encoded. But if they cannot be base64-decoded, they will be assumed to |
| 142 | // be in raw form and used as is. |
| 143 | func MetadataFromHeaders(headers []string) metadata.MD { |
| 144 | md := make(metadata.MD) |
| 145 | for _, part := range headers { |
| 146 | if part != "" { |
| 147 | pieces := strings.SplitN(part, ":", 2) |
| 148 | if len(pieces) == 1 { |
| 149 | pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter) |
| 150 | } |
| 151 | headerName := strings.ToLower(strings.TrimSpace(pieces[0])) |
| 152 | val := strings.TrimSpace(pieces[1]) |
| 153 | if strings.HasSuffix(headerName, "-bin") { |
| 154 | if v, err := decode(val); err == nil { |
| 155 | val = v |
| 156 | } |
| 157 | } |
| 158 | md[headerName] = append(md[headerName], val) |
| 159 | } |
| 160 | } |
| 161 | return md |
| 162 | } |
| 163 | |
| 164 | var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding} |
| 165 | |
| 166 | func decode(val string) (string, error) { |
| 167 | var firstErr error |
| 168 | var b []byte |
| 169 | // we are lenient and can accept any of the flavors of base64 encoding |
| 170 | for _, d := range base64Codecs { |
| 171 | var err error |
| 172 | b, err = d.DecodeString(val) |
| 173 | if err != nil { |
| 174 | if firstErr == nil { |
| 175 | firstErr = err |
| 176 | } |
| 177 | continue |
| 178 | } |
| 179 | return string(b), nil |
| 180 | } |
| 181 | return "", firstErr |
| 182 | } |
| 183 | |
| 184 | // MetadataToString returns a string representation of the given metadata, for |
| 185 | // displaying to users. |
| 186 | func MetadataToString(md metadata.MD) string { |
| 187 | if len(md) == 0 { |
| 188 | return "(empty)" |
| 189 | } |
| 190 | |
| 191 | keys := make([]string, 0, len(md)) |
| 192 | for k := range md { |
| 193 | keys = append(keys, k) |
| 194 | } |
| 195 | sort.Strings(keys) |
| 196 | |
| 197 | var b bytes.Buffer |
| 198 | first := true |
| 199 | for _, k := range keys { |
| 200 | vs := md[k] |
| 201 | for _, v := range vs { |
| 202 | if first { |
| 203 | first = false |
| 204 | } else { |
| 205 | b.WriteString("\n") |
| 206 | } |
| 207 | b.WriteString(k) |
| 208 | b.WriteString(": ") |
| 209 | if strings.HasSuffix(k, "-bin") { |
| 210 | v = base64.StdEncoding.EncodeToString([]byte(v)) |
| 211 | } |
| 212 | b.WriteString(v) |
| 213 | } |
| 214 | } |
| 215 | return b.String() |
| 216 | } |
| 217 | |
| 218 | var printer = &protoprint.Printer{ |
| 219 | Compact: true, |
| 220 | OmitComments: protoprint.CommentsNonDoc, |
| 221 | SortElements: true, |
| 222 | ForceFullyQualifiedNames: true, |
| 223 | } |
| 224 | |
| 225 | // GetDescriptorText returns a string representation of the given descriptor. |
| 226 | // This returns a snippet of proto source that describes the given element. |
| 227 | func GetDescriptorText(dsc desc.Descriptor, _ DescriptorSource) (string, error) { |
| 228 | // Note: DescriptorSource is not used, but remains an argument for backwards |
| 229 | // compatibility with previous implementation. |
| 230 | txt, err := printer.PrintProtoToString(dsc) |
| 231 | if err != nil { |
| 232 | return "", err |
| 233 | } |
| 234 | // callers don't expect trailing newlines |
| 235 | if txt[len(txt)-1] == '\n' { |
| 236 | txt = txt[:len(txt)-1] |
| 237 | } |
| 238 | return txt, nil |
| 239 | } |
| 240 | |
| 241 | // EnsureExtensions uses the given descriptor source to download extensions for |
| 242 | // the given message. It returns a copy of the given message, but as a dynamic |
| 243 | // message that knows about all extensions known to the given descriptor source. |
| 244 | func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message { |
| 245 | // load any server extensions so we can properly describe custom options |
| 246 | dsc, err := desc.LoadMessageDescriptorForMessage(msg) |
| 247 | if err != nil { |
| 248 | return msg |
| 249 | } |
| 250 | |
| 251 | var ext dynamic.ExtensionRegistry |
| 252 | if err = fetchAllExtensions(source, &ext, dsc, map[string]bool{}); err != nil { |
| 253 | return msg |
| 254 | } |
| 255 | |
| 256 | // convert message into dynamic message that knows about applicable extensions |
| 257 | // (that way we can show meaningful info for custom options instead of printing as unknown) |
| 258 | msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext) |
| 259 | dm, err := fullyConvertToDynamic(msgFactory, msg) |
| 260 | if err != nil { |
| 261 | return msg |
| 262 | } |
| 263 | return dm |
| 264 | } |
| 265 | |
| 266 | // fetchAllExtensions recursively fetches from the server extensions for the given message type as well as |
| 267 | // for all message types of nested fields. The extensions are added to the given dynamic registry of extensions |
| 268 | // so that all server-known extensions can be correctly parsed by grpcurl. |
| 269 | func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry, md *desc.MessageDescriptor, alreadyFetched map[string]bool) error { |
| 270 | msgTypeName := md.GetFullyQualifiedName() |
| 271 | if alreadyFetched[msgTypeName] { |
| 272 | return nil |
| 273 | } |
| 274 | alreadyFetched[msgTypeName] = true |
| 275 | if len(md.GetExtensionRanges()) > 0 { |
| 276 | fds, err := source.AllExtensionsForType(msgTypeName) |
| 277 | if err != nil { |
| 278 | return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err) |
| 279 | } |
| 280 | for _, fd := range fds { |
| 281 | if err := ext.AddExtension(fd); err != nil { |
| 282 | return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err) |
| 283 | } |
| 284 | } |
| 285 | } |
| 286 | // recursively fetch extensions for the types of any message fields |
| 287 | for _, fd := range md.GetFields() { |
| 288 | if fd.GetMessageType() != nil { |
| 289 | err := fetchAllExtensions(source, ext, fd.GetMessageType(), alreadyFetched) |
| 290 | if err != nil { |
| 291 | return err |
| 292 | } |
| 293 | } |
| 294 | } |
| 295 | return nil |
| 296 | } |
| 297 | |
| 298 | // fullConvertToDynamic attempts to convert the given message to a dynamic message as well |
| 299 | // as any nested messages it may contain as field values. If the given message factory has |
| 300 | // extensions registered that were not known when the given message was parsed, this effectively |
| 301 | // allows re-parsing to identify those extensions. |
| 302 | func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (proto.Message, error) { |
| 303 | if _, ok := msg.(*dynamic.Message); ok { |
| 304 | return msg, nil // already a dynamic message |
| 305 | } |
| 306 | md, err := desc.LoadMessageDescriptorForMessage(msg) |
| 307 | if err != nil { |
| 308 | return nil, err |
| 309 | } |
| 310 | newMsg := msgFact.NewMessage(md) |
| 311 | dm, ok := newMsg.(*dynamic.Message) |
| 312 | if !ok { |
| 313 | // if message factory didn't produce a dynamic message, then we should leave msg as is |
| 314 | return msg, nil |
| 315 | } |
| 316 | |
| 317 | if err := dm.ConvertFrom(msg); err != nil { |
| 318 | return nil, err |
| 319 | } |
| 320 | |
| 321 | // recursively convert all field values, too |
| 322 | for _, fd := range md.GetFields() { |
| 323 | if fd.IsMap() { |
| 324 | if fd.GetMapValueType().GetMessageType() != nil { |
| 325 | m := dm.GetField(fd).(map[interface{}]interface{}) |
| 326 | for k, v := range m { |
| 327 | // keys can't be nested messages; so we only need to recurse through map values, not keys |
| 328 | newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message)) |
| 329 | if err != nil { |
| 330 | return nil, err |
| 331 | } |
| 332 | dm.PutMapField(fd, k, newVal) |
| 333 | } |
| 334 | } |
| 335 | } else if fd.IsRepeated() { |
| 336 | if fd.GetMessageType() != nil { |
| 337 | s := dm.GetField(fd).([]interface{}) |
| 338 | for i, e := range s { |
| 339 | newVal, err := fullyConvertToDynamic(msgFact, e.(proto.Message)) |
| 340 | if err != nil { |
| 341 | return nil, err |
| 342 | } |
| 343 | dm.SetRepeatedField(fd, i, newVal) |
| 344 | } |
| 345 | } |
| 346 | } else { |
| 347 | if fd.GetMessageType() != nil { |
| 348 | v := dm.GetField(fd) |
| 349 | newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message)) |
| 350 | if err != nil { |
| 351 | return nil, err |
| 352 | } |
| 353 | dm.SetField(fd, newVal) |
| 354 | } |
| 355 | } |
| 356 | } |
| 357 | return dm, nil |
| 358 | } |
| 359 | |
| 360 | // MakeTemplate returns a message instance for the given descriptor that is a |
| 361 | // suitable template for creating an instance of that message in JSON. In |
| 362 | // particular, it ensures that any repeated fields (which include map fields) |
| 363 | // are not empty, so they will render with a single element (to show the types |
| 364 | // and optionally nested fields). It also ensures that nested messages are not |
| 365 | // nil by setting them to a message that is also fleshed out as a template |
| 366 | // message. |
| 367 | func MakeTemplate(md *desc.MessageDescriptor) proto.Message { |
| 368 | return makeTemplate(md, nil) |
| 369 | } |
| 370 | |
| 371 | func makeTemplate(md *desc.MessageDescriptor, path []*desc.MessageDescriptor) proto.Message { |
| 372 | switch md.GetFullyQualifiedName() { |
| 373 | case "google.protobuf.Any": |
| 374 | // empty type URL is not allowed by JSON representation |
| 375 | // so we must give it a dummy type |
| 376 | msg, _ := ptypes.MarshalAny(&empty.Empty{}) |
| 377 | return msg |
| 378 | case "google.protobuf.Value": |
| 379 | // unset kind is not allowed by JSON representation |
| 380 | // so we must give it something |
| 381 | return &structpb.Value{ |
| 382 | Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{ |
| 383 | Fields: map[string]*structpb.Value{ |
| 384 | "google.protobuf.Value": {Kind: &structpb.Value_StringValue{ |
| 385 | StringValue: "supports arbitrary JSON", |
| 386 | }}, |
| 387 | }, |
| 388 | }}, |
| 389 | } |
| 390 | case "google.protobuf.ListValue": |
| 391 | return &structpb.ListValue{ |
| 392 | Values: []*structpb.Value{ |
| 393 | { |
| 394 | Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{ |
| 395 | Fields: map[string]*structpb.Value{ |
| 396 | "google.protobuf.ListValue": {Kind: &structpb.Value_StringValue{ |
| 397 | StringValue: "is an array of arbitrary JSON values", |
| 398 | }}, |
| 399 | }, |
| 400 | }}, |
| 401 | }, |
| 402 | }, |
| 403 | } |
| 404 | case "google.protobuf.Struct": |
| 405 | return &structpb.Struct{ |
| 406 | Fields: map[string]*structpb.Value{ |
| 407 | "google.protobuf.Struct": {Kind: &structpb.Value_StringValue{ |
| 408 | StringValue: "supports arbitrary JSON objects", |
| 409 | }}, |
| 410 | }, |
| 411 | } |
| 412 | } |
| 413 | |
| 414 | dm := dynamic.NewMessage(md) |
| 415 | |
| 416 | // if the message is a recursive structure, we don't want to blow the stack |
| 417 | for _, seen := range path { |
| 418 | if seen == md { |
| 419 | // already visited this type; avoid infinite recursion |
| 420 | return dm |
| 421 | } |
| 422 | } |
| 423 | path = append(path, dm.GetMessageDescriptor()) |
| 424 | |
| 425 | // for repeated fields, add a single element with default value |
| 426 | // and for message fields, add a message with all default fields |
| 427 | // that also has non-nil message and non-empty repeated fields |
| 428 | |
| 429 | for _, fd := range dm.GetMessageDescriptor().GetFields() { |
| 430 | if fd.IsRepeated() { |
| 431 | switch fd.GetType() { |
| 432 | case descpb.FieldDescriptorProto_TYPE_FIXED32, |
| 433 | descpb.FieldDescriptorProto_TYPE_UINT32: |
| 434 | dm.AddRepeatedField(fd, uint32(0)) |
| 435 | |
| 436 | case descpb.FieldDescriptorProto_TYPE_SFIXED32, |
| 437 | descpb.FieldDescriptorProto_TYPE_SINT32, |
| 438 | descpb.FieldDescriptorProto_TYPE_INT32, |
| 439 | descpb.FieldDescriptorProto_TYPE_ENUM: |
| 440 | dm.AddRepeatedField(fd, int32(0)) |
| 441 | |
| 442 | case descpb.FieldDescriptorProto_TYPE_FIXED64, |
| 443 | descpb.FieldDescriptorProto_TYPE_UINT64: |
| 444 | dm.AddRepeatedField(fd, uint64(0)) |
| 445 | |
| 446 | case descpb.FieldDescriptorProto_TYPE_SFIXED64, |
| 447 | descpb.FieldDescriptorProto_TYPE_SINT64, |
| 448 | descpb.FieldDescriptorProto_TYPE_INT64: |
| 449 | dm.AddRepeatedField(fd, int64(0)) |
| 450 | |
| 451 | case descpb.FieldDescriptorProto_TYPE_STRING: |
| 452 | dm.AddRepeatedField(fd, "") |
| 453 | |
| 454 | case descpb.FieldDescriptorProto_TYPE_BYTES: |
| 455 | dm.AddRepeatedField(fd, []byte{}) |
| 456 | |
| 457 | case descpb.FieldDescriptorProto_TYPE_BOOL: |
| 458 | dm.AddRepeatedField(fd, false) |
| 459 | |
| 460 | case descpb.FieldDescriptorProto_TYPE_FLOAT: |
| 461 | dm.AddRepeatedField(fd, float32(0)) |
| 462 | |
| 463 | case descpb.FieldDescriptorProto_TYPE_DOUBLE: |
| 464 | dm.AddRepeatedField(fd, float64(0)) |
| 465 | |
| 466 | case descpb.FieldDescriptorProto_TYPE_MESSAGE, |
| 467 | descpb.FieldDescriptorProto_TYPE_GROUP: |
| 468 | dm.AddRepeatedField(fd, makeTemplate(fd.GetMessageType(), path)) |
| 469 | } |
| 470 | } else if fd.GetMessageType() != nil { |
| 471 | dm.SetField(fd, makeTemplate(fd.GetMessageType(), path)) |
| 472 | } |
| 473 | } |
| 474 | return dm |
| 475 | } |
| 476 | |
| 477 | // ClientTransportCredentials builds transport credentials for a gRPC client using the |
| 478 | // given properties. If cacertFile is blank, only standard trusted certs are used to |
| 479 | // verify the server certs. If clientCertFile is blank, the client will not use a client |
| 480 | // certificate. If clientCertFile is not blank then clientKeyFile must not be blank. |
| 481 | func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (credentials.TransportCredentials, error) { |
| 482 | var tlsConf tls.Config |
| 483 | |
| 484 | if clientCertFile != "" { |
| 485 | // Load the client certificates from disk |
| 486 | certificate, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile) |
| 487 | if err != nil { |
| 488 | return nil, fmt.Errorf("could not load client key pair: %v", err) |
| 489 | } |
| 490 | tlsConf.Certificates = []tls.Certificate{certificate} |
| 491 | } |
| 492 | |
| 493 | if insecureSkipVerify { |
| 494 | tlsConf.InsecureSkipVerify = true |
| 495 | } else if cacertFile != "" { |
| 496 | // Create a certificate pool from the certificate authority |
| 497 | certPool := x509.NewCertPool() |
| 498 | ca, err := ioutil.ReadFile(cacertFile) |
| 499 | if err != nil { |
| 500 | return nil, fmt.Errorf("could not read ca certificate: %v", err) |
| 501 | } |
| 502 | |
| 503 | // Append the certificates from the CA |
| 504 | if ok := certPool.AppendCertsFromPEM(ca); !ok { |
| 505 | return nil, errors.New("failed to append ca certs") |
| 506 | } |
| 507 | |
| 508 | tlsConf.RootCAs = certPool |
| 509 | } |
| 510 | |
| 511 | return credentials.NewTLS(&tlsConf), nil |
| 512 | } |
| 513 | |
| 514 | // ServerTransportCredentials builds transport credentials for a gRPC server using the |
| 515 | // given properties. If cacertFile is blank, the server will not request client certs |
| 516 | // unless requireClientCerts is true. When requireClientCerts is false and cacertFile is |
| 517 | // not blank, the server will verify client certs when presented, but will not require |
| 518 | // client certs. The serverCertFile and serverKeyFile must both not be blank. |
| 519 | func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string, requireClientCerts bool) (credentials.TransportCredentials, error) { |
| 520 | var tlsConf tls.Config |
| 521 | // TODO(jh): Remove this line once https://github.com/golang/go/issues/28779 is fixed |
| 522 | // in Go tip. Until then, the recently merged TLS 1.3 support breaks the TLS tests. |
| 523 | tlsConf.MaxVersion = tls.VersionTLS12 |
| 524 | |
| 525 | // Load the server certificates from disk |
| 526 | certificate, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile) |
| 527 | if err != nil { |
| 528 | return nil, fmt.Errorf("could not load key pair: %v", err) |
| 529 | } |
| 530 | tlsConf.Certificates = []tls.Certificate{certificate} |
| 531 | |
| 532 | if cacertFile != "" { |
| 533 | // Create a certificate pool from the certificate authority |
| 534 | certPool := x509.NewCertPool() |
| 535 | ca, err := ioutil.ReadFile(cacertFile) |
| 536 | if err != nil { |
| 537 | return nil, fmt.Errorf("could not read ca certificate: %v", err) |
| 538 | } |
| 539 | |
| 540 | // Append the certificates from the CA |
| 541 | if ok := certPool.AppendCertsFromPEM(ca); !ok { |
| 542 | return nil, errors.New("failed to append ca certs") |
| 543 | } |
| 544 | |
| 545 | tlsConf.ClientCAs = certPool |
| 546 | } |
| 547 | |
| 548 | if requireClientCerts { |
| 549 | tlsConf.ClientAuth = tls.RequireAndVerifyClientCert |
| 550 | } else if cacertFile != "" { |
| 551 | tlsConf.ClientAuth = tls.VerifyClientCertIfGiven |
| 552 | } else { |
| 553 | tlsConf.ClientAuth = tls.NoClientCert |
| 554 | } |
| 555 | |
| 556 | return credentials.NewTLS(&tlsConf), nil |
| 557 | } |
| 558 | |
| 559 | // BlockingDial is a helper method to dial the given address, using optional TLS credentials, |
| 560 | // and blocking until the returned connection is ready. If the given credentials are nil, the |
| 561 | // connection will be insecure (plain-text). |
| 562 | func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) { |
| 563 | // grpc.Dial doesn't provide any information on permanent connection errors (like |
| 564 | // TLS handshake failures). So in order to provide good error messages, we need a |
| 565 | // custom dialer that can provide that info. That means we manage the TLS handshake. |
| 566 | result := make(chan interface{}, 1) |
| 567 | |
| 568 | writeResult := func(res interface{}) { |
| 569 | // non-blocking write: we only need the first result |
| 570 | select { |
| 571 | case result <- res: |
| 572 | default: |
| 573 | } |
| 574 | } |
| 575 | |
| 576 | dialer := func(ctx context.Context, address string) (net.Conn, error) { |
| 577 | conn, err := (&net.Dialer{}).DialContext(ctx, network, address) |
| 578 | if err != nil { |
| 579 | writeResult(err) |
| 580 | return nil, err |
| 581 | } |
| 582 | if creds != nil { |
| 583 | conn, _, err = creds.ClientHandshake(ctx, address, conn) |
| 584 | if err != nil { |
| 585 | writeResult(err) |
| 586 | return nil, err |
| 587 | } |
| 588 | } |
| 589 | return conn, nil |
| 590 | } |
| 591 | |
| 592 | // Even with grpc.FailOnNonTempDialError, this call will usually timeout in |
| 593 | // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to |
| 594 | // know when we're done. So we run it in a goroutine and then use result |
| 595 | // channel to either get the channel or fail-fast. |
| 596 | go func() { |
| 597 | opts = append(opts, |
| 598 | grpc.WithBlock(), |
| 599 | grpc.FailOnNonTempDialError(true), |
| 600 | grpc.WithContextDialer(dialer), |
| 601 | grpc.WithInsecure(), // we are handling TLS, so tell grpc not to |
| 602 | ) |
| 603 | conn, err := grpc.DialContext(ctx, address, opts...) |
| 604 | var res interface{} |
| 605 | if err != nil { |
| 606 | res = err |
| 607 | } else { |
| 608 | res = conn |
| 609 | } |
| 610 | writeResult(res) |
| 611 | }() |
| 612 | |
| 613 | select { |
| 614 | case res := <-result: |
| 615 | if conn, ok := res.(*grpc.ClientConn); ok { |
| 616 | return conn, nil |
| 617 | } |
| 618 | return nil, res.(error) |
| 619 | case <-ctx.Done(): |
| 620 | return nil, ctx.Err() |
| 621 | } |
| 622 | } |