blob: 3fca3eb0f0dfc563a4d9f4108d2a0ea2ab6709ea [file] [log] [blame]
khenaidoof3333552021-12-15 16:52:31 -05001package grpcreflect
2
3import (
4 "bytes"
5 "fmt"
6 "io"
7 "reflect"
8 "runtime"
9 "sync"
10
11 "github.com/golang/protobuf/proto"
12 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
13 "golang.org/x/net/context"
14 "google.golang.org/grpc/codes"
15 rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
16 "google.golang.org/grpc/status"
17
18 "github.com/jhump/protoreflect/desc"
19 "github.com/jhump/protoreflect/internal"
20)
21
22// elementNotFoundError is the error returned by reflective operations where the
23// server does not recognize a given file name, symbol name, or extension.
24type elementNotFoundError struct {
25 name string
26 kind elementKind
27 symType symbolType // only used when kind == elementKindSymbol
28 tag int32 // only used when kind == elementKindExtension
29
30 // only errors with a kind of elementKindFile will have a cause, which means
31 // the named file count not be resolved because of a dependency that could
32 // not be found where cause describes the missing dependency
33 cause *elementNotFoundError
34}
35
36type elementKind int
37
38const (
39 elementKindSymbol elementKind = iota
40 elementKindFile
41 elementKindExtension
42)
43
44type symbolType string
45
46const (
47 symbolTypeService = "Service"
48 symbolTypeMessage = "Message"
49 symbolTypeEnum = "Enum"
50 symbolTypeUnknown = "Symbol"
51)
52
53func symbolNotFound(symbol string, symType symbolType, cause *elementNotFoundError) error {
54 return &elementNotFoundError{name: symbol, symType: symType, kind: elementKindSymbol, cause: cause}
55}
56
57func extensionNotFound(extendee string, tag int32, cause *elementNotFoundError) error {
58 return &elementNotFoundError{name: extendee, tag: tag, kind: elementKindExtension, cause: cause}
59}
60
61func fileNotFound(file string, cause *elementNotFoundError) error {
62 return &elementNotFoundError{name: file, kind: elementKindFile, cause: cause}
63}
64
65func (e *elementNotFoundError) Error() string {
66 first := true
67 var b bytes.Buffer
68 for ; e != nil; e = e.cause {
69 if first {
70 first = false
71 } else {
72 fmt.Fprint(&b, "\ncaused by: ")
73 }
74 switch e.kind {
75 case elementKindSymbol:
76 fmt.Fprintf(&b, "%s not found: %s", e.symType, e.name)
77 case elementKindExtension:
78 fmt.Fprintf(&b, "Extension not found: tag %d for %s", e.tag, e.name)
79 default:
80 fmt.Fprintf(&b, "File not found: %s", e.name)
81 }
82 }
83 return b.String()
84}
85
86// IsElementNotFoundError determines if the given error indicates that a file
87// name, symbol name, or extension field was could not be found by the server.
88func IsElementNotFoundError(err error) bool {
89 _, ok := err.(*elementNotFoundError)
90 return ok
91}
92
93// ProtocolError is an error returned when the server sends a response of the
94// wrong type.
95type ProtocolError struct {
96 missingType reflect.Type
97}
98
99func (p ProtocolError) Error() string {
100 return fmt.Sprintf("Protocol error: response was missing %v", p.missingType)
101}
102
103type extDesc struct {
104 extendedMessageName string
105 extensionNumber int32
106}
107
108// Client is a client connection to a server for performing reflection calls
109// and resolving remote symbols.
110type Client struct {
111 ctx context.Context
112 stub rpb.ServerReflectionClient
113
114 connMu sync.Mutex
115 cancel context.CancelFunc
116 stream rpb.ServerReflection_ServerReflectionInfoClient
117
118 cacheMu sync.RWMutex
119 protosByName map[string]*dpb.FileDescriptorProto
120 filesByName map[string]*desc.FileDescriptor
121 filesBySymbol map[string]*desc.FileDescriptor
122 filesByExtension map[extDesc]*desc.FileDescriptor
123}
124
125// NewClient creates a new Client with the given root context and using the
126// given RPC stub for talking to the server.
127func NewClient(ctx context.Context, stub rpb.ServerReflectionClient) *Client {
128 cr := &Client{
129 ctx: ctx,
130 stub: stub,
131 protosByName: map[string]*dpb.FileDescriptorProto{},
132 filesByName: map[string]*desc.FileDescriptor{},
133 filesBySymbol: map[string]*desc.FileDescriptor{},
134 filesByExtension: map[extDesc]*desc.FileDescriptor{},
135 }
136 // don't leak a grpc stream
137 runtime.SetFinalizer(cr, (*Client).Reset)
138 return cr
139}
140
141// FileByFilename asks the server for a file descriptor for the proto file with
142// the given name.
143func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) {
144 // hit the cache first
145 cr.cacheMu.RLock()
146 if fd, ok := cr.filesByName[filename]; ok {
147 cr.cacheMu.RUnlock()
148 return fd, nil
149 }
150 fdp, ok := cr.protosByName[filename]
151 cr.cacheMu.RUnlock()
152 // not there? see if we've downloaded the proto
153 if ok {
154 return cr.descriptorFromProto(fdp)
155 }
156
157 req := &rpb.ServerReflectionRequest{
158 MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
159 FileByFilename: filename,
160 },
161 }
162 fd, err := cr.getAndCacheFileDescriptors(req, filename, "")
163 if isNotFound(err) {
164 // file not found? see if we can look up via alternate name
165 if alternate, ok := internal.StdFileAliases[filename]; ok {
166 req := &rpb.ServerReflectionRequest{
167 MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
168 FileByFilename: alternate,
169 },
170 }
171 fd, err = cr.getAndCacheFileDescriptors(req, alternate, filename)
172 if isNotFound(err) {
173 err = fileNotFound(filename, nil)
174 }
175 } else {
176 err = fileNotFound(filename, nil)
177 }
178 } else if e, ok := err.(*elementNotFoundError); ok {
179 err = fileNotFound(filename, e)
180 }
181 return fd, err
182}
183
184// FileContainingSymbol asks the server for a file descriptor for the proto file
185// that declares the given fully-qualified symbol.
186func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, error) {
187 // hit the cache first
188 cr.cacheMu.RLock()
189 fd, ok := cr.filesBySymbol[symbol]
190 cr.cacheMu.RUnlock()
191 if ok {
192 return fd, nil
193 }
194
195 req := &rpb.ServerReflectionRequest{
196 MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{
197 FileContainingSymbol: symbol,
198 },
199 }
200 fd, err := cr.getAndCacheFileDescriptors(req, "", "")
201 if isNotFound(err) {
202 err = symbolNotFound(symbol, symbolTypeUnknown, nil)
203 } else if e, ok := err.(*elementNotFoundError); ok {
204 err = symbolNotFound(symbol, symbolTypeUnknown, e)
205 }
206 return fd, err
207}
208
209// FileContainingExtension asks the server for a file descriptor for the proto
210// file that declares an extension with the given number for the given
211// fully-qualified message name.
212func (cr *Client) FileContainingExtension(extendedMessageName string, extensionNumber int32) (*desc.FileDescriptor, error) {
213 // hit the cache first
214 cr.cacheMu.RLock()
215 fd, ok := cr.filesByExtension[extDesc{extendedMessageName, extensionNumber}]
216 cr.cacheMu.RUnlock()
217 if ok {
218 return fd, nil
219 }
220
221 req := &rpb.ServerReflectionRequest{
222 MessageRequest: &rpb.ServerReflectionRequest_FileContainingExtension{
223 FileContainingExtension: &rpb.ExtensionRequest{
224 ContainingType: extendedMessageName,
225 ExtensionNumber: extensionNumber,
226 },
227 },
228 }
229 fd, err := cr.getAndCacheFileDescriptors(req, "", "")
230 if isNotFound(err) {
231 err = extensionNotFound(extendedMessageName, extensionNumber, nil)
232 } else if e, ok := err.(*elementNotFoundError); ok {
233 err = extensionNotFound(extendedMessageName, extensionNumber, e)
234 }
235 return fd, err
236}
237
238func (cr *Client) getAndCacheFileDescriptors(req *rpb.ServerReflectionRequest, expectedName, alias string) (*desc.FileDescriptor, error) {
239 resp, err := cr.send(req)
240 if err != nil {
241 return nil, err
242 }
243
244 fdResp := resp.GetFileDescriptorResponse()
245 if fdResp == nil {
246 return nil, &ProtocolError{reflect.TypeOf(fdResp).Elem()}
247 }
248
249 // Response can contain the result file descriptor, but also its transitive
250 // deps. Furthermore, protocol states that subsequent requests do not need
251 // to send transitive deps that have been sent in prior responses. So we
252 // need to cache all file descriptors and then return the first one (which
253 // should be the answer). If we're looking for a file by name, we can be
254 // smarter and make sure to grab one by name instead of just grabbing the
255 // first one.
256 var firstFd *dpb.FileDescriptorProto
257 for _, fdBytes := range fdResp.FileDescriptorProto {
258 fd := &dpb.FileDescriptorProto{}
259 if err = proto.Unmarshal(fdBytes, fd); err != nil {
260 return nil, err
261 }
262
263 if expectedName != "" && alias != "" && expectedName != alias && fd.GetName() == expectedName {
264 // we found a file was aliased, so we need to update the proto to reflect that
265 fd.Name = proto.String(alias)
266 }
267
268 cr.cacheMu.Lock()
269 // see if this file was created and cached concurrently
270 if firstFd == nil {
271 if d, ok := cr.filesByName[fd.GetName()]; ok {
272 cr.cacheMu.Unlock()
273 return d, nil
274 }
275 }
276 // store in cache of raw descriptor protos, but don't overwrite existing protos
277 if existingFd, ok := cr.protosByName[fd.GetName()]; ok {
278 fd = existingFd
279 } else {
280 cr.protosByName[fd.GetName()] = fd
281 }
282 cr.cacheMu.Unlock()
283 if firstFd == nil {
284 firstFd = fd
285 }
286 }
287 if firstFd == nil {
288 return nil, &ProtocolError{reflect.TypeOf(firstFd).Elem()}
289 }
290
291 return cr.descriptorFromProto(firstFd)
292}
293
294func (cr *Client) descriptorFromProto(fd *dpb.FileDescriptorProto) (*desc.FileDescriptor, error) {
295 deps := make([]*desc.FileDescriptor, len(fd.GetDependency()))
296 for i, depName := range fd.GetDependency() {
297 if dep, err := cr.FileByFilename(depName); err != nil {
298 return nil, err
299 } else {
300 deps[i] = dep
301 }
302 }
303 d, err := desc.CreateFileDescriptor(fd, deps...)
304 if err != nil {
305 return nil, err
306 }
307 d = cr.cacheFile(d)
308 return d, nil
309}
310
311func (cr *Client) cacheFile(fd *desc.FileDescriptor) *desc.FileDescriptor {
312 cr.cacheMu.Lock()
313 defer cr.cacheMu.Unlock()
314
315 // cache file descriptor by name, but don't overwrite existing entry
316 // (existing entry could come from concurrent caller)
317 if existingFd, ok := cr.filesByName[fd.GetName()]; ok {
318 return existingFd
319 }
320 cr.filesByName[fd.GetName()] = fd
321
322 // also cache by symbols and extensions
323 for _, m := range fd.GetMessageTypes() {
324 cr.cacheMessageLocked(fd, m)
325 }
326 for _, e := range fd.GetEnumTypes() {
327 cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
328 for _, v := range e.GetValues() {
329 cr.filesBySymbol[v.GetFullyQualifiedName()] = fd
330 }
331 }
332 for _, e := range fd.GetExtensions() {
333 cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
334 cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd
335 }
336 for _, s := range fd.GetServices() {
337 cr.filesBySymbol[s.GetFullyQualifiedName()] = fd
338 for _, m := range s.GetMethods() {
339 cr.filesBySymbol[m.GetFullyQualifiedName()] = fd
340 }
341 }
342
343 return fd
344}
345
346func (cr *Client) cacheMessageLocked(fd *desc.FileDescriptor, md *desc.MessageDescriptor) {
347 cr.filesBySymbol[md.GetFullyQualifiedName()] = fd
348 for _, f := range md.GetFields() {
349 cr.filesBySymbol[f.GetFullyQualifiedName()] = fd
350 }
351 for _, o := range md.GetOneOfs() {
352 cr.filesBySymbol[o.GetFullyQualifiedName()] = fd
353 }
354 for _, e := range md.GetNestedEnumTypes() {
355 cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
356 for _, v := range e.GetValues() {
357 cr.filesBySymbol[v.GetFullyQualifiedName()] = fd
358 }
359 }
360 for _, e := range md.GetNestedExtensions() {
361 cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
362 cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd
363 }
364 for _, m := range md.GetNestedMessageTypes() {
365 cr.cacheMessageLocked(fd, m) // recurse
366 }
367}
368
369// AllExtensionNumbersForType asks the server for all known extension numbers
370// for the given fully-qualified message name.
371func (cr *Client) AllExtensionNumbersForType(extendedMessageName string) ([]int32, error) {
372 req := &rpb.ServerReflectionRequest{
373 MessageRequest: &rpb.ServerReflectionRequest_AllExtensionNumbersOfType{
374 AllExtensionNumbersOfType: extendedMessageName,
375 },
376 }
377 resp, err := cr.send(req)
378 if err != nil {
379 if isNotFound(err) {
380 return nil, symbolNotFound(extendedMessageName, symbolTypeMessage, nil)
381 }
382 return nil, err
383 }
384
385 extResp := resp.GetAllExtensionNumbersResponse()
386 if extResp == nil {
387 return nil, &ProtocolError{reflect.TypeOf(extResp).Elem()}
388 }
389 return extResp.ExtensionNumber, nil
390}
391
392// ListServices asks the server for the fully-qualified names of all exposed
393// services.
394func (cr *Client) ListServices() ([]string, error) {
395 req := &rpb.ServerReflectionRequest{
396 MessageRequest: &rpb.ServerReflectionRequest_ListServices{
397 // proto doesn't indicate any purpose for this value and server impl
398 // doesn't actually use it...
399 ListServices: "*",
400 },
401 }
402 resp, err := cr.send(req)
403 if err != nil {
404 return nil, err
405 }
406
407 listResp := resp.GetListServicesResponse()
408 if listResp == nil {
409 return nil, &ProtocolError{reflect.TypeOf(listResp).Elem()}
410 }
411 serviceNames := make([]string, len(listResp.Service))
412 for i, s := range listResp.Service {
413 serviceNames[i] = s.Name
414 }
415 return serviceNames, nil
416}
417
418func (cr *Client) send(req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) {
419 // we allow one immediate retry, in case we have a stale stream
420 // (e.g. closed by server)
421 resp, err := cr.doSend(true, req)
422 if err != nil {
423 return nil, err
424 }
425
426 // convert error response messages into errors
427 errResp := resp.GetErrorResponse()
428 if errResp != nil {
429 return nil, status.Errorf(codes.Code(errResp.ErrorCode), "%s", errResp.ErrorMessage)
430 }
431
432 return resp, nil
433}
434
435func isNotFound(err error) bool {
436 if err == nil {
437 return false
438 }
439 s, ok := status.FromError(err)
440 return ok && s.Code() == codes.NotFound
441}
442
443func (cr *Client) doSend(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) {
444 // TODO: Streams are thread-safe, so we shouldn't need to lock. But without locking, we'll need more machinery
445 // (goroutines and channels) to ensure that responses are correctly correlated with their requests and thus
446 // delivered in correct oder.
447 cr.connMu.Lock()
448 defer cr.connMu.Unlock()
449 return cr.doSendLocked(retry, req)
450}
451
452func (cr *Client) doSendLocked(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) {
453 if err := cr.initStreamLocked(); err != nil {
454 return nil, err
455 }
456
457 if err := cr.stream.Send(req); err != nil {
458 if err == io.EOF {
459 // if send returns EOF, must call Recv to get real underlying error
460 _, err = cr.stream.Recv()
461 }
462 cr.resetLocked()
463 if retry {
464 return cr.doSendLocked(false, req)
465 }
466 return nil, err
467 }
468
469 if resp, err := cr.stream.Recv(); err != nil {
470 cr.resetLocked()
471 if retry {
472 return cr.doSendLocked(false, req)
473 }
474 return nil, err
475 } else {
476 return resp, nil
477 }
478}
479
480func (cr *Client) initStreamLocked() error {
481 if cr.stream != nil {
482 return nil
483 }
484 var newCtx context.Context
485 newCtx, cr.cancel = context.WithCancel(cr.ctx)
486 var err error
487 cr.stream, err = cr.stub.ServerReflectionInfo(newCtx)
488 return err
489}
490
491// Reset ensures that any active stream with the server is closed, releasing any
492// resources.
493func (cr *Client) Reset() {
494 cr.connMu.Lock()
495 defer cr.connMu.Unlock()
496 cr.resetLocked()
497}
498
499func (cr *Client) resetLocked() {
500 if cr.stream != nil {
501 cr.stream.CloseSend()
502 for {
503 // drain the stream, this covers io.EOF too
504 if _, err := cr.stream.Recv(); err != nil {
505 break
506 }
507 }
508 cr.stream = nil
509 }
510 if cr.cancel != nil {
511 cr.cancel()
512 cr.cancel = nil
513 }
514}
515
516// ResolveService asks the server to resolve the given fully-qualified service
517// name into a service descriptor.
518func (cr *Client) ResolveService(serviceName string) (*desc.ServiceDescriptor, error) {
519 file, err := cr.FileContainingSymbol(serviceName)
520 if err != nil {
521 return nil, setSymbolType(err, serviceName, symbolTypeService)
522 }
523 d := file.FindSymbol(serviceName)
524 if d == nil {
525 return nil, symbolNotFound(serviceName, symbolTypeService, nil)
526 }
527 if s, ok := d.(*desc.ServiceDescriptor); ok {
528 return s, nil
529 } else {
530 return nil, symbolNotFound(serviceName, symbolTypeService, nil)
531 }
532}
533
534// ResolveMessage asks the server to resolve the given fully-qualified message
535// name into a message descriptor.
536func (cr *Client) ResolveMessage(messageName string) (*desc.MessageDescriptor, error) {
537 file, err := cr.FileContainingSymbol(messageName)
538 if err != nil {
539 return nil, setSymbolType(err, messageName, symbolTypeMessage)
540 }
541 d := file.FindSymbol(messageName)
542 if d == nil {
543 return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
544 }
545 if s, ok := d.(*desc.MessageDescriptor); ok {
546 return s, nil
547 } else {
548 return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
549 }
550}
551
552// ResolveEnum asks the server to resolve the given fully-qualified enum name
553// into an enum descriptor.
554func (cr *Client) ResolveEnum(enumName string) (*desc.EnumDescriptor, error) {
555 file, err := cr.FileContainingSymbol(enumName)
556 if err != nil {
557 return nil, setSymbolType(err, enumName, symbolTypeEnum)
558 }
559 d := file.FindSymbol(enumName)
560 if d == nil {
561 return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
562 }
563 if s, ok := d.(*desc.EnumDescriptor); ok {
564 return s, nil
565 } else {
566 return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
567 }
568}
569
570func setSymbolType(err error, name string, symType symbolType) error {
571 if e, ok := err.(*elementNotFoundError); ok {
572 if e.kind == elementKindSymbol && e.name == name && e.symType == symbolTypeUnknown {
573 e.symType = symType
574 }
575 }
576 return err
577}
578
579// ResolveEnumValues asks the server to resolve the given fully-qualified enum
580// name into a map of names to numbers that represents the enum's values.
581func (cr *Client) ResolveEnumValues(enumName string) (map[string]int32, error) {
582 enumDesc, err := cr.ResolveEnum(enumName)
583 if err != nil {
584 return nil, err
585 }
586 vals := map[string]int32{}
587 for _, valDesc := range enumDesc.GetValues() {
588 vals[valDesc.GetName()] = valDesc.GetNumber()
589 }
590 return vals, nil
591}
592
593// ResolveExtension asks the server to resolve the given extension number and
594// fully-qualified message name into a field descriptor.
595func (cr *Client) ResolveExtension(extendedType string, extensionNumber int32) (*desc.FieldDescriptor, error) {
596 file, err := cr.FileContainingExtension(extendedType, extensionNumber)
597 if err != nil {
598 return nil, err
599 }
600 d := findExtension(extendedType, extensionNumber, fileDescriptorExtensions{file})
601 if d == nil {
602 return nil, extensionNotFound(extendedType, extensionNumber, nil)
603 } else {
604 return d, nil
605 }
606}
607
608func findExtension(extendedType string, extensionNumber int32, scope extensionScope) *desc.FieldDescriptor {
609 // search extensions in this scope
610 for _, ext := range scope.extensions() {
611 if ext.GetNumber() == extensionNumber && ext.GetOwner().GetFullyQualifiedName() == extendedType {
612 return ext
613 }
614 }
615
616 // if not found, search nested scopes
617 for _, nested := range scope.nestedScopes() {
618 ext := findExtension(extendedType, extensionNumber, nested)
619 if ext != nil {
620 return ext
621 }
622 }
623
624 return nil
625}
626
627type extensionScope interface {
628 extensions() []*desc.FieldDescriptor
629 nestedScopes() []extensionScope
630}
631
632// fileDescriptorExtensions implements extensionHolder interface on top of
633// FileDescriptorProto
634type fileDescriptorExtensions struct {
635 proto *desc.FileDescriptor
636}
637
638func (fde fileDescriptorExtensions) extensions() []*desc.FieldDescriptor {
639 return fde.proto.GetExtensions()
640}
641
642func (fde fileDescriptorExtensions) nestedScopes() []extensionScope {
643 scopes := make([]extensionScope, len(fde.proto.GetMessageTypes()))
644 for i, m := range fde.proto.GetMessageTypes() {
645 scopes[i] = msgDescriptorExtensions{m}
646 }
647 return scopes
648}
649
650// msgDescriptorExtensions implements extensionHolder interface on top of
651// DescriptorProto
652type msgDescriptorExtensions struct {
653 proto *desc.MessageDescriptor
654}
655
656func (mde msgDescriptorExtensions) extensions() []*desc.FieldDescriptor {
657 return mde.proto.GetNestedExtensions()
658}
659
660func (mde msgDescriptorExtensions) nestedScopes() []extensionScope {
661 scopes := make([]extensionScope, len(mde.proto.GetNestedMessageTypes()))
662 for i, m := range mde.proto.GetNestedMessageTypes() {
663 scopes[i] = msgDescriptorExtensions{m}
664 }
665 return scopes
666}