blob: 43f16c6164ce3ef9978ee9032ecf99f6742907b7 [file] [log] [blame]
Scott Baker105df152020-04-13 15:55:14 -07001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protoregistry provides data structures to register and lookup
6// protobuf descriptor types.
7//
8// The Files registry contains file descriptors and provides the ability
9// to iterate over the files or lookup a specific descriptor within the files.
10// Files only contains protobuf descriptors and has no understanding of Go
11// type information that may be associated with each descriptor.
12//
13// The Types registry contains descriptor types for which there is a known
14// Go type associated with that descriptor. It provides the ability to iterate
15// over the registered types or lookup a type by name.
16package protoregistry
17
18import (
19 "fmt"
20 "log"
21 "strings"
22 "sync"
23
24 "google.golang.org/protobuf/internal/errors"
25 "google.golang.org/protobuf/reflect/protoreflect"
26)
27
28// ignoreConflict reports whether to ignore a registration conflict
29// given the descriptor being registered and the error.
30// It is a variable so that the behavior is easily overridden in another file.
31var ignoreConflict = func(d protoreflect.Descriptor, err error) bool {
32 log.Printf(""+
33 "WARNING: %v\n"+
34 "A future release will panic on registration conflicts. See:\n"+
35 "https://developers.google.com/protocol-buffers/docs/reference/go/faq#namespace-conflict\n"+
36 "\n", err)
37 return true
38}
39
40var globalMutex sync.RWMutex
41
42// GlobalFiles is a global registry of file descriptors.
43var GlobalFiles *Files = new(Files)
44
45// GlobalTypes is the registry used by default for type lookups
46// unless a local registry is provided by the user.
47var GlobalTypes *Types = new(Types)
48
49// NotFound is a sentinel error value to indicate that the type was not found.
50//
51// Since registry lookup can happen in the critical performance path, resolvers
52// must return this exact error value, not an error wrapping it.
53var NotFound = errors.New("not found")
54
55// Files is a registry for looking up or iterating over files and the
56// descriptors contained within them.
57// The Find and Range methods are safe for concurrent use.
58type Files struct {
59 // The map of descsByName contains:
60 // EnumDescriptor
61 // EnumValueDescriptor
62 // MessageDescriptor
63 // ExtensionDescriptor
64 // ServiceDescriptor
65 // *packageDescriptor
66 //
67 // Note that files are stored as a slice, since a package may contain
68 // multiple files. Only top-level declarations are registered.
69 // Note that enum values are in the top-level since that are in the same
70 // scope as the parent enum.
71 descsByName map[protoreflect.FullName]interface{}
72 filesByPath map[string]protoreflect.FileDescriptor
73}
74
75type packageDescriptor struct {
76 files []protoreflect.FileDescriptor
77}
78
79// RegisterFile registers the provided file descriptor.
80//
81// If any descriptor within the file conflicts with the descriptor of any
82// previously registered file (e.g., two enums with the same full name),
83// then the file is not registered and an error is returned.
84//
85// It is permitted for multiple files to have the same file path.
86func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error {
87 if r == GlobalFiles {
88 globalMutex.Lock()
89 defer globalMutex.Unlock()
90 }
91 if r.descsByName == nil {
92 r.descsByName = map[protoreflect.FullName]interface{}{
93 "": &packageDescriptor{},
94 }
95 r.filesByPath = make(map[string]protoreflect.FileDescriptor)
96 }
97 path := file.Path()
98 if prev := r.filesByPath[path]; prev != nil {
99 err := errors.New("file %q is already registered", file.Path())
100 err = amendErrorWithCaller(err, prev, file)
101 if r == GlobalFiles && ignoreConflict(file, err) {
102 err = nil
103 }
104 return err
105 }
106
107 for name := file.Package(); name != ""; name = name.Parent() {
108 switch prev := r.descsByName[name]; prev.(type) {
109 case nil, *packageDescriptor:
110 default:
111 err := errors.New("file %q has a package name conflict over %v", file.Path(), name)
112 err = amendErrorWithCaller(err, prev, file)
113 if r == GlobalFiles && ignoreConflict(file, err) {
114 err = nil
115 }
116 return err
117 }
118 }
119 var err error
120 var hasConflict bool
121 rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
122 if prev := r.descsByName[d.FullName()]; prev != nil {
123 hasConflict = true
124 err = errors.New("file %q has a name conflict over %v", file.Path(), d.FullName())
125 err = amendErrorWithCaller(err, prev, file)
126 if r == GlobalFiles && ignoreConflict(d, err) {
127 err = nil
128 }
129 }
130 })
131 if hasConflict {
132 return err
133 }
134
135 for name := file.Package(); name != ""; name = name.Parent() {
136 if r.descsByName[name] == nil {
137 r.descsByName[name] = &packageDescriptor{}
138 }
139 }
140 p := r.descsByName[file.Package()].(*packageDescriptor)
141 p.files = append(p.files, file)
142 rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
143 r.descsByName[d.FullName()] = d
144 })
145 r.filesByPath[path] = file
146 return nil
147}
148
149// FindDescriptorByName looks up a descriptor by the full name.
150//
151// This returns (nil, NotFound) if not found.
152func (r *Files) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
153 if r == nil {
154 return nil, NotFound
155 }
156 if r == GlobalFiles {
157 globalMutex.RLock()
158 defer globalMutex.RUnlock()
159 }
160 prefix := name
161 suffix := nameSuffix("")
162 for prefix != "" {
163 if d, ok := r.descsByName[prefix]; ok {
164 switch d := d.(type) {
165 case protoreflect.EnumDescriptor:
166 if d.FullName() == name {
167 return d, nil
168 }
169 case protoreflect.EnumValueDescriptor:
170 if d.FullName() == name {
171 return d, nil
172 }
173 case protoreflect.MessageDescriptor:
174 if d.FullName() == name {
175 return d, nil
176 }
177 if d := findDescriptorInMessage(d, suffix); d != nil && d.FullName() == name {
178 return d, nil
179 }
180 case protoreflect.ExtensionDescriptor:
181 if d.FullName() == name {
182 return d, nil
183 }
184 case protoreflect.ServiceDescriptor:
185 if d.FullName() == name {
186 return d, nil
187 }
188 if d := d.Methods().ByName(suffix.Pop()); d != nil && d.FullName() == name {
189 return d, nil
190 }
191 }
192 return nil, NotFound
193 }
194 prefix = prefix.Parent()
195 suffix = nameSuffix(name[len(prefix)+len("."):])
196 }
197 return nil, NotFound
198}
199
200func findDescriptorInMessage(md protoreflect.MessageDescriptor, suffix nameSuffix) protoreflect.Descriptor {
201 name := suffix.Pop()
202 if suffix == "" {
203 if ed := md.Enums().ByName(name); ed != nil {
204 return ed
205 }
206 for i := md.Enums().Len() - 1; i >= 0; i-- {
207 if vd := md.Enums().Get(i).Values().ByName(name); vd != nil {
208 return vd
209 }
210 }
211 if xd := md.Extensions().ByName(name); xd != nil {
212 return xd
213 }
214 if fd := md.Fields().ByName(name); fd != nil {
215 return fd
216 }
217 if od := md.Oneofs().ByName(name); od != nil {
218 return od
219 }
220 }
221 if md := md.Messages().ByName(name); md != nil {
222 if suffix == "" {
223 return md
224 }
225 return findDescriptorInMessage(md, suffix)
226 }
227 return nil
228}
229
230type nameSuffix string
231
232func (s *nameSuffix) Pop() (name protoreflect.Name) {
233 if i := strings.IndexByte(string(*s), '.'); i >= 0 {
234 name, *s = protoreflect.Name((*s)[:i]), (*s)[i+1:]
235 } else {
236 name, *s = protoreflect.Name((*s)), ""
237 }
238 return name
239}
240
241// FindFileByPath looks up a file by the path.
242//
243// This returns (nil, NotFound) if not found.
244func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
245 if r == nil {
246 return nil, NotFound
247 }
248 if r == GlobalFiles {
249 globalMutex.RLock()
250 defer globalMutex.RUnlock()
251 }
252 if fd, ok := r.filesByPath[path]; ok {
253 return fd, nil
254 }
255 return nil, NotFound
256}
257
258// NumFiles reports the number of registered files.
259func (r *Files) NumFiles() int {
260 if r == nil {
261 return 0
262 }
263 if r == GlobalFiles {
264 globalMutex.RLock()
265 defer globalMutex.RUnlock()
266 }
267 return len(r.filesByPath)
268}
269
270// RangeFiles iterates over all registered files while f returns true.
271// The iteration order is undefined.
272func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
273 if r == nil {
274 return
275 }
276 if r == GlobalFiles {
277 globalMutex.RLock()
278 defer globalMutex.RUnlock()
279 }
280 for _, file := range r.filesByPath {
281 if !f(file) {
282 return
283 }
284 }
285}
286
287// NumFilesByPackage reports the number of registered files in a proto package.
288func (r *Files) NumFilesByPackage(name protoreflect.FullName) int {
289 if r == nil {
290 return 0
291 }
292 if r == GlobalFiles {
293 globalMutex.RLock()
294 defer globalMutex.RUnlock()
295 }
296 p, ok := r.descsByName[name].(*packageDescriptor)
297 if !ok {
298 return 0
299 }
300 return len(p.files)
301}
302
303// RangeFilesByPackage iterates over all registered files in a given proto package
304// while f returns true. The iteration order is undefined.
305func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) {
306 if r == nil {
307 return
308 }
309 if r == GlobalFiles {
310 globalMutex.RLock()
311 defer globalMutex.RUnlock()
312 }
313 p, ok := r.descsByName[name].(*packageDescriptor)
314 if !ok {
315 return
316 }
317 for _, file := range p.files {
318 if !f(file) {
319 return
320 }
321 }
322}
323
324// rangeTopLevelDescriptors iterates over all top-level descriptors in a file
325// which will be directly entered into the registry.
326func rangeTopLevelDescriptors(fd protoreflect.FileDescriptor, f func(protoreflect.Descriptor)) {
327 eds := fd.Enums()
328 for i := eds.Len() - 1; i >= 0; i-- {
329 f(eds.Get(i))
330 vds := eds.Get(i).Values()
331 for i := vds.Len() - 1; i >= 0; i-- {
332 f(vds.Get(i))
333 }
334 }
335 mds := fd.Messages()
336 for i := mds.Len() - 1; i >= 0; i-- {
337 f(mds.Get(i))
338 }
339 xds := fd.Extensions()
340 for i := xds.Len() - 1; i >= 0; i-- {
341 f(xds.Get(i))
342 }
343 sds := fd.Services()
344 for i := sds.Len() - 1; i >= 0; i-- {
345 f(sds.Get(i))
346 }
347}
348
349// MessageTypeResolver is an interface for looking up messages.
350//
351// A compliant implementation must deterministically return the same type
352// if no error is encountered.
353//
354// The Types type implements this interface.
355type MessageTypeResolver interface {
356 // FindMessageByName looks up a message by its full name.
357 // E.g., "google.protobuf.Any"
358 //
359 // This return (nil, NotFound) if not found.
360 FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error)
361
362 // FindMessageByURL looks up a message by a URL identifier.
363 // See documentation on google.protobuf.Any.type_url for the URL format.
364 //
365 // This returns (nil, NotFound) if not found.
366 FindMessageByURL(url string) (protoreflect.MessageType, error)
367}
368
369// ExtensionTypeResolver is an interface for looking up extensions.
370//
371// A compliant implementation must deterministically return the same type
372// if no error is encountered.
373//
374// The Types type implements this interface.
375type ExtensionTypeResolver interface {
376 // FindExtensionByName looks up a extension field by the field's full name.
377 // Note that this is the full name of the field as determined by
378 // where the extension is declared and is unrelated to the full name of the
379 // message being extended.
380 //
381 // This returns (nil, NotFound) if not found.
382 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
383
384 // FindExtensionByNumber looks up a extension field by the field number
385 // within some parent message, identified by full name.
386 //
387 // This returns (nil, NotFound) if not found.
388 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
389}
390
391var (
392 _ MessageTypeResolver = (*Types)(nil)
393 _ ExtensionTypeResolver = (*Types)(nil)
394)
395
396// Types is a registry for looking up or iterating over descriptor types.
397// The Find and Range methods are safe for concurrent use.
398type Types struct {
399 typesByName typesByName
400 extensionsByMessage extensionsByMessage
401
402 numEnums int
403 numMessages int
404 numExtensions int
405}
406
407type (
408 typesByName map[protoreflect.FullName]interface{}
409 extensionsByMessage map[protoreflect.FullName]extensionsByNumber
410 extensionsByNumber map[protoreflect.FieldNumber]protoreflect.ExtensionType
411)
412
413// RegisterMessage registers the provided message type.
414//
415// If a naming conflict occurs, the type is not registered and an error is returned.
416func (r *Types) RegisterMessage(mt protoreflect.MessageType) error {
417 // Under rare circumstances getting the descriptor might recursively
418 // examine the registry, so fetch it before locking.
419 md := mt.Descriptor()
420
421 if r == GlobalTypes {
422 globalMutex.Lock()
423 defer globalMutex.Unlock()
424 }
425
426 if err := r.register("message", md, mt); err != nil {
427 return err
428 }
429 r.numMessages++
430 return nil
431}
432
433// RegisterEnum registers the provided enum type.
434//
435// If a naming conflict occurs, the type is not registered and an error is returned.
436func (r *Types) RegisterEnum(et protoreflect.EnumType) error {
437 // Under rare circumstances getting the descriptor might recursively
438 // examine the registry, so fetch it before locking.
439 ed := et.Descriptor()
440
441 if r == GlobalTypes {
442 globalMutex.Lock()
443 defer globalMutex.Unlock()
444 }
445
446 if err := r.register("enum", ed, et); err != nil {
447 return err
448 }
449 r.numEnums++
450 return nil
451}
452
453// RegisterExtension registers the provided extension type.
454//
455// If a naming conflict occurs, the type is not registered and an error is returned.
456func (r *Types) RegisterExtension(xt protoreflect.ExtensionType) error {
457 // Under rare circumstances getting the descriptor might recursively
458 // examine the registry, so fetch it before locking.
459 //
460 // A known case where this can happen: Fetching the TypeDescriptor for a
461 // legacy ExtensionDesc can consult the global registry.
462 xd := xt.TypeDescriptor()
463
464 if r == GlobalTypes {
465 globalMutex.Lock()
466 defer globalMutex.Unlock()
467 }
468
469 field := xd.Number()
470 message := xd.ContainingMessage().FullName()
471 if prev := r.extensionsByMessage[message][field]; prev != nil {
472 err := errors.New("extension number %d is already registered on message %v", field, message)
473 err = amendErrorWithCaller(err, prev, xt)
474 if !(r == GlobalTypes && ignoreConflict(xd, err)) {
475 return err
476 }
477 }
478
479 if err := r.register("extension", xd, xt); err != nil {
480 return err
481 }
482 if r.extensionsByMessage == nil {
483 r.extensionsByMessage = make(extensionsByMessage)
484 }
485 if r.extensionsByMessage[message] == nil {
486 r.extensionsByMessage[message] = make(extensionsByNumber)
487 }
488 r.extensionsByMessage[message][field] = xt
489 r.numExtensions++
490 return nil
491}
492
493func (r *Types) register(kind string, desc protoreflect.Descriptor, typ interface{}) error {
494 name := desc.FullName()
495 prev := r.typesByName[name]
496 if prev != nil {
497 err := errors.New("%v %v is already registered", kind, name)
498 err = amendErrorWithCaller(err, prev, typ)
499 if !(r == GlobalTypes && ignoreConflict(desc, err)) {
500 return err
501 }
502 }
503 if r.typesByName == nil {
504 r.typesByName = make(typesByName)
505 }
506 r.typesByName[name] = typ
507 return nil
508}
509
510// FindEnumByName looks up an enum by its full name.
511// E.g., "google.protobuf.Field.Kind".
512//
513// This returns (nil, NotFound) if not found.
514func (r *Types) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumType, error) {
515 if r == nil {
516 return nil, NotFound
517 }
518 if r == GlobalTypes {
519 globalMutex.RLock()
520 defer globalMutex.RUnlock()
521 }
522 if v := r.typesByName[enum]; v != nil {
523 if et, _ := v.(protoreflect.EnumType); et != nil {
524 return et, nil
525 }
526 return nil, errors.New("found wrong type: got %v, want enum", typeName(v))
527 }
528 return nil, NotFound
529}
530
531// FindMessageByName looks up a message by its full name.
532// E.g., "google.protobuf.Any"
533//
534// This return (nil, NotFound) if not found.
535func (r *Types) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
536 // The full name by itself is a valid URL.
537 return r.FindMessageByURL(string(message))
538}
539
540// FindMessageByURL looks up a message by a URL identifier.
541// See documentation on google.protobuf.Any.type_url for the URL format.
542//
543// This returns (nil, NotFound) if not found.
544func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
545 if r == nil {
546 return nil, NotFound
547 }
548 if r == GlobalTypes {
549 globalMutex.RLock()
550 defer globalMutex.RUnlock()
551 }
552 message := protoreflect.FullName(url)
553 if i := strings.LastIndexByte(url, '/'); i >= 0 {
554 message = message[i+len("/"):]
555 }
556
557 if v := r.typesByName[message]; v != nil {
558 if mt, _ := v.(protoreflect.MessageType); mt != nil {
559 return mt, nil
560 }
561 return nil, errors.New("found wrong type: got %v, want message", typeName(v))
562 }
563 return nil, NotFound
564}
565
566// FindExtensionByName looks up a extension field by the field's full name.
567// Note that this is the full name of the field as determined by
568// where the extension is declared and is unrelated to the full name of the
569// message being extended.
570//
571// This returns (nil, NotFound) if not found.
572func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
573 if r == nil {
574 return nil, NotFound
575 }
576 if r == GlobalTypes {
577 globalMutex.RLock()
578 defer globalMutex.RUnlock()
579 }
580 if v := r.typesByName[field]; v != nil {
581 if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
582 return xt, nil
583 }
584 return nil, errors.New("found wrong type: got %v, want extension", typeName(v))
585 }
586 return nil, NotFound
587}
588
589// FindExtensionByNumber looks up a extension field by the field number
590// within some parent message, identified by full name.
591//
592// This returns (nil, NotFound) if not found.
593func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
594 if r == nil {
595 return nil, NotFound
596 }
597 if r == GlobalTypes {
598 globalMutex.RLock()
599 defer globalMutex.RUnlock()
600 }
601 if xt, ok := r.extensionsByMessage[message][field]; ok {
602 return xt, nil
603 }
604 return nil, NotFound
605}
606
607// NumEnums reports the number of registered enums.
608func (r *Types) NumEnums() int {
609 if r == nil {
610 return 0
611 }
612 if r == GlobalTypes {
613 globalMutex.RLock()
614 defer globalMutex.RUnlock()
615 }
616 return r.numEnums
617}
618
619// RangeEnums iterates over all registered enums while f returns true.
620// Iteration order is undefined.
621func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
622 if r == nil {
623 return
624 }
625 if r == GlobalTypes {
626 globalMutex.RLock()
627 defer globalMutex.RUnlock()
628 }
629 for _, typ := range r.typesByName {
630 if et, ok := typ.(protoreflect.EnumType); ok {
631 if !f(et) {
632 return
633 }
634 }
635 }
636}
637
638// NumMessages reports the number of registered messages.
639func (r *Types) NumMessages() int {
640 if r == nil {
641 return 0
642 }
643 if r == GlobalTypes {
644 globalMutex.RLock()
645 defer globalMutex.RUnlock()
646 }
647 return r.numMessages
648}
649
650// RangeMessages iterates over all registered messages while f returns true.
651// Iteration order is undefined.
652func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
653 if r == nil {
654 return
655 }
656 if r == GlobalTypes {
657 globalMutex.RLock()
658 defer globalMutex.RUnlock()
659 }
660 for _, typ := range r.typesByName {
661 if mt, ok := typ.(protoreflect.MessageType); ok {
662 if !f(mt) {
663 return
664 }
665 }
666 }
667}
668
669// NumExtensions reports the number of registered extensions.
670func (r *Types) NumExtensions() int {
671 if r == nil {
672 return 0
673 }
674 if r == GlobalTypes {
675 globalMutex.RLock()
676 defer globalMutex.RUnlock()
677 }
678 return r.numExtensions
679}
680
681// RangeExtensions iterates over all registered extensions while f returns true.
682// Iteration order is undefined.
683func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
684 if r == nil {
685 return
686 }
687 if r == GlobalTypes {
688 globalMutex.RLock()
689 defer globalMutex.RUnlock()
690 }
691 for _, typ := range r.typesByName {
692 if xt, ok := typ.(protoreflect.ExtensionType); ok {
693 if !f(xt) {
694 return
695 }
696 }
697 }
698}
699
700// NumExtensionsByMessage reports the number of registered extensions for
701// a given message type.
702func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int {
703 if r == nil {
704 return 0
705 }
706 if r == GlobalTypes {
707 globalMutex.RLock()
708 defer globalMutex.RUnlock()
709 }
710 return len(r.extensionsByMessage[message])
711}
712
713// RangeExtensionsByMessage iterates over all registered extensions filtered
714// by a given message type while f returns true. Iteration order is undefined.
715func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {
716 if r == nil {
717 return
718 }
719 if r == GlobalTypes {
720 globalMutex.RLock()
721 defer globalMutex.RUnlock()
722 }
723 for _, xt := range r.extensionsByMessage[message] {
724 if !f(xt) {
725 return
726 }
727 }
728}
729
730func typeName(t interface{}) string {
731 switch t.(type) {
732 case protoreflect.EnumType:
733 return "enum"
734 case protoreflect.MessageType:
735 return "message"
736 case protoreflect.ExtensionType:
737 return "extension"
738 default:
739 return fmt.Sprintf("%T", t)
740 }
741}
742
743func amendErrorWithCaller(err error, prev, curr interface{}) error {
744 prevPkg := goPackage(prev)
745 currPkg := goPackage(curr)
746 if prevPkg == "" || currPkg == "" || prevPkg == currPkg {
747 return err
748 }
749 return errors.New("%s\n\tpreviously from: %q\n\tcurrently from: %q", err, prevPkg, currPkg)
750}
751
752func goPackage(v interface{}) string {
753 switch d := v.(type) {
754 case protoreflect.EnumType:
755 v = d.Descriptor()
756 case protoreflect.MessageType:
757 v = d.Descriptor()
758 case protoreflect.ExtensionType:
759 v = d.TypeDescriptor()
760 }
761 if d, ok := v.(protoreflect.Descriptor); ok {
762 v = d.ParentFile()
763 }
764 if d, ok := v.(interface{ GoPackagePath() string }); ok {
765 return d.GoPackagePath()
766 }
767 return ""
768}