Zack Williams | e940c7a | 2019-08-21 14:25:39 -0700 | [diff] [blame] | 1 | // Package grpcurl provides the core functionality exposed by the grpcurl command, for |
| 2 | // dynamically connecting to a server, using the reflection service to inspect the server, |
| 3 | // and invoking RPCs. The grpcurl command-line tool constructs a DescriptorSource, based |
| 4 | // on the command-line parameters, and supplies an InvocationEventHandler to supply request |
| 5 | // data (which can come from command-line args or the process's stdin) and to log the |
| 6 | // events (to the process's stdout). |
| 7 | package grpcurl |
| 8 | |
| 9 | import ( |
| 10 | "bytes" |
| 11 | "crypto/tls" |
| 12 | "crypto/x509" |
| 13 | "encoding/base64" |
| 14 | "errors" |
| 15 | "fmt" |
| 16 | "io/ioutil" |
| 17 | "net" |
Scott Baker | 4a35a70 | 2019-11-26 08:17:33 -0800 | [diff] [blame] | 18 | "os" |
| 19 | "regexp" |
Zack Williams | e940c7a | 2019-08-21 14:25:39 -0700 | [diff] [blame] | 20 | "sort" |
| 21 | "strings" |
| 22 | |
| 23 | "github.com/golang/protobuf/proto" |
| 24 | descpb "github.com/golang/protobuf/protoc-gen-go/descriptor" |
| 25 | "github.com/golang/protobuf/ptypes" |
| 26 | "github.com/golang/protobuf/ptypes/empty" |
| 27 | "github.com/golang/protobuf/ptypes/struct" |
| 28 | "github.com/jhump/protoreflect/desc" |
| 29 | "github.com/jhump/protoreflect/desc/protoprint" |
| 30 | "github.com/jhump/protoreflect/dynamic" |
| 31 | "golang.org/x/net/context" |
| 32 | "google.golang.org/grpc" |
| 33 | "google.golang.org/grpc/credentials" |
| 34 | "google.golang.org/grpc/metadata" |
| 35 | ) |
| 36 | |
| 37 | // ListServices uses the given descriptor source to return a sorted list of fully-qualified |
| 38 | // service names. |
| 39 | func ListServices(source DescriptorSource) ([]string, error) { |
| 40 | svcs, err := source.ListServices() |
| 41 | if err != nil { |
| 42 | return nil, err |
| 43 | } |
| 44 | sort.Strings(svcs) |
| 45 | return svcs, nil |
| 46 | } |
| 47 | |
| 48 | type sourceWithFiles interface { |
| 49 | GetAllFiles() ([]*desc.FileDescriptor, error) |
| 50 | } |
| 51 | |
| 52 | var _ sourceWithFiles = (*fileSource)(nil) |
| 53 | |
| 54 | // GetAllFiles uses the given descriptor source to return a list of file descriptors. |
| 55 | func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) { |
| 56 | var files []*desc.FileDescriptor |
| 57 | srcFiles, ok := source.(sourceWithFiles) |
| 58 | |
| 59 | // If an error occurs, we still try to load as many files as we can, so that |
| 60 | // caller can decide whether to ignore error or not. |
| 61 | var firstError error |
| 62 | if ok { |
| 63 | files, firstError = srcFiles.GetAllFiles() |
| 64 | } else { |
| 65 | // Source does not implement GetAllFiles method, so use ListServices |
| 66 | // and grab files from there. |
| 67 | svcNames, err := source.ListServices() |
| 68 | if err != nil { |
| 69 | firstError = err |
| 70 | } else { |
| 71 | allFiles := map[string]*desc.FileDescriptor{} |
| 72 | for _, name := range svcNames { |
| 73 | d, err := source.FindSymbol(name) |
| 74 | if err != nil { |
| 75 | if firstError == nil { |
| 76 | firstError = err |
| 77 | } |
| 78 | } else { |
| 79 | addAllFilesToSet(d.GetFile(), allFiles) |
| 80 | } |
| 81 | } |
| 82 | files = make([]*desc.FileDescriptor, len(allFiles)) |
| 83 | i := 0 |
| 84 | for _, fd := range allFiles { |
| 85 | files[i] = fd |
| 86 | i++ |
| 87 | } |
| 88 | } |
| 89 | } |
| 90 | |
| 91 | sort.Sort(filesByName(files)) |
| 92 | return files, firstError |
| 93 | } |
| 94 | |
| 95 | type filesByName []*desc.FileDescriptor |
| 96 | |
| 97 | func (f filesByName) Len() int { |
| 98 | return len(f) |
| 99 | } |
| 100 | |
| 101 | func (f filesByName) Less(i, j int) bool { |
| 102 | return f[i].GetName() < f[j].GetName() |
| 103 | } |
| 104 | |
| 105 | func (f filesByName) Swap(i, j int) { |
| 106 | f[i], f[j] = f[j], f[i] |
| 107 | } |
| 108 | |
| 109 | func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) { |
| 110 | if _, ok := all[fd.GetName()]; ok { |
| 111 | // already added |
| 112 | return |
| 113 | } |
| 114 | all[fd.GetName()] = fd |
| 115 | for _, dep := range fd.GetDependencies() { |
| 116 | addAllFilesToSet(dep, all) |
| 117 | } |
| 118 | } |
| 119 | |
| 120 | // ListMethods uses the given descriptor source to return a sorted list of method names |
| 121 | // for the specified fully-qualified service name. |
| 122 | func ListMethods(source DescriptorSource, serviceName string) ([]string, error) { |
| 123 | dsc, err := source.FindSymbol(serviceName) |
| 124 | if err != nil { |
| 125 | return nil, err |
| 126 | } |
| 127 | if sd, ok := dsc.(*desc.ServiceDescriptor); !ok { |
| 128 | return nil, notFound("Service", serviceName) |
| 129 | } else { |
| 130 | methods := make([]string, 0, len(sd.GetMethods())) |
| 131 | for _, method := range sd.GetMethods() { |
| 132 | methods = append(methods, method.GetFullyQualifiedName()) |
| 133 | } |
| 134 | sort.Strings(methods) |
| 135 | return methods, nil |
| 136 | } |
| 137 | } |
| 138 | |
| 139 | // MetadataFromHeaders converts a list of header strings (each string in |
| 140 | // "Header-Name: Header-Value" form) into metadata. If a string has a header |
| 141 | // name without a value (e.g. does not contain a colon), the value is assumed |
| 142 | // to be blank. Binary headers (those whose names end in "-bin") should be |
| 143 | // base64-encoded. But if they cannot be base64-decoded, they will be assumed to |
| 144 | // be in raw form and used as is. |
| 145 | func MetadataFromHeaders(headers []string) metadata.MD { |
| 146 | md := make(metadata.MD) |
| 147 | for _, part := range headers { |
| 148 | if part != "" { |
| 149 | pieces := strings.SplitN(part, ":", 2) |
| 150 | if len(pieces) == 1 { |
| 151 | pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter) |
| 152 | } |
| 153 | headerName := strings.ToLower(strings.TrimSpace(pieces[0])) |
| 154 | val := strings.TrimSpace(pieces[1]) |
| 155 | if strings.HasSuffix(headerName, "-bin") { |
| 156 | if v, err := decode(val); err == nil { |
| 157 | val = v |
| 158 | } |
| 159 | } |
| 160 | md[headerName] = append(md[headerName], val) |
| 161 | } |
| 162 | } |
| 163 | return md |
| 164 | } |
| 165 | |
Scott Baker | 4a35a70 | 2019-11-26 08:17:33 -0800 | [diff] [blame] | 166 | var envVarRegex = regexp.MustCompile(`\${\w+}`) |
| 167 | |
| 168 | // ExpandHeaders expands environment variables contained in the header string. |
| 169 | // If no corresponding environment variable is found an error is returned. |
| 170 | // TODO: Add escaping for `${` |
| 171 | func ExpandHeaders(headers []string) ([]string, error) { |
| 172 | expandedHeaders := make([]string, len(headers)) |
| 173 | for idx, header := range headers { |
| 174 | if header == "" { |
| 175 | continue |
| 176 | } |
| 177 | results := envVarRegex.FindAllString(header, -1) |
| 178 | if len(results) == 0 { |
| 179 | expandedHeaders[idx] = headers[idx] |
| 180 | continue |
| 181 | } |
| 182 | expandedHeader := header |
| 183 | for _, result := range results { |
| 184 | envVarName := result[2 : len(result)-1] // strip leading `${` and trailing `}` |
| 185 | envVarValue, ok := os.LookupEnv(envVarName) |
| 186 | if !ok { |
| 187 | return nil, fmt.Errorf("header %q refers to missing environment variable %q", header, envVarName) |
| 188 | } |
| 189 | expandedHeader = strings.Replace(expandedHeader, result, envVarValue, -1) |
| 190 | } |
| 191 | expandedHeaders[idx] = expandedHeader |
| 192 | } |
| 193 | return expandedHeaders, nil |
| 194 | } |
| 195 | |
Zack Williams | e940c7a | 2019-08-21 14:25:39 -0700 | [diff] [blame] | 196 | var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding} |
| 197 | |
| 198 | func decode(val string) (string, error) { |
| 199 | var firstErr error |
| 200 | var b []byte |
| 201 | // we are lenient and can accept any of the flavors of base64 encoding |
| 202 | for _, d := range base64Codecs { |
| 203 | var err error |
| 204 | b, err = d.DecodeString(val) |
| 205 | if err != nil { |
| 206 | if firstErr == nil { |
| 207 | firstErr = err |
| 208 | } |
| 209 | continue |
| 210 | } |
| 211 | return string(b), nil |
| 212 | } |
| 213 | return "", firstErr |
| 214 | } |
| 215 | |
| 216 | // MetadataToString returns a string representation of the given metadata, for |
| 217 | // displaying to users. |
| 218 | func MetadataToString(md metadata.MD) string { |
| 219 | if len(md) == 0 { |
| 220 | return "(empty)" |
| 221 | } |
| 222 | |
| 223 | keys := make([]string, 0, len(md)) |
| 224 | for k := range md { |
| 225 | keys = append(keys, k) |
| 226 | } |
| 227 | sort.Strings(keys) |
| 228 | |
| 229 | var b bytes.Buffer |
| 230 | first := true |
| 231 | for _, k := range keys { |
| 232 | vs := md[k] |
| 233 | for _, v := range vs { |
| 234 | if first { |
| 235 | first = false |
| 236 | } else { |
| 237 | b.WriteString("\n") |
| 238 | } |
| 239 | b.WriteString(k) |
| 240 | b.WriteString(": ") |
| 241 | if strings.HasSuffix(k, "-bin") { |
| 242 | v = base64.StdEncoding.EncodeToString([]byte(v)) |
| 243 | } |
| 244 | b.WriteString(v) |
| 245 | } |
| 246 | } |
| 247 | return b.String() |
| 248 | } |
| 249 | |
| 250 | var printer = &protoprint.Printer{ |
| 251 | Compact: true, |
| 252 | OmitComments: protoprint.CommentsNonDoc, |
| 253 | SortElements: true, |
| 254 | ForceFullyQualifiedNames: true, |
| 255 | } |
| 256 | |
| 257 | // GetDescriptorText returns a string representation of the given descriptor. |
| 258 | // This returns a snippet of proto source that describes the given element. |
| 259 | func GetDescriptorText(dsc desc.Descriptor, _ DescriptorSource) (string, error) { |
| 260 | // Note: DescriptorSource is not used, but remains an argument for backwards |
| 261 | // compatibility with previous implementation. |
| 262 | txt, err := printer.PrintProtoToString(dsc) |
| 263 | if err != nil { |
| 264 | return "", err |
| 265 | } |
| 266 | // callers don't expect trailing newlines |
| 267 | if txt[len(txt)-1] == '\n' { |
| 268 | txt = txt[:len(txt)-1] |
| 269 | } |
| 270 | return txt, nil |
| 271 | } |
| 272 | |
| 273 | // EnsureExtensions uses the given descriptor source to download extensions for |
| 274 | // the given message. It returns a copy of the given message, but as a dynamic |
| 275 | // message that knows about all extensions known to the given descriptor source. |
| 276 | func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message { |
| 277 | // load any server extensions so we can properly describe custom options |
| 278 | dsc, err := desc.LoadMessageDescriptorForMessage(msg) |
| 279 | if err != nil { |
| 280 | return msg |
| 281 | } |
| 282 | |
| 283 | var ext dynamic.ExtensionRegistry |
| 284 | if err = fetchAllExtensions(source, &ext, dsc, map[string]bool{}); err != nil { |
| 285 | return msg |
| 286 | } |
| 287 | |
| 288 | // convert message into dynamic message that knows about applicable extensions |
| 289 | // (that way we can show meaningful info for custom options instead of printing as unknown) |
| 290 | msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext) |
| 291 | dm, err := fullyConvertToDynamic(msgFactory, msg) |
| 292 | if err != nil { |
| 293 | return msg |
| 294 | } |
| 295 | return dm |
| 296 | } |
| 297 | |
| 298 | // fetchAllExtensions recursively fetches from the server extensions for the given message type as well as |
| 299 | // for all message types of nested fields. The extensions are added to the given dynamic registry of extensions |
| 300 | // so that all server-known extensions can be correctly parsed by grpcurl. |
| 301 | func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry, md *desc.MessageDescriptor, alreadyFetched map[string]bool) error { |
| 302 | msgTypeName := md.GetFullyQualifiedName() |
| 303 | if alreadyFetched[msgTypeName] { |
| 304 | return nil |
| 305 | } |
| 306 | alreadyFetched[msgTypeName] = true |
| 307 | if len(md.GetExtensionRanges()) > 0 { |
| 308 | fds, err := source.AllExtensionsForType(msgTypeName) |
| 309 | if err != nil { |
| 310 | return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err) |
| 311 | } |
| 312 | for _, fd := range fds { |
| 313 | if err := ext.AddExtension(fd); err != nil { |
| 314 | return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err) |
| 315 | } |
| 316 | } |
| 317 | } |
| 318 | // recursively fetch extensions for the types of any message fields |
| 319 | for _, fd := range md.GetFields() { |
| 320 | if fd.GetMessageType() != nil { |
| 321 | err := fetchAllExtensions(source, ext, fd.GetMessageType(), alreadyFetched) |
| 322 | if err != nil { |
| 323 | return err |
| 324 | } |
| 325 | } |
| 326 | } |
| 327 | return nil |
| 328 | } |
| 329 | |
| 330 | // fullConvertToDynamic attempts to convert the given message to a dynamic message as well |
| 331 | // as any nested messages it may contain as field values. If the given message factory has |
| 332 | // extensions registered that were not known when the given message was parsed, this effectively |
| 333 | // allows re-parsing to identify those extensions. |
| 334 | func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (proto.Message, error) { |
| 335 | if _, ok := msg.(*dynamic.Message); ok { |
| 336 | return msg, nil // already a dynamic message |
| 337 | } |
| 338 | md, err := desc.LoadMessageDescriptorForMessage(msg) |
| 339 | if err != nil { |
| 340 | return nil, err |
| 341 | } |
| 342 | newMsg := msgFact.NewMessage(md) |
| 343 | dm, ok := newMsg.(*dynamic.Message) |
| 344 | if !ok { |
| 345 | // if message factory didn't produce a dynamic message, then we should leave msg as is |
| 346 | return msg, nil |
| 347 | } |
| 348 | |
| 349 | if err := dm.ConvertFrom(msg); err != nil { |
| 350 | return nil, err |
| 351 | } |
| 352 | |
| 353 | // recursively convert all field values, too |
| 354 | for _, fd := range md.GetFields() { |
| 355 | if fd.IsMap() { |
| 356 | if fd.GetMapValueType().GetMessageType() != nil { |
| 357 | m := dm.GetField(fd).(map[interface{}]interface{}) |
| 358 | for k, v := range m { |
| 359 | // keys can't be nested messages; so we only need to recurse through map values, not keys |
| 360 | newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message)) |
| 361 | if err != nil { |
| 362 | return nil, err |
| 363 | } |
| 364 | dm.PutMapField(fd, k, newVal) |
| 365 | } |
| 366 | } |
| 367 | } else if fd.IsRepeated() { |
| 368 | if fd.GetMessageType() != nil { |
| 369 | s := dm.GetField(fd).([]interface{}) |
| 370 | for i, e := range s { |
| 371 | newVal, err := fullyConvertToDynamic(msgFact, e.(proto.Message)) |
| 372 | if err != nil { |
| 373 | return nil, err |
| 374 | } |
| 375 | dm.SetRepeatedField(fd, i, newVal) |
| 376 | } |
| 377 | } |
| 378 | } else { |
| 379 | if fd.GetMessageType() != nil { |
| 380 | v := dm.GetField(fd) |
| 381 | newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message)) |
| 382 | if err != nil { |
| 383 | return nil, err |
| 384 | } |
| 385 | dm.SetField(fd, newVal) |
| 386 | } |
| 387 | } |
| 388 | } |
| 389 | return dm, nil |
| 390 | } |
| 391 | |
| 392 | // MakeTemplate returns a message instance for the given descriptor that is a |
| 393 | // suitable template for creating an instance of that message in JSON. In |
| 394 | // particular, it ensures that any repeated fields (which include map fields) |
| 395 | // are not empty, so they will render with a single element (to show the types |
| 396 | // and optionally nested fields). It also ensures that nested messages are not |
| 397 | // nil by setting them to a message that is also fleshed out as a template |
| 398 | // message. |
| 399 | func MakeTemplate(md *desc.MessageDescriptor) proto.Message { |
| 400 | return makeTemplate(md, nil) |
| 401 | } |
| 402 | |
| 403 | func makeTemplate(md *desc.MessageDescriptor, path []*desc.MessageDescriptor) proto.Message { |
| 404 | switch md.GetFullyQualifiedName() { |
| 405 | case "google.protobuf.Any": |
| 406 | // empty type URL is not allowed by JSON representation |
| 407 | // so we must give it a dummy type |
| 408 | msg, _ := ptypes.MarshalAny(&empty.Empty{}) |
| 409 | return msg |
| 410 | case "google.protobuf.Value": |
| 411 | // unset kind is not allowed by JSON representation |
| 412 | // so we must give it something |
| 413 | return &structpb.Value{ |
| 414 | Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{ |
| 415 | Fields: map[string]*structpb.Value{ |
| 416 | "google.protobuf.Value": {Kind: &structpb.Value_StringValue{ |
| 417 | StringValue: "supports arbitrary JSON", |
| 418 | }}, |
| 419 | }, |
| 420 | }}, |
| 421 | } |
| 422 | case "google.protobuf.ListValue": |
| 423 | return &structpb.ListValue{ |
| 424 | Values: []*structpb.Value{ |
| 425 | { |
| 426 | Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{ |
| 427 | Fields: map[string]*structpb.Value{ |
| 428 | "google.protobuf.ListValue": {Kind: &structpb.Value_StringValue{ |
| 429 | StringValue: "is an array of arbitrary JSON values", |
| 430 | }}, |
| 431 | }, |
| 432 | }}, |
| 433 | }, |
| 434 | }, |
| 435 | } |
| 436 | case "google.protobuf.Struct": |
| 437 | return &structpb.Struct{ |
| 438 | Fields: map[string]*structpb.Value{ |
| 439 | "google.protobuf.Struct": {Kind: &structpb.Value_StringValue{ |
| 440 | StringValue: "supports arbitrary JSON objects", |
| 441 | }}, |
| 442 | }, |
| 443 | } |
| 444 | } |
| 445 | |
| 446 | dm := dynamic.NewMessage(md) |
| 447 | |
| 448 | // if the message is a recursive structure, we don't want to blow the stack |
| 449 | for _, seen := range path { |
| 450 | if seen == md { |
| 451 | // already visited this type; avoid infinite recursion |
| 452 | return dm |
| 453 | } |
| 454 | } |
| 455 | path = append(path, dm.GetMessageDescriptor()) |
| 456 | |
| 457 | // for repeated fields, add a single element with default value |
| 458 | // and for message fields, add a message with all default fields |
| 459 | // that also has non-nil message and non-empty repeated fields |
| 460 | |
| 461 | for _, fd := range dm.GetMessageDescriptor().GetFields() { |
| 462 | if fd.IsRepeated() { |
| 463 | switch fd.GetType() { |
| 464 | case descpb.FieldDescriptorProto_TYPE_FIXED32, |
| 465 | descpb.FieldDescriptorProto_TYPE_UINT32: |
| 466 | dm.AddRepeatedField(fd, uint32(0)) |
| 467 | |
| 468 | case descpb.FieldDescriptorProto_TYPE_SFIXED32, |
| 469 | descpb.FieldDescriptorProto_TYPE_SINT32, |
| 470 | descpb.FieldDescriptorProto_TYPE_INT32, |
| 471 | descpb.FieldDescriptorProto_TYPE_ENUM: |
| 472 | dm.AddRepeatedField(fd, int32(0)) |
| 473 | |
| 474 | case descpb.FieldDescriptorProto_TYPE_FIXED64, |
| 475 | descpb.FieldDescriptorProto_TYPE_UINT64: |
| 476 | dm.AddRepeatedField(fd, uint64(0)) |
| 477 | |
| 478 | case descpb.FieldDescriptorProto_TYPE_SFIXED64, |
| 479 | descpb.FieldDescriptorProto_TYPE_SINT64, |
| 480 | descpb.FieldDescriptorProto_TYPE_INT64: |
| 481 | dm.AddRepeatedField(fd, int64(0)) |
| 482 | |
| 483 | case descpb.FieldDescriptorProto_TYPE_STRING: |
| 484 | dm.AddRepeatedField(fd, "") |
| 485 | |
| 486 | case descpb.FieldDescriptorProto_TYPE_BYTES: |
| 487 | dm.AddRepeatedField(fd, []byte{}) |
| 488 | |
| 489 | case descpb.FieldDescriptorProto_TYPE_BOOL: |
| 490 | dm.AddRepeatedField(fd, false) |
| 491 | |
| 492 | case descpb.FieldDescriptorProto_TYPE_FLOAT: |
| 493 | dm.AddRepeatedField(fd, float32(0)) |
| 494 | |
| 495 | case descpb.FieldDescriptorProto_TYPE_DOUBLE: |
| 496 | dm.AddRepeatedField(fd, float64(0)) |
| 497 | |
| 498 | case descpb.FieldDescriptorProto_TYPE_MESSAGE, |
| 499 | descpb.FieldDescriptorProto_TYPE_GROUP: |
| 500 | dm.AddRepeatedField(fd, makeTemplate(fd.GetMessageType(), path)) |
| 501 | } |
| 502 | } else if fd.GetMessageType() != nil { |
| 503 | dm.SetField(fd, makeTemplate(fd.GetMessageType(), path)) |
| 504 | } |
| 505 | } |
| 506 | return dm |
| 507 | } |
| 508 | |
| 509 | // ClientTransportCredentials builds transport credentials for a gRPC client using the |
| 510 | // given properties. If cacertFile is blank, only standard trusted certs are used to |
| 511 | // verify the server certs. If clientCertFile is blank, the client will not use a client |
| 512 | // certificate. If clientCertFile is not blank then clientKeyFile must not be blank. |
| 513 | func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (credentials.TransportCredentials, error) { |
| 514 | var tlsConf tls.Config |
| 515 | |
| 516 | if clientCertFile != "" { |
| 517 | // Load the client certificates from disk |
| 518 | certificate, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile) |
| 519 | if err != nil { |
| 520 | return nil, fmt.Errorf("could not load client key pair: %v", err) |
| 521 | } |
| 522 | tlsConf.Certificates = []tls.Certificate{certificate} |
| 523 | } |
| 524 | |
| 525 | if insecureSkipVerify { |
| 526 | tlsConf.InsecureSkipVerify = true |
| 527 | } else if cacertFile != "" { |
| 528 | // Create a certificate pool from the certificate authority |
| 529 | certPool := x509.NewCertPool() |
| 530 | ca, err := ioutil.ReadFile(cacertFile) |
| 531 | if err != nil { |
| 532 | return nil, fmt.Errorf("could not read ca certificate: %v", err) |
| 533 | } |
| 534 | |
| 535 | // Append the certificates from the CA |
| 536 | if ok := certPool.AppendCertsFromPEM(ca); !ok { |
| 537 | return nil, errors.New("failed to append ca certs") |
| 538 | } |
| 539 | |
| 540 | tlsConf.RootCAs = certPool |
| 541 | } |
| 542 | |
| 543 | return credentials.NewTLS(&tlsConf), nil |
| 544 | } |
| 545 | |
| 546 | // ServerTransportCredentials builds transport credentials for a gRPC server using the |
| 547 | // given properties. If cacertFile is blank, the server will not request client certs |
| 548 | // unless requireClientCerts is true. When requireClientCerts is false and cacertFile is |
| 549 | // not blank, the server will verify client certs when presented, but will not require |
| 550 | // client certs. The serverCertFile and serverKeyFile must both not be blank. |
| 551 | func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string, requireClientCerts bool) (credentials.TransportCredentials, error) { |
| 552 | var tlsConf tls.Config |
| 553 | // TODO(jh): Remove this line once https://github.com/golang/go/issues/28779 is fixed |
| 554 | // in Go tip. Until then, the recently merged TLS 1.3 support breaks the TLS tests. |
| 555 | tlsConf.MaxVersion = tls.VersionTLS12 |
| 556 | |
| 557 | // Load the server certificates from disk |
| 558 | certificate, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile) |
| 559 | if err != nil { |
| 560 | return nil, fmt.Errorf("could not load key pair: %v", err) |
| 561 | } |
| 562 | tlsConf.Certificates = []tls.Certificate{certificate} |
| 563 | |
| 564 | if cacertFile != "" { |
| 565 | // Create a certificate pool from the certificate authority |
| 566 | certPool := x509.NewCertPool() |
| 567 | ca, err := ioutil.ReadFile(cacertFile) |
| 568 | if err != nil { |
| 569 | return nil, fmt.Errorf("could not read ca certificate: %v", err) |
| 570 | } |
| 571 | |
| 572 | // Append the certificates from the CA |
| 573 | if ok := certPool.AppendCertsFromPEM(ca); !ok { |
| 574 | return nil, errors.New("failed to append ca certs") |
| 575 | } |
| 576 | |
| 577 | tlsConf.ClientCAs = certPool |
| 578 | } |
| 579 | |
| 580 | if requireClientCerts { |
| 581 | tlsConf.ClientAuth = tls.RequireAndVerifyClientCert |
| 582 | } else if cacertFile != "" { |
| 583 | tlsConf.ClientAuth = tls.VerifyClientCertIfGiven |
| 584 | } else { |
| 585 | tlsConf.ClientAuth = tls.NoClientCert |
| 586 | } |
| 587 | |
| 588 | return credentials.NewTLS(&tlsConf), nil |
| 589 | } |
| 590 | |
| 591 | // BlockingDial is a helper method to dial the given address, using optional TLS credentials, |
| 592 | // and blocking until the returned connection is ready. If the given credentials are nil, the |
| 593 | // connection will be insecure (plain-text). |
| 594 | func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) { |
| 595 | // grpc.Dial doesn't provide any information on permanent connection errors (like |
| 596 | // TLS handshake failures). So in order to provide good error messages, we need a |
| 597 | // custom dialer that can provide that info. That means we manage the TLS handshake. |
| 598 | result := make(chan interface{}, 1) |
| 599 | |
| 600 | writeResult := func(res interface{}) { |
| 601 | // non-blocking write: we only need the first result |
| 602 | select { |
| 603 | case result <- res: |
| 604 | default: |
| 605 | } |
| 606 | } |
| 607 | |
| 608 | dialer := func(ctx context.Context, address string) (net.Conn, error) { |
| 609 | conn, err := (&net.Dialer{}).DialContext(ctx, network, address) |
| 610 | if err != nil { |
| 611 | writeResult(err) |
| 612 | return nil, err |
| 613 | } |
| 614 | if creds != nil { |
| 615 | conn, _, err = creds.ClientHandshake(ctx, address, conn) |
| 616 | if err != nil { |
| 617 | writeResult(err) |
| 618 | return nil, err |
| 619 | } |
| 620 | } |
| 621 | return conn, nil |
| 622 | } |
| 623 | |
| 624 | // Even with grpc.FailOnNonTempDialError, this call will usually timeout in |
| 625 | // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to |
| 626 | // know when we're done. So we run it in a goroutine and then use result |
| 627 | // channel to either get the channel or fail-fast. |
| 628 | go func() { |
| 629 | opts = append(opts, |
| 630 | grpc.WithBlock(), |
| 631 | grpc.FailOnNonTempDialError(true), |
| 632 | grpc.WithContextDialer(dialer), |
| 633 | grpc.WithInsecure(), // we are handling TLS, so tell grpc not to |
| 634 | ) |
| 635 | conn, err := grpc.DialContext(ctx, address, opts...) |
| 636 | var res interface{} |
| 637 | if err != nil { |
| 638 | res = err |
| 639 | } else { |
| 640 | res = conn |
| 641 | } |
| 642 | writeResult(res) |
| 643 | }() |
| 644 | |
| 645 | select { |
| 646 | case res := <-result: |
| 647 | if conn, ok := res.(*grpc.ClientConn); ok { |
| 648 | return conn, nil |
| 649 | } |
| 650 | return nil, res.(error) |
| 651 | case <-ctx.Done(): |
| 652 | return nil, ctx.Err() |
| 653 | } |
| 654 | } |