blob: 635ddefed9f97dbfbecd96ed62564e726211dbb2 [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001package grpcurl
2
3import (
4 "errors"
5 "fmt"
Scott Baker4a35a702019-11-26 08:17:33 -08006 "io"
Zack Williamse940c7a2019-08-21 14:25:39 -07007 "io/ioutil"
8 "sync"
9
10 "github.com/golang/protobuf/proto"
11 descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
12 "github.com/jhump/protoreflect/desc"
13 "github.com/jhump/protoreflect/desc/protoparse"
14 "github.com/jhump/protoreflect/dynamic"
15 "github.com/jhump/protoreflect/grpcreflect"
16 "golang.org/x/net/context"
17 "google.golang.org/grpc/codes"
18 "google.golang.org/grpc/status"
19)
20
21// ErrReflectionNotSupported is returned by DescriptorSource operations that
22// rely on interacting with the reflection service when the source does not
23// actually expose the reflection service. When this occurs, an alternate source
24// (like file descriptor sets) must be used.
25var ErrReflectionNotSupported = errors.New("server does not support the reflection API")
26
27// DescriptorSource is a source of protobuf descriptor information. It can be backed by a FileDescriptorSet
28// proto (like a file generated by protoc) or a remote server that supports the reflection API.
29type DescriptorSource interface {
30 // ListServices returns a list of fully-qualified service names. It will be all services in a set of
31 // descriptor files or the set of all services exposed by a gRPC server.
32 ListServices() ([]string, error)
33 // FindSymbol returns a descriptor for the given fully-qualified symbol name.
34 FindSymbol(fullyQualifiedName string) (desc.Descriptor, error)
35 // AllExtensionsForType returns all known extension fields that extend the given message type name.
36 AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error)
37}
38
39// DescriptorSourceFromProtoSets creates a DescriptorSource that is backed by the named files, whose contents
40// are encoded FileDescriptorSet protos.
41func DescriptorSourceFromProtoSets(fileNames ...string) (DescriptorSource, error) {
42 files := &descpb.FileDescriptorSet{}
43 for _, fileName := range fileNames {
44 b, err := ioutil.ReadFile(fileName)
45 if err != nil {
46 return nil, fmt.Errorf("could not load protoset file %q: %v", fileName, err)
47 }
48 var fs descpb.FileDescriptorSet
49 err = proto.Unmarshal(b, &fs)
50 if err != nil {
51 return nil, fmt.Errorf("could not parse contents of protoset file %q: %v", fileName, err)
52 }
53 files.File = append(files.File, fs.File...)
54 }
55 return DescriptorSourceFromFileDescriptorSet(files)
56}
57
58// DescriptorSourceFromProtoFiles creates a DescriptorSource that is backed by the named files,
59// whose contents are Protocol Buffer source files. The given importPaths are used to locate
60// any imported files.
61func DescriptorSourceFromProtoFiles(importPaths []string, fileNames ...string) (DescriptorSource, error) {
62 fileNames, err := protoparse.ResolveFilenames(importPaths, fileNames...)
63 if err != nil {
64 return nil, err
65 }
66 p := protoparse.Parser{
67 ImportPaths: importPaths,
68 InferImportPaths: len(importPaths) == 0,
69 IncludeSourceCodeInfo: true,
70 }
71 fds, err := p.ParseFiles(fileNames...)
72 if err != nil {
73 return nil, fmt.Errorf("could not parse given files: %v", err)
74 }
75 return DescriptorSourceFromFileDescriptors(fds...)
76}
77
78// DescriptorSourceFromFileDescriptorSet creates a DescriptorSource that is backed by the FileDescriptorSet.
79func DescriptorSourceFromFileDescriptorSet(files *descpb.FileDescriptorSet) (DescriptorSource, error) {
80 unresolved := map[string]*descpb.FileDescriptorProto{}
81 for _, fd := range files.File {
82 unresolved[fd.GetName()] = fd
83 }
84 resolved := map[string]*desc.FileDescriptor{}
85 for _, fd := range files.File {
86 _, err := resolveFileDescriptor(unresolved, resolved, fd.GetName())
87 if err != nil {
88 return nil, err
89 }
90 }
91 return &fileSource{files: resolved}, nil
92}
93
94func resolveFileDescriptor(unresolved map[string]*descpb.FileDescriptorProto, resolved map[string]*desc.FileDescriptor, filename string) (*desc.FileDescriptor, error) {
95 if r, ok := resolved[filename]; ok {
96 return r, nil
97 }
98 fd, ok := unresolved[filename]
99 if !ok {
100 return nil, fmt.Errorf("no descriptor found for %q", filename)
101 }
102 deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency()))
103 for _, dep := range fd.GetDependency() {
104 depFd, err := resolveFileDescriptor(unresolved, resolved, dep)
105 if err != nil {
106 return nil, err
107 }
108 deps = append(deps, depFd)
109 }
110 result, err := desc.CreateFileDescriptor(fd, deps...)
111 if err != nil {
112 return nil, err
113 }
114 resolved[filename] = result
115 return result, nil
116}
117
118// DescriptorSourceFromFileDescriptors creates a DescriptorSource that is backed by the given
119// file descriptors
120func DescriptorSourceFromFileDescriptors(files ...*desc.FileDescriptor) (DescriptorSource, error) {
121 fds := map[string]*desc.FileDescriptor{}
122 for _, fd := range files {
123 if err := addFile(fd, fds); err != nil {
124 return nil, err
125 }
126 }
127 return &fileSource{files: fds}, nil
128}
129
130func addFile(fd *desc.FileDescriptor, fds map[string]*desc.FileDescriptor) error {
131 name := fd.GetName()
132 if existing, ok := fds[name]; ok {
133 // already added this file
134 if existing != fd {
135 // doh! duplicate files provided
136 return fmt.Errorf("given files include multiple copies of %q", name)
137 }
138 return nil
139 }
140 fds[name] = fd
141 for _, dep := range fd.GetDependencies() {
142 if err := addFile(dep, fds); err != nil {
143 return err
144 }
145 }
146 return nil
147}
148
149type fileSource struct {
150 files map[string]*desc.FileDescriptor
151 er *dynamic.ExtensionRegistry
152 erInit sync.Once
153}
154
155func (fs *fileSource) ListServices() ([]string, error) {
156 set := map[string]bool{}
157 for _, fd := range fs.files {
158 for _, svc := range fd.GetServices() {
159 set[svc.GetFullyQualifiedName()] = true
160 }
161 }
162 sl := make([]string, 0, len(set))
163 for svc := range set {
164 sl = append(sl, svc)
165 }
166 return sl, nil
167}
168
169// GetAllFiles returns all of the underlying file descriptors. This is
170// more thorough and more efficient than the fallback strategy used by
171// the GetAllFiles package method, for enumerating all files from a
172// descriptor source.
173func (fs *fileSource) GetAllFiles() ([]*desc.FileDescriptor, error) {
174 files := make([]*desc.FileDescriptor, len(fs.files))
175 i := 0
176 for _, fd := range fs.files {
177 files[i] = fd
178 i++
179 }
180 return files, nil
181}
182
183func (fs *fileSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) {
184 for _, fd := range fs.files {
185 if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil {
186 return dsc, nil
187 }
188 }
189 return nil, notFound("Symbol", fullyQualifiedName)
190}
191
192func (fs *fileSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) {
193 fs.erInit.Do(func() {
194 fs.er = &dynamic.ExtensionRegistry{}
195 for _, fd := range fs.files {
196 fs.er.AddExtensionsFromFile(fd)
197 }
198 })
199 return fs.er.AllExtensionsForType(typeName), nil
200}
201
202// DescriptorSourceFromServer creates a DescriptorSource that uses the given gRPC reflection client
203// to interrogate a server for descriptor information. If the server does not support the reflection
204// API then the various DescriptorSource methods will return ErrReflectionNotSupported
205func DescriptorSourceFromServer(_ context.Context, refClient *grpcreflect.Client) DescriptorSource {
206 return serverSource{client: refClient}
207}
208
209type serverSource struct {
210 client *grpcreflect.Client
211}
212
213func (ss serverSource) ListServices() ([]string, error) {
214 svcs, err := ss.client.ListServices()
215 return svcs, reflectionSupport(err)
216}
217
218func (ss serverSource) FindSymbol(fullyQualifiedName string) (desc.Descriptor, error) {
219 file, err := ss.client.FileContainingSymbol(fullyQualifiedName)
220 if err != nil {
221 return nil, reflectionSupport(err)
222 }
223 d := file.FindSymbol(fullyQualifiedName)
224 if d == nil {
225 return nil, notFound("Symbol", fullyQualifiedName)
226 }
227 return d, nil
228}
229
230func (ss serverSource) AllExtensionsForType(typeName string) ([]*desc.FieldDescriptor, error) {
231 var exts []*desc.FieldDescriptor
232 nums, err := ss.client.AllExtensionNumbersForType(typeName)
233 if err != nil {
234 return nil, reflectionSupport(err)
235 }
236 for _, fieldNum := range nums {
237 ext, err := ss.client.ResolveExtension(typeName, fieldNum)
238 if err != nil {
239 return nil, reflectionSupport(err)
240 }
241 exts = append(exts, ext)
242 }
243 return exts, nil
244}
245
246func reflectionSupport(err error) error {
247 if err == nil {
248 return nil
249 }
250 if stat, ok := status.FromError(err); ok && stat.Code() == codes.Unimplemented {
251 return ErrReflectionNotSupported
252 }
253 return err
254}
Scott Baker4a35a702019-11-26 08:17:33 -0800255
256// WriteProtoset will use the given descriptor source to resolve all of the given
257// symbols and write a proto file descriptor set with their definitions to the
258// given output. The output will include descriptors for all files in which the
259// symbols are defined as well as their transitive dependencies.
260func WriteProtoset(out io.Writer, descSource DescriptorSource, symbols ...string) error {
261 // compute set of file descriptors
262 filenames := make([]string, 0, len(symbols))
263 fds := make(map[string]*desc.FileDescriptor, len(symbols))
264 for _, sym := range symbols {
265 d, err := descSource.FindSymbol(sym)
266 if err != nil {
267 return fmt.Errorf("failed to find descriptor for %q: %v", sym, err)
268 }
269 fd := d.GetFile()
270 if _, ok := fds[fd.GetName()]; !ok {
271 fds[fd.GetName()] = fd
272 filenames = append(filenames, fd.GetName())
273 }
274 }
275 // now expand that to include transitive dependencies in topologically sorted
276 // order (such that file always appears after its dependencies)
277 expandedFiles := make(map[string]struct{}, len(fds))
278 allFilesSlice := make([]*descpb.FileDescriptorProto, 0, len(fds))
279 for _, filename := range filenames {
280 allFilesSlice = addFilesToSet(allFilesSlice, expandedFiles, fds[filename])
281 }
282 // now we can serialize to file
283 b, err := proto.Marshal(&descpb.FileDescriptorSet{File: allFilesSlice})
284 if err != nil {
285 return fmt.Errorf("failed to serialize file descriptor set: %v", err)
286 }
287 if _, err := out.Write(b); err != nil {
288 return fmt.Errorf("failed to write file descriptor set: %v", err)
289 }
290 return nil
291}
292
293func addFilesToSet(allFiles []*descpb.FileDescriptorProto, expanded map[string]struct{}, fd *desc.FileDescriptor) []*descpb.FileDescriptorProto {
294 if _, ok := expanded[fd.GetName()]; ok {
295 // already seen this one
296 return allFiles
297 }
298 expanded[fd.GetName()] = struct{}{}
299 // add all dependencies first
300 for _, dep := range fd.GetDependencies() {
301 allFiles = addFilesToSet(allFiles, expanded, dep)
302 }
303 return append(allFiles, fd.AsFileDescriptorProto())
304}