| package grpcreflect |
| |
| import ( |
| "bytes" |
| "fmt" |
| "io" |
| "reflect" |
| "runtime" |
| "sync" |
| |
| "github.com/golang/protobuf/proto" |
| dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" |
| "golang.org/x/net/context" |
| "google.golang.org/grpc/codes" |
| rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" |
| "google.golang.org/grpc/status" |
| |
| "github.com/jhump/protoreflect/desc" |
| "github.com/jhump/protoreflect/internal" |
| ) |
| |
| // elementNotFoundError is the error returned by reflective operations where the |
| // server does not recognize a given file name, symbol name, or extension. |
| type elementNotFoundError struct { |
| name string |
| kind elementKind |
| symType symbolType // only used when kind == elementKindSymbol |
| tag int32 // only used when kind == elementKindExtension |
| |
| // only errors with a kind of elementKindFile will have a cause, which means |
| // the named file count not be resolved because of a dependency that could |
| // not be found where cause describes the missing dependency |
| cause *elementNotFoundError |
| } |
| |
| type elementKind int |
| |
| const ( |
| elementKindSymbol elementKind = iota |
| elementKindFile |
| elementKindExtension |
| ) |
| |
| type symbolType string |
| |
| const ( |
| symbolTypeService = "Service" |
| symbolTypeMessage = "Message" |
| symbolTypeEnum = "Enum" |
| symbolTypeUnknown = "Symbol" |
| ) |
| |
| func symbolNotFound(symbol string, symType symbolType, cause *elementNotFoundError) error { |
| return &elementNotFoundError{name: symbol, symType: symType, kind: elementKindSymbol, cause: cause} |
| } |
| |
| func extensionNotFound(extendee string, tag int32, cause *elementNotFoundError) error { |
| return &elementNotFoundError{name: extendee, tag: tag, kind: elementKindExtension, cause: cause} |
| } |
| |
| func fileNotFound(file string, cause *elementNotFoundError) error { |
| return &elementNotFoundError{name: file, kind: elementKindFile, cause: cause} |
| } |
| |
| func (e *elementNotFoundError) Error() string { |
| first := true |
| var b bytes.Buffer |
| for ; e != nil; e = e.cause { |
| if first { |
| first = false |
| } else { |
| fmt.Fprint(&b, "\ncaused by: ") |
| } |
| switch e.kind { |
| case elementKindSymbol: |
| fmt.Fprintf(&b, "%s not found: %s", e.symType, e.name) |
| case elementKindExtension: |
| fmt.Fprintf(&b, "Extension not found: tag %d for %s", e.tag, e.name) |
| default: |
| fmt.Fprintf(&b, "File not found: %s", e.name) |
| } |
| } |
| return b.String() |
| } |
| |
| // IsElementNotFoundError determines if the given error indicates that a file |
| // name, symbol name, or extension field was could not be found by the server. |
| func IsElementNotFoundError(err error) bool { |
| _, ok := err.(*elementNotFoundError) |
| return ok |
| } |
| |
| // ProtocolError is an error returned when the server sends a response of the |
| // wrong type. |
| type ProtocolError struct { |
| missingType reflect.Type |
| } |
| |
| func (p ProtocolError) Error() string { |
| return fmt.Sprintf("Protocol error: response was missing %v", p.missingType) |
| } |
| |
| type extDesc struct { |
| extendedMessageName string |
| extensionNumber int32 |
| } |
| |
| // Client is a client connection to a server for performing reflection calls |
| // and resolving remote symbols. |
| type Client struct { |
| ctx context.Context |
| stub rpb.ServerReflectionClient |
| |
| connMu sync.Mutex |
| cancel context.CancelFunc |
| stream rpb.ServerReflection_ServerReflectionInfoClient |
| |
| cacheMu sync.RWMutex |
| protosByName map[string]*dpb.FileDescriptorProto |
| filesByName map[string]*desc.FileDescriptor |
| filesBySymbol map[string]*desc.FileDescriptor |
| filesByExtension map[extDesc]*desc.FileDescriptor |
| } |
| |
| // NewClient creates a new Client with the given root context and using the |
| // given RPC stub for talking to the server. |
| func NewClient(ctx context.Context, stub rpb.ServerReflectionClient) *Client { |
| cr := &Client{ |
| ctx: ctx, |
| stub: stub, |
| protosByName: map[string]*dpb.FileDescriptorProto{}, |
| filesByName: map[string]*desc.FileDescriptor{}, |
| filesBySymbol: map[string]*desc.FileDescriptor{}, |
| filesByExtension: map[extDesc]*desc.FileDescriptor{}, |
| } |
| // don't leak a grpc stream |
| runtime.SetFinalizer(cr, (*Client).Reset) |
| return cr |
| } |
| |
| // FileByFilename asks the server for a file descriptor for the proto file with |
| // the given name. |
| func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) { |
| // hit the cache first |
| cr.cacheMu.RLock() |
| if fd, ok := cr.filesByName[filename]; ok { |
| cr.cacheMu.RUnlock() |
| return fd, nil |
| } |
| fdp, ok := cr.protosByName[filename] |
| cr.cacheMu.RUnlock() |
| // not there? see if we've downloaded the proto |
| if ok { |
| return cr.descriptorFromProto(fdp) |
| } |
| |
| req := &rpb.ServerReflectionRequest{ |
| MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{ |
| FileByFilename: filename, |
| }, |
| } |
| fd, err := cr.getAndCacheFileDescriptors(req, filename, "") |
| if isNotFound(err) { |
| // file not found? see if we can look up via alternate name |
| if alternate, ok := internal.StdFileAliases[filename]; ok { |
| req := &rpb.ServerReflectionRequest{ |
| MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{ |
| FileByFilename: alternate, |
| }, |
| } |
| fd, err = cr.getAndCacheFileDescriptors(req, alternate, filename) |
| if isNotFound(err) { |
| err = fileNotFound(filename, nil) |
| } |
| } else { |
| err = fileNotFound(filename, nil) |
| } |
| } else if e, ok := err.(*elementNotFoundError); ok { |
| err = fileNotFound(filename, e) |
| } |
| return fd, err |
| } |
| |
| // FileContainingSymbol asks the server for a file descriptor for the proto file |
| // that declares the given fully-qualified symbol. |
| func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, error) { |
| // hit the cache first |
| cr.cacheMu.RLock() |
| fd, ok := cr.filesBySymbol[symbol] |
| cr.cacheMu.RUnlock() |
| if ok { |
| return fd, nil |
| } |
| |
| req := &rpb.ServerReflectionRequest{ |
| MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{ |
| FileContainingSymbol: symbol, |
| }, |
| } |
| fd, err := cr.getAndCacheFileDescriptors(req, "", "") |
| if isNotFound(err) { |
| err = symbolNotFound(symbol, symbolTypeUnknown, nil) |
| } else if e, ok := err.(*elementNotFoundError); ok { |
| err = symbolNotFound(symbol, symbolTypeUnknown, e) |
| } |
| return fd, err |
| } |
| |
| // FileContainingExtension asks the server for a file descriptor for the proto |
| // file that declares an extension with the given number for the given |
| // fully-qualified message name. |
| func (cr *Client) FileContainingExtension(extendedMessageName string, extensionNumber int32) (*desc.FileDescriptor, error) { |
| // hit the cache first |
| cr.cacheMu.RLock() |
| fd, ok := cr.filesByExtension[extDesc{extendedMessageName, extensionNumber}] |
| cr.cacheMu.RUnlock() |
| if ok { |
| return fd, nil |
| } |
| |
| req := &rpb.ServerReflectionRequest{ |
| MessageRequest: &rpb.ServerReflectionRequest_FileContainingExtension{ |
| FileContainingExtension: &rpb.ExtensionRequest{ |
| ContainingType: extendedMessageName, |
| ExtensionNumber: extensionNumber, |
| }, |
| }, |
| } |
| fd, err := cr.getAndCacheFileDescriptors(req, "", "") |
| if isNotFound(err) { |
| err = extensionNotFound(extendedMessageName, extensionNumber, nil) |
| } else if e, ok := err.(*elementNotFoundError); ok { |
| err = extensionNotFound(extendedMessageName, extensionNumber, e) |
| } |
| return fd, err |
| } |
| |
| func (cr *Client) getAndCacheFileDescriptors(req *rpb.ServerReflectionRequest, expectedName, alias string) (*desc.FileDescriptor, error) { |
| resp, err := cr.send(req) |
| if err != nil { |
| return nil, err |
| } |
| |
| fdResp := resp.GetFileDescriptorResponse() |
| if fdResp == nil { |
| return nil, &ProtocolError{reflect.TypeOf(fdResp).Elem()} |
| } |
| |
| // Response can contain the result file descriptor, but also its transitive |
| // deps. Furthermore, protocol states that subsequent requests do not need |
| // to send transitive deps that have been sent in prior responses. So we |
| // need to cache all file descriptors and then return the first one (which |
| // should be the answer). If we're looking for a file by name, we can be |
| // smarter and make sure to grab one by name instead of just grabbing the |
| // first one. |
| var firstFd *dpb.FileDescriptorProto |
| for _, fdBytes := range fdResp.FileDescriptorProto { |
| fd := &dpb.FileDescriptorProto{} |
| if err = proto.Unmarshal(fdBytes, fd); err != nil { |
| return nil, err |
| } |
| |
| if expectedName != "" && alias != "" && expectedName != alias && fd.GetName() == expectedName { |
| // we found a file was aliased, so we need to update the proto to reflect that |
| fd.Name = proto.String(alias) |
| } |
| |
| cr.cacheMu.Lock() |
| // see if this file was created and cached concurrently |
| if firstFd == nil { |
| if d, ok := cr.filesByName[fd.GetName()]; ok { |
| cr.cacheMu.Unlock() |
| return d, nil |
| } |
| } |
| // store in cache of raw descriptor protos, but don't overwrite existing protos |
| if existingFd, ok := cr.protosByName[fd.GetName()]; ok { |
| fd = existingFd |
| } else { |
| cr.protosByName[fd.GetName()] = fd |
| } |
| cr.cacheMu.Unlock() |
| if firstFd == nil { |
| firstFd = fd |
| } |
| } |
| if firstFd == nil { |
| return nil, &ProtocolError{reflect.TypeOf(firstFd).Elem()} |
| } |
| |
| return cr.descriptorFromProto(firstFd) |
| } |
| |
| func (cr *Client) descriptorFromProto(fd *dpb.FileDescriptorProto) (*desc.FileDescriptor, error) { |
| deps := make([]*desc.FileDescriptor, len(fd.GetDependency())) |
| for i, depName := range fd.GetDependency() { |
| if dep, err := cr.FileByFilename(depName); err != nil { |
| return nil, err |
| } else { |
| deps[i] = dep |
| } |
| } |
| d, err := desc.CreateFileDescriptor(fd, deps...) |
| if err != nil { |
| return nil, err |
| } |
| d = cr.cacheFile(d) |
| return d, nil |
| } |
| |
| func (cr *Client) cacheFile(fd *desc.FileDescriptor) *desc.FileDescriptor { |
| cr.cacheMu.Lock() |
| defer cr.cacheMu.Unlock() |
| |
| // cache file descriptor by name, but don't overwrite existing entry |
| // (existing entry could come from concurrent caller) |
| if existingFd, ok := cr.filesByName[fd.GetName()]; ok { |
| return existingFd |
| } |
| cr.filesByName[fd.GetName()] = fd |
| |
| // also cache by symbols and extensions |
| for _, m := range fd.GetMessageTypes() { |
| cr.cacheMessageLocked(fd, m) |
| } |
| for _, e := range fd.GetEnumTypes() { |
| cr.filesBySymbol[e.GetFullyQualifiedName()] = fd |
| for _, v := range e.GetValues() { |
| cr.filesBySymbol[v.GetFullyQualifiedName()] = fd |
| } |
| } |
| for _, e := range fd.GetExtensions() { |
| cr.filesBySymbol[e.GetFullyQualifiedName()] = fd |
| cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd |
| } |
| for _, s := range fd.GetServices() { |
| cr.filesBySymbol[s.GetFullyQualifiedName()] = fd |
| for _, m := range s.GetMethods() { |
| cr.filesBySymbol[m.GetFullyQualifiedName()] = fd |
| } |
| } |
| |
| return fd |
| } |
| |
| func (cr *Client) cacheMessageLocked(fd *desc.FileDescriptor, md *desc.MessageDescriptor) { |
| cr.filesBySymbol[md.GetFullyQualifiedName()] = fd |
| for _, f := range md.GetFields() { |
| cr.filesBySymbol[f.GetFullyQualifiedName()] = fd |
| } |
| for _, o := range md.GetOneOfs() { |
| cr.filesBySymbol[o.GetFullyQualifiedName()] = fd |
| } |
| for _, e := range md.GetNestedEnumTypes() { |
| cr.filesBySymbol[e.GetFullyQualifiedName()] = fd |
| for _, v := range e.GetValues() { |
| cr.filesBySymbol[v.GetFullyQualifiedName()] = fd |
| } |
| } |
| for _, e := range md.GetNestedExtensions() { |
| cr.filesBySymbol[e.GetFullyQualifiedName()] = fd |
| cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd |
| } |
| for _, m := range md.GetNestedMessageTypes() { |
| cr.cacheMessageLocked(fd, m) // recurse |
| } |
| } |
| |
| // AllExtensionNumbersForType asks the server for all known extension numbers |
| // for the given fully-qualified message name. |
| func (cr *Client) AllExtensionNumbersForType(extendedMessageName string) ([]int32, error) { |
| req := &rpb.ServerReflectionRequest{ |
| MessageRequest: &rpb.ServerReflectionRequest_AllExtensionNumbersOfType{ |
| AllExtensionNumbersOfType: extendedMessageName, |
| }, |
| } |
| resp, err := cr.send(req) |
| if err != nil { |
| if isNotFound(err) { |
| return nil, symbolNotFound(extendedMessageName, symbolTypeMessage, nil) |
| } |
| return nil, err |
| } |
| |
| extResp := resp.GetAllExtensionNumbersResponse() |
| if extResp == nil { |
| return nil, &ProtocolError{reflect.TypeOf(extResp).Elem()} |
| } |
| return extResp.ExtensionNumber, nil |
| } |
| |
| // ListServices asks the server for the fully-qualified names of all exposed |
| // services. |
| func (cr *Client) ListServices() ([]string, error) { |
| req := &rpb.ServerReflectionRequest{ |
| MessageRequest: &rpb.ServerReflectionRequest_ListServices{ |
| // proto doesn't indicate any purpose for this value and server impl |
| // doesn't actually use it... |
| ListServices: "*", |
| }, |
| } |
| resp, err := cr.send(req) |
| if err != nil { |
| return nil, err |
| } |
| |
| listResp := resp.GetListServicesResponse() |
| if listResp == nil { |
| return nil, &ProtocolError{reflect.TypeOf(listResp).Elem()} |
| } |
| serviceNames := make([]string, len(listResp.Service)) |
| for i, s := range listResp.Service { |
| serviceNames[i] = s.Name |
| } |
| return serviceNames, nil |
| } |
| |
| func (cr *Client) send(req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) { |
| // we allow one immediate retry, in case we have a stale stream |
| // (e.g. closed by server) |
| resp, err := cr.doSend(true, req) |
| if err != nil { |
| return nil, err |
| } |
| |
| // convert error response messages into errors |
| errResp := resp.GetErrorResponse() |
| if errResp != nil { |
| return nil, status.Errorf(codes.Code(errResp.ErrorCode), "%s", errResp.ErrorMessage) |
| } |
| |
| return resp, nil |
| } |
| |
| func isNotFound(err error) bool { |
| if err == nil { |
| return false |
| } |
| s, ok := status.FromError(err) |
| return ok && s.Code() == codes.NotFound |
| } |
| |
| func (cr *Client) doSend(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) { |
| // TODO: Streams are thread-safe, so we shouldn't need to lock. But without locking, we'll need more machinery |
| // (goroutines and channels) to ensure that responses are correctly correlated with their requests and thus |
| // delivered in correct oder. |
| cr.connMu.Lock() |
| defer cr.connMu.Unlock() |
| return cr.doSendLocked(retry, req) |
| } |
| |
| func (cr *Client) doSendLocked(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) { |
| if err := cr.initStreamLocked(); err != nil { |
| return nil, err |
| } |
| |
| if err := cr.stream.Send(req); err != nil { |
| if err == io.EOF { |
| // if send returns EOF, must call Recv to get real underlying error |
| _, err = cr.stream.Recv() |
| } |
| cr.resetLocked() |
| if retry { |
| return cr.doSendLocked(false, req) |
| } |
| return nil, err |
| } |
| |
| if resp, err := cr.stream.Recv(); err != nil { |
| cr.resetLocked() |
| if retry { |
| return cr.doSendLocked(false, req) |
| } |
| return nil, err |
| } else { |
| return resp, nil |
| } |
| } |
| |
| func (cr *Client) initStreamLocked() error { |
| if cr.stream != nil { |
| return nil |
| } |
| var newCtx context.Context |
| newCtx, cr.cancel = context.WithCancel(cr.ctx) |
| var err error |
| cr.stream, err = cr.stub.ServerReflectionInfo(newCtx) |
| return err |
| } |
| |
| // Reset ensures that any active stream with the server is closed, releasing any |
| // resources. |
| func (cr *Client) Reset() { |
| cr.connMu.Lock() |
| defer cr.connMu.Unlock() |
| cr.resetLocked() |
| } |
| |
| func (cr *Client) resetLocked() { |
| if cr.stream != nil { |
| cr.stream.CloseSend() |
| for { |
| // drain the stream, this covers io.EOF too |
| if _, err := cr.stream.Recv(); err != nil { |
| break |
| } |
| } |
| cr.stream = nil |
| } |
| if cr.cancel != nil { |
| cr.cancel() |
| cr.cancel = nil |
| } |
| } |
| |
| // ResolveService asks the server to resolve the given fully-qualified service |
| // name into a service descriptor. |
| func (cr *Client) ResolveService(serviceName string) (*desc.ServiceDescriptor, error) { |
| file, err := cr.FileContainingSymbol(serviceName) |
| if err != nil { |
| return nil, setSymbolType(err, serviceName, symbolTypeService) |
| } |
| d := file.FindSymbol(serviceName) |
| if d == nil { |
| return nil, symbolNotFound(serviceName, symbolTypeService, nil) |
| } |
| if s, ok := d.(*desc.ServiceDescriptor); ok { |
| return s, nil |
| } else { |
| return nil, symbolNotFound(serviceName, symbolTypeService, nil) |
| } |
| } |
| |
| // ResolveMessage asks the server to resolve the given fully-qualified message |
| // name into a message descriptor. |
| func (cr *Client) ResolveMessage(messageName string) (*desc.MessageDescriptor, error) { |
| file, err := cr.FileContainingSymbol(messageName) |
| if err != nil { |
| return nil, setSymbolType(err, messageName, symbolTypeMessage) |
| } |
| d := file.FindSymbol(messageName) |
| if d == nil { |
| return nil, symbolNotFound(messageName, symbolTypeMessage, nil) |
| } |
| if s, ok := d.(*desc.MessageDescriptor); ok { |
| return s, nil |
| } else { |
| return nil, symbolNotFound(messageName, symbolTypeMessage, nil) |
| } |
| } |
| |
| // ResolveEnum asks the server to resolve the given fully-qualified enum name |
| // into an enum descriptor. |
| func (cr *Client) ResolveEnum(enumName string) (*desc.EnumDescriptor, error) { |
| file, err := cr.FileContainingSymbol(enumName) |
| if err != nil { |
| return nil, setSymbolType(err, enumName, symbolTypeEnum) |
| } |
| d := file.FindSymbol(enumName) |
| if d == nil { |
| return nil, symbolNotFound(enumName, symbolTypeEnum, nil) |
| } |
| if s, ok := d.(*desc.EnumDescriptor); ok { |
| return s, nil |
| } else { |
| return nil, symbolNotFound(enumName, symbolTypeEnum, nil) |
| } |
| } |
| |
| func setSymbolType(err error, name string, symType symbolType) error { |
| if e, ok := err.(*elementNotFoundError); ok { |
| if e.kind == elementKindSymbol && e.name == name && e.symType == symbolTypeUnknown { |
| e.symType = symType |
| } |
| } |
| return err |
| } |
| |
| // ResolveEnumValues asks the server to resolve the given fully-qualified enum |
| // name into a map of names to numbers that represents the enum's values. |
| func (cr *Client) ResolveEnumValues(enumName string) (map[string]int32, error) { |
| enumDesc, err := cr.ResolveEnum(enumName) |
| if err != nil { |
| return nil, err |
| } |
| vals := map[string]int32{} |
| for _, valDesc := range enumDesc.GetValues() { |
| vals[valDesc.GetName()] = valDesc.GetNumber() |
| } |
| return vals, nil |
| } |
| |
| // ResolveExtension asks the server to resolve the given extension number and |
| // fully-qualified message name into a field descriptor. |
| func (cr *Client) ResolveExtension(extendedType string, extensionNumber int32) (*desc.FieldDescriptor, error) { |
| file, err := cr.FileContainingExtension(extendedType, extensionNumber) |
| if err != nil { |
| return nil, err |
| } |
| d := findExtension(extendedType, extensionNumber, fileDescriptorExtensions{file}) |
| if d == nil { |
| return nil, extensionNotFound(extendedType, extensionNumber, nil) |
| } else { |
| return d, nil |
| } |
| } |
| |
| func findExtension(extendedType string, extensionNumber int32, scope extensionScope) *desc.FieldDescriptor { |
| // search extensions in this scope |
| for _, ext := range scope.extensions() { |
| if ext.GetNumber() == extensionNumber && ext.GetOwner().GetFullyQualifiedName() == extendedType { |
| return ext |
| } |
| } |
| |
| // if not found, search nested scopes |
| for _, nested := range scope.nestedScopes() { |
| ext := findExtension(extendedType, extensionNumber, nested) |
| if ext != nil { |
| return ext |
| } |
| } |
| |
| return nil |
| } |
| |
| type extensionScope interface { |
| extensions() []*desc.FieldDescriptor |
| nestedScopes() []extensionScope |
| } |
| |
| // fileDescriptorExtensions implements extensionHolder interface on top of |
| // FileDescriptorProto |
| type fileDescriptorExtensions struct { |
| proto *desc.FileDescriptor |
| } |
| |
| func (fde fileDescriptorExtensions) extensions() []*desc.FieldDescriptor { |
| return fde.proto.GetExtensions() |
| } |
| |
| func (fde fileDescriptorExtensions) nestedScopes() []extensionScope { |
| scopes := make([]extensionScope, len(fde.proto.GetMessageTypes())) |
| for i, m := range fde.proto.GetMessageTypes() { |
| scopes[i] = msgDescriptorExtensions{m} |
| } |
| return scopes |
| } |
| |
| // msgDescriptorExtensions implements extensionHolder interface on top of |
| // DescriptorProto |
| type msgDescriptorExtensions struct { |
| proto *desc.MessageDescriptor |
| } |
| |
| func (mde msgDescriptorExtensions) extensions() []*desc.FieldDescriptor { |
| return mde.proto.GetNestedExtensions() |
| } |
| |
| func (mde msgDescriptorExtensions) nestedScopes() []extensionScope { |
| scopes := make([]extensionScope, len(mde.proto.GetNestedMessageTypes())) |
| for i, m := range mde.proto.GetNestedMessageTypes() { |
| scopes[i] = msgDescriptorExtensions{m} |
| } |
| return scopes |
| } |