blob: dd22a2da78496cd3b14ac45483ecc4c1ac094a80 [file] [log] [blame]
khenaidoo106c61a2021-08-11 18:05:46 -04001/*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19//go:generate protoc --go_out=plugins=grpc:. grpc_reflection_v1alpha/reflection.proto
20
21/*
22Package reflection implements server reflection service.
23
24The service implemented is defined in:
25https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
26
27To register server reflection on a gRPC server:
28 import "google.golang.org/grpc/reflection"
29
30 s := grpc.NewServer()
31 pb.RegisterYourOwnServer(s, &server{})
32
33 // Register reflection service on gRPC server.
34 reflection.Register(s)
35
36 s.Serve(lis)
37
38*/
39package reflection // import "google.golang.org/grpc/reflection"
40
41import (
42 "bytes"
43 "compress/gzip"
44 "fmt"
45 "io"
46 "io/ioutil"
47 "reflect"
48 "sort"
49 "sync"
50
51 "github.com/golang/protobuf/proto"
52 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
53 "google.golang.org/grpc"
54 "google.golang.org/grpc/codes"
55 rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
56 "google.golang.org/grpc/status"
57)
58
59type serverReflectionServer struct {
60 s *grpc.Server
61
62 initSymbols sync.Once
63 serviceNames []string
64 symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files
65}
66
67// Register registers the server reflection service on the given gRPC server.
68func Register(s *grpc.Server) {
69 rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
70 s: s,
71 })
72}
73
74// protoMessage is used for type assertion on proto messages.
75// Generated proto message implements function Descriptor(), but Descriptor()
76// is not part of interface proto.Message. This interface is needed to
77// call Descriptor().
78type protoMessage interface {
79 Descriptor() ([]byte, []int)
80}
81
82func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) {
83 s.initSymbols.Do(func() {
84 serviceInfo := s.s.GetServiceInfo()
85
86 s.symbols = map[string]*dpb.FileDescriptorProto{}
87 s.serviceNames = make([]string, 0, len(serviceInfo))
88 processed := map[string]struct{}{}
89 for svc, info := range serviceInfo {
90 s.serviceNames = append(s.serviceNames, svc)
91 fdenc, ok := parseMetadata(info.Metadata)
92 if !ok {
93 continue
94 }
95 fd, err := decodeFileDesc(fdenc)
96 if err != nil {
97 continue
98 }
99 s.processFile(fd, processed)
100 }
101 sort.Strings(s.serviceNames)
102 })
103
104 return s.serviceNames, s.symbols
105}
106
107func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) {
108 filename := fd.GetName()
109 if _, ok := processed[filename]; ok {
110 return
111 }
112 processed[filename] = struct{}{}
113
114 prefix := fd.GetPackage()
115
116 for _, msg := range fd.MessageType {
117 s.processMessage(fd, prefix, msg)
118 }
119 for _, en := range fd.EnumType {
120 s.processEnum(fd, prefix, en)
121 }
122 for _, ext := range fd.Extension {
123 s.processField(fd, prefix, ext)
124 }
125 for _, svc := range fd.Service {
126 svcName := fqn(prefix, svc.GetName())
127 s.symbols[svcName] = fd
128 for _, meth := range svc.Method {
129 name := fqn(svcName, meth.GetName())
130 s.symbols[name] = fd
131 }
132 }
133
134 for _, dep := range fd.Dependency {
135 fdenc := proto.FileDescriptor(dep)
136 fdDep, err := decodeFileDesc(fdenc)
137 if err != nil {
138 continue
139 }
140 s.processFile(fdDep, processed)
141 }
142}
143
144func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) {
145 msgName := fqn(prefix, msg.GetName())
146 s.symbols[msgName] = fd
147
148 for _, nested := range msg.NestedType {
149 s.processMessage(fd, msgName, nested)
150 }
151 for _, en := range msg.EnumType {
152 s.processEnum(fd, msgName, en)
153 }
154 for _, ext := range msg.Extension {
155 s.processField(fd, msgName, ext)
156 }
157 for _, fld := range msg.Field {
158 s.processField(fd, msgName, fld)
159 }
160 for _, oneof := range msg.OneofDecl {
161 oneofName := fqn(msgName, oneof.GetName())
162 s.symbols[oneofName] = fd
163 }
164}
165
166func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) {
167 enName := fqn(prefix, en.GetName())
168 s.symbols[enName] = fd
169
170 for _, val := range en.Value {
171 valName := fqn(enName, val.GetName())
172 s.symbols[valName] = fd
173 }
174}
175
176func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) {
177 fldName := fqn(prefix, fld.GetName())
178 s.symbols[fldName] = fd
179}
180
181func fqn(prefix, name string) string {
182 if prefix == "" {
183 return name
184 }
185 return prefix + "." + name
186}
187
188// fileDescForType gets the file descriptor for the given type.
189// The given type should be a proto message.
190func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) {
191 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
192 if !ok {
193 return nil, fmt.Errorf("failed to create message from type: %v", st)
194 }
195 enc, _ := m.Descriptor()
196
197 return decodeFileDesc(enc)
198}
199
200// decodeFileDesc does decompression and unmarshalling on the given
201// file descriptor byte slice.
202func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
203 raw, err := decompress(enc)
204 if err != nil {
205 return nil, fmt.Errorf("failed to decompress enc: %v", err)
206 }
207
208 fd := new(dpb.FileDescriptorProto)
209 if err := proto.Unmarshal(raw, fd); err != nil {
210 return nil, fmt.Errorf("bad descriptor: %v", err)
211 }
212 return fd, nil
213}
214
215// decompress does gzip decompression.
216func decompress(b []byte) ([]byte, error) {
217 r, err := gzip.NewReader(bytes.NewReader(b))
218 if err != nil {
219 return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
220 }
221 out, err := ioutil.ReadAll(r)
222 if err != nil {
223 return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
224 }
225 return out, nil
226}
227
228func typeForName(name string) (reflect.Type, error) {
229 pt := proto.MessageType(name)
230 if pt == nil {
231 return nil, fmt.Errorf("unknown type: %q", name)
232 }
233 st := pt.Elem()
234
235 return st, nil
236}
237
238func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
239 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
240 if !ok {
241 return nil, fmt.Errorf("failed to create message from type: %v", st)
242 }
243
244 var extDesc *proto.ExtensionDesc
245 for id, desc := range proto.RegisteredExtensions(m) {
246 if id == ext {
247 extDesc = desc
248 break
249 }
250 }
251
252 if extDesc == nil {
253 return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
254 }
255
256 return decodeFileDesc(proto.FileDescriptor(extDesc.Filename))
257}
258
259func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
260 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
261 if !ok {
262 return nil, fmt.Errorf("failed to create message from type: %v", st)
263 }
264
265 exts := proto.RegisteredExtensions(m)
266 out := make([]int32, 0, len(exts))
267 for id := range exts {
268 out = append(out, id)
269 }
270 return out, nil
271}
272
273// fileDescEncodingByFilename finds the file descriptor for given filename,
274// does marshalling on it and returns the marshalled result.
275func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) {
276 enc := proto.FileDescriptor(name)
277 if enc == nil {
278 return nil, fmt.Errorf("unknown file: %v", name)
279 }
280 fd, err := decodeFileDesc(enc)
281 if err != nil {
282 return nil, err
283 }
284 return proto.Marshal(fd)
285}
286
287// parseMetadata finds the file descriptor bytes specified meta.
288// For SupportPackageIsVersion4, m is the name of the proto file, we
289// call proto.FileDescriptor to get the byte slice.
290// For SupportPackageIsVersion3, m is a byte slice itself.
291func parseMetadata(meta interface{}) ([]byte, bool) {
292 // Check if meta is the file name.
293 if fileNameForMeta, ok := meta.(string); ok {
294 return proto.FileDescriptor(fileNameForMeta), true
295 }
296
297 // Check if meta is the byte slice.
298 if enc, ok := meta.([]byte); ok {
299 return enc, true
300 }
301
302 return nil, false
303}
304
305// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
306// does marshalling on it and returns the marshalled result.
307// The given symbol can be a type, a service or a method.
308func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) {
309 _, symbols := s.getSymbols()
310 fd := symbols[name]
311 if fd == nil {
312 // Check if it's a type name that was not present in the
313 // transitive dependencies of the registered services.
314 if st, err := typeForName(name); err == nil {
315 fd, err = s.fileDescForType(st)
316 if err != nil {
317 return nil, err
318 }
319 }
320 }
321
322 if fd == nil {
323 return nil, fmt.Errorf("unknown symbol: %v", name)
324 }
325
326 return proto.Marshal(fd)
327}
328
329// fileDescEncodingContainingExtension finds the file descriptor containing given extension,
330// does marshalling on it and returns the marshalled result.
331func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) {
332 st, err := typeForName(typeName)
333 if err != nil {
334 return nil, err
335 }
336 fd, err := fileDescContainingExtension(st, extNum)
337 if err != nil {
338 return nil, err
339 }
340 return proto.Marshal(fd)
341}
342
343// allExtensionNumbersForTypeName returns all extension numbers for the given type.
344func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
345 st, err := typeForName(name)
346 if err != nil {
347 return nil, err
348 }
349 extNums, err := s.allExtensionNumbersForType(st)
350 if err != nil {
351 return nil, err
352 }
353 return extNums, nil
354}
355
356// ServerReflectionInfo is the reflection service handler.
357func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
358 for {
359 in, err := stream.Recv()
360 if err == io.EOF {
361 return nil
362 }
363 if err != nil {
364 return err
365 }
366
367 out := &rpb.ServerReflectionResponse{
368 ValidHost: in.Host,
369 OriginalRequest: in,
370 }
371 switch req := in.MessageRequest.(type) {
372 case *rpb.ServerReflectionRequest_FileByFilename:
373 b, err := s.fileDescEncodingByFilename(req.FileByFilename)
374 if err != nil {
375 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
376 ErrorResponse: &rpb.ErrorResponse{
377 ErrorCode: int32(codes.NotFound),
378 ErrorMessage: err.Error(),
379 },
380 }
381 } else {
382 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
383 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
384 }
385 }
386 case *rpb.ServerReflectionRequest_FileContainingSymbol:
387 b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol)
388 if err != nil {
389 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
390 ErrorResponse: &rpb.ErrorResponse{
391 ErrorCode: int32(codes.NotFound),
392 ErrorMessage: err.Error(),
393 },
394 }
395 } else {
396 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
397 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
398 }
399 }
400 case *rpb.ServerReflectionRequest_FileContainingExtension:
401 typeName := req.FileContainingExtension.ContainingType
402 extNum := req.FileContainingExtension.ExtensionNumber
403 b, err := s.fileDescEncodingContainingExtension(typeName, extNum)
404 if err != nil {
405 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
406 ErrorResponse: &rpb.ErrorResponse{
407 ErrorCode: int32(codes.NotFound),
408 ErrorMessage: err.Error(),
409 },
410 }
411 } else {
412 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
413 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
414 }
415 }
416 case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
417 extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
418 if err != nil {
419 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
420 ErrorResponse: &rpb.ErrorResponse{
421 ErrorCode: int32(codes.NotFound),
422 ErrorMessage: err.Error(),
423 },
424 }
425 } else {
426 out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
427 AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
428 BaseTypeName: req.AllExtensionNumbersOfType,
429 ExtensionNumber: extNums,
430 },
431 }
432 }
433 case *rpb.ServerReflectionRequest_ListServices:
434 svcNames, _ := s.getSymbols()
435 serviceResponses := make([]*rpb.ServiceResponse, len(svcNames))
436 for i, n := range svcNames {
437 serviceResponses[i] = &rpb.ServiceResponse{
438 Name: n,
439 }
440 }
441 out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
442 ListServicesResponse: &rpb.ListServiceResponse{
443 Service: serviceResponses,
444 },
445 }
446 default:
447 return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
448 }
449
450 if err := stream.Send(out); err != nil {
451 return err
452 }
453 }
454}