blob: c23ae3d25c19f2709aee2eb20fc35f3350b60699 [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001package grpcurl
2
3import (
4 "errors"
5 "fmt"
6 "io/ioutil"
7 "sync"
8
9 "github.com/golang/protobuf/proto"
10 descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
11 "github.com/jhump/protoreflect/desc"
12 "github.com/jhump/protoreflect/desc/protoparse"
13 "github.com/jhump/protoreflect/dynamic"
14 "github.com/jhump/protoreflect/grpcreflect"
15 "golang.org/x/net/context"
16 "google.golang.org/grpc/codes"
17 "google.golang.org/grpc/status"
18)
19
20// ErrReflectionNotSupported is returned by DescriptorSource operations that
21// rely on interacting with the reflection service when the source does not
22// actually expose the reflection service. When this occurs, an alternate source
23// (like file descriptor sets) must be used.
24var ErrReflectionNotSupported = errors.New("server does not support the reflection API")
25
26// DescriptorSource is a source of protobuf descriptor information. It can be backed by a FileDescriptorSet
27// proto (like a file generated by protoc) or a remote server that supports the reflection API.
28type DescriptorSource interface {
29 // ListServices returns a list of fully-qualified service names. It will be all services in a set of
30 // descriptor files or the set of all services exposed by a gRPC server.
31 ListServices() ([]string, error)
32 // FindSymbol returns a descriptor for the given fully-qualified symbol name.
33 FindSymbol(fullyQualifiedName string) (desc.Descriptor, error)
34 // AllExtensionsForType returns all known extension fields that extend the given message type name.
35 AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error)
36}
37
38// DescriptorSourceFromProtoSets creates a DescriptorSource that is backed by the named files, whose contents
39// are encoded FileDescriptorSet protos.
40func DescriptorSourceFromProtoSets(fileNames ...string) (DescriptorSource, error) {
41 files := &descpb.FileDescriptorSet{}
42 for _, fileName := range fileNames {
43 b, err := ioutil.ReadFile(fileName)
44 if err != nil {
45 return nil, fmt.Errorf("could not load protoset file %q: %v", fileName, err)
46 }
47 var fs descpb.FileDescriptorSet
48 err = proto.Unmarshal(b, &fs)
49 if err != nil {
50 return nil, fmt.Errorf("could not parse contents of protoset file %q: %v", fileName, err)
51 }
52 files.File = append(files.File, fs.File...)
53 }
54 return DescriptorSourceFromFileDescriptorSet(files)
55}
56
57// DescriptorSourceFromProtoFiles creates a DescriptorSource that is backed by the named files,
58// whose contents are Protocol Buffer source files. The given importPaths are used to locate
59// any imported files.
60func DescriptorSourceFromProtoFiles(importPaths []string, fileNames ...string) (DescriptorSource, error) {
61 fileNames, err := protoparse.ResolveFilenames(importPaths, fileNames...)
62 if err != nil {
63 return nil, err
64 }
65 p := protoparse.Parser{
66 ImportPaths: importPaths,
67 InferImportPaths: len(importPaths) == 0,
68 IncludeSourceCodeInfo: true,
69 }
70 fds, err := p.ParseFiles(fileNames...)
71 if err != nil {
72 return nil, fmt.Errorf("could not parse given files: %v", err)
73 }
74 return DescriptorSourceFromFileDescriptors(fds...)
75}
76
77// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the FileDescriptorSet.
78func DescriptorSourceFromFileDescriptorSet(files *descpb.FileDescriptorSet) (DescriptorSource, error) {
79 unresolved := map[string]*descpb.FileDescriptorProto{}
80 for _, fd := range files.File {
81 unresolved[fd.GetName()] = fd
82 }
83 resolved := map[string]*desc.FileDescriptor{}
84 for _, fd := range files.File {
85 _, err := resolveFileDescriptor(unresolved, resolved, fd.GetName())
86 if err != nil {
87 return nil, err
88 }
89 }
90 return &fileSource{files: resolved}, nil
91}
92
93func resolveFileDescriptor(unresolved map[string]*descpb.FileDescriptorProto, resolved map[string]*desc.FileDescriptor, filename string) (*desc.FileDescriptor, error) {
94 if r, ok := resolved[filename]; ok {
95 return r, nil
96 }
97 fd, ok := unresolved[filename]
98 if !ok {
99 return nil, fmt.Errorf("no descriptor found for %q", filename)
100 }
101 deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency()))
102 for _, dep := range fd.GetDependency() {
103 depFd, err := resolveFileDescriptor(unresolved, resolved, dep)
104 if err != nil {
105 return nil, err
106 }
107 deps = append(deps, depFd)
108 }
109 result, err := desc.CreateFileDescriptor(fd, deps...)
110 if err != nil {
111 return nil, err
112 }
113 resolved[filename] = result
114 return result, nil
115}
116
117// DescriptorSourceFromFileDescriptors creates a DescriptorSource that is backed by the given
118// file descriptors
119func DescriptorSourceFromFileDescriptors(files ...*desc.FileDescriptor) (DescriptorSource, error) {
120 fds := map[string]*desc.FileDescriptor{}
121 for _, fd := range files {
122 if err := addFile(fd, fds); err != nil {
123 return nil, err
124 }
125 }
126 return &fileSource{files: fds}, nil
127}
128
129func addFile(fd *desc.FileDescriptor, fds map[string]*desc.FileDescriptor) error {
130 name := fd.GetName()
131 if existing, ok := fds[name]; ok {
132 // already added this file
133 if existing != fd {
134 // doh! duplicate files provided
135 return fmt.Errorf("given files include multiple copies of %q", name)
136 }
137 return nil
138 }
139 fds[name] = fd
140 for _, dep := range fd.GetDependencies() {
141 if err := addFile(dep, fds); err != nil {
142 return err
143 }
144 }
145 return nil
146}
147
148type fileSource struct {
149 files map[string]*desc.FileDescriptor
150 er *dynamic.ExtensionRegistry
151 erInit sync.Once
152}
153
154func (fs *fileSource) ListServices() ([]string, error) {
155 set := map[string]bool{}
156 for _, fd := range fs.files {
157 for _, svc := range fd.GetServices() {
158 set[svc.GetFullyQualifiedName()] = true
159 }
160 }
161 sl := make([]string, 0, len(set))
162 for svc := range set {
163 sl = append(sl, svc)
164 }
165 return sl, nil
166}
167
168// GetAllFiles returns all of the underlying file descriptors. This is
169// more thorough and more efficient than the fallback strategy used by
170// the GetAllFiles package method, for enumerating all files from a
171// descriptor source.
172func (fs *fileSource) GetAllFiles() ([]*desc.FileDescriptor, error) {
173 files := make([]*desc.FileDescriptor, len(fs.files))
174 i := 0
175 for _, fd := range fs.files {
176 files[i] = fd
177 i++
178 }
179 return files, nil
180}
181
182func (fs *fileSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) {
183 for _, fd := range fs.files {
184 if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil {
185 return dsc, nil
186 }
187 }
188 return nil, notFound("Symbol", fullyQualifiedName)
189}
190
191func (fs *fileSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) {
192 fs.erInit.Do(func() {
193 fs.er = &dynamic.ExtensionRegistry{}
194 for _, fd := range fs.files {
195 fs.er.AddExtensionsFromFile(fd)
196 }
197 })
198 return fs.er.AllExtensionsForType(typeName), nil
199}
200
201// DescriptorSourceFromServer creates a DescriptorSource that uses the given gRPC reflection client
202// to interrogate a server for descriptor information. If the server does not support the reflection
203// API then the various DescriptorSource methods will return ErrReflectionNotSupported
204func DescriptorSourceFromServer(_ context.Context, refClient *grpcreflect.Client) DescriptorSource {
205 return serverSource{client: refClient}
206}
207
208type serverSource struct {
209 client *grpcreflect.Client
210}
211
212func (ss serverSource) ListServices() ([]string, error) {
213 svcs, err := ss.client.ListServices()
214 return svcs, reflectionSupport(err)
215}
216
217func (ss serverSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) {
218 file, err := ss.client.FileContainingSymbol(fullyQualifiedName)
219 if err != nil {
220 return nil, reflectionSupport(err)
221 }
222 d := file.FindSymbol(fullyQualifiedName)
223 if d == nil {
224 return nil, notFound("Symbol", fullyQualifiedName)
225 }
226 return d, nil
227}
228
229func (ss serverSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) {
230 var exts []*desc.FieldDescriptor
231 nums, err := ss.client.AllExtensionNumbersForType(typeName)
232 if err != nil {
233 return nil, reflectionSupport(err)
234 }
235 for _, fieldNum := range nums {
236 ext, err := ss.client.ResolveExtension(typeName, fieldNum)
237 if err != nil {
238 return nil, reflectionSupport(err)
239 }
240 exts = append(exts, ext)
241 }
242 return exts, nil
243}
244
245func reflectionSupport(err error) error {
246 if err == nil {
247 return nil
248 }
249 if stat, ok := status.FromError(err); ok && stat.Code() == codes.Unimplemented {
250 return ErrReflectionNotSupported
251 }
252 return err
253}