blob: 68768278f202d95768414a95aea6b4be47ad1093 [file] [log] [blame]
package dynamic
import (
"fmt"
"reflect"
"sync"
"github.com/golang/protobuf/proto"
"github.com/jhump/protoreflect/desc"
)
// ExtensionRegistry is a registry of known extension fields. This is used to parse
// extension fields encountered when de-serializing a dynamic message.
type ExtensionRegistry struct {
includeDefault bool
mu sync.RWMutex
exts map[string]map[int32]*desc.FieldDescriptor
}
// NewExtensionRegistryWithDefaults is a registry that includes all "default" extensions,
// which are those that are statically linked into the current program (e.g. registered by
// protoc-generated code via proto.RegisterExtension). Extensions explicitly added to the
// registry will override any default extensions that are for the same extendee and have the
// same tag number and/or name.
func NewExtensionRegistryWithDefaults() *ExtensionRegistry {
return &ExtensionRegistry{includeDefault: true}
}
// AddExtensionDesc adds the given extensions to the registry.
func (r *ExtensionRegistry) AddExtensionDesc(exts ...*proto.ExtensionDesc) error {
flds := make([]*desc.FieldDescriptor, len(exts))
for i, ext := range exts {
fd, err := desc.LoadFieldDescriptorForExtension(ext)
if err != nil {
return err
}
flds[i] = fd
}
r.mu.Lock()
defer r.mu.Unlock()
if r.exts == nil {
r.exts = map[string]map[int32]*desc.FieldDescriptor{}
}
for _, fd := range flds {
r.putExtensionLocked(fd)
}
return nil
}
// AddExtension adds the given extensions to the registry. The given extensions
// will overwrite any previously added extensions that are for the same extendee
// message and same extension tag number.
func (r *ExtensionRegistry) AddExtension(exts ...*desc.FieldDescriptor) error {
for _, ext := range exts {
if !ext.IsExtension() {
return fmt.Errorf("given field is not an extension: %s", ext.GetFullyQualifiedName())
}
}
r.mu.Lock()
defer r.mu.Unlock()
if r.exts == nil {
r.exts = map[string]map[int32]*desc.FieldDescriptor{}
}
for _, ext := range exts {
r.putExtensionLocked(ext)
}
return nil
}
// AddExtensionsFromFile adds to the registry all extension fields defined in the given file descriptor.
func (r *ExtensionRegistry) AddExtensionsFromFile(fd *desc.FileDescriptor) {
r.mu.Lock()
defer r.mu.Unlock()
r.addExtensionsFromFileLocked(fd, false, nil)
}
// AddExtensionsFromFileRecursively adds to the registry all extension fields defined in the give file
// descriptor and also recursively adds all extensions defined in that file's dependencies. This adds
// extensions from the entire transitive closure for the given file.
func (r *ExtensionRegistry) AddExtensionsFromFileRecursively(fd *desc.FileDescriptor) {
r.mu.Lock()
defer r.mu.Unlock()
already := map[*desc.FileDescriptor]struct{}{}
r.addExtensionsFromFileLocked(fd, true, already)
}
func (r *ExtensionRegistry) addExtensionsFromFileLocked(fd *desc.FileDescriptor, recursive bool, alreadySeen map[*desc.FileDescriptor]struct{}) {
if _, ok := alreadySeen[fd]; ok {
return
}
if r.exts == nil {
r.exts = map[string]map[int32]*desc.FieldDescriptor{}
}
for _, ext := range fd.GetExtensions() {
r.putExtensionLocked(ext)
}
for _, msg := range fd.GetMessageTypes() {
r.addExtensionsFromMessageLocked(msg)
}
if recursive {
alreadySeen[fd] = struct{}{}
for _, dep := range fd.GetDependencies() {
r.addExtensionsFromFileLocked(dep, recursive, alreadySeen)
}
}
}
func (r *ExtensionRegistry) addExtensionsFromMessageLocked(md *desc.MessageDescriptor) {
for _, ext := range md.GetNestedExtensions() {
r.putExtensionLocked(ext)
}
for _, msg := range md.GetNestedMessageTypes() {
r.addExtensionsFromMessageLocked(msg)
}
}
func (r *ExtensionRegistry) putExtensionLocked(fd *desc.FieldDescriptor) {
msgName := fd.GetOwner().GetFullyQualifiedName()
m := r.exts[msgName]
if m == nil {
m = map[int32]*desc.FieldDescriptor{}
r.exts[msgName] = m
}
m[fd.GetNumber()] = fd
}
// FindExtension queries for the extension field with the given extendee name (must be a fully-qualified
// message name) and tag number. If no extension is known, nil is returned.
func (r *ExtensionRegistry) FindExtension(messageName string, tagNumber int32) *desc.FieldDescriptor {
if r == nil {
return nil
}
r.mu.RLock()
defer r.mu.RUnlock()
fd := r.exts[messageName][tagNumber]
if fd == nil && r.includeDefault {
ext := getDefaultExtensions(messageName)[tagNumber]
if ext != nil {
fd, _ = desc.LoadFieldDescriptorForExtension(ext)
}
}
return fd
}
// FindExtensionByName queries for the extension field with the given extendee name (must be a fully-qualified
// message name) and field name (must also be a fully-qualified extension name). If no extension is known, nil
// is returned.
func (r *ExtensionRegistry) FindExtensionByName(messageName string, fieldName string) *desc.FieldDescriptor {
if r == nil {
return nil
}
r.mu.RLock()
defer r.mu.RUnlock()
for _, fd := range r.exts[messageName] {
if fd.GetFullyQualifiedName() == fieldName {
return fd
}
}
if r.includeDefault {
for _, ext := range getDefaultExtensions(messageName) {
fd, _ := desc.LoadFieldDescriptorForExtension(ext)
if fd.GetFullyQualifiedName() == fieldName {
return fd
}
}
}
return nil
}
// FindExtensionByJSONName queries for the extension field with the given extendee name (must be a fully-qualified
// message name) and JSON field name (must also be a fully-qualified name). If no extension is known, nil is returned.
// The fully-qualified JSON name is the same as the extension's normal fully-qualified name except that the last
// component uses the field's JSON name (if present).
func (r *ExtensionRegistry) FindExtensionByJSONName(messageName string, fieldName string) *desc.FieldDescriptor {
if r == nil {
return nil
}
r.mu.RLock()
defer r.mu.RUnlock()
for _, fd := range r.exts[messageName] {
if fd.GetFullyQualifiedJSONName() == fieldName {
return fd
}
}
if r.includeDefault {
for _, ext := range getDefaultExtensions(messageName) {
fd, _ := desc.LoadFieldDescriptorForExtension(ext)
if fd.GetFullyQualifiedJSONName() == fieldName {
return fd
}
}
}
return nil
}
func getDefaultExtensions(messageName string) map[int32]*proto.ExtensionDesc {
t := proto.MessageType(messageName)
if t != nil {
msg := reflect.Zero(t).Interface().(proto.Message)
return proto.RegisteredExtensions(msg)
}
return nil
}
// AllExtensionsForType returns all known extension fields for the given extendee name (must be a
// fully-qualified message name).
func (r *ExtensionRegistry) AllExtensionsForType(messageName string) []*desc.FieldDescriptor {
if r == nil {
return []*desc.FieldDescriptor(nil)
}
r.mu.RLock()
defer r.mu.RUnlock()
flds := r.exts[messageName]
var ret []*desc.FieldDescriptor
if r.includeDefault {
exts := getDefaultExtensions(messageName)
if len(exts) > 0 || len(flds) > 0 {
ret = make([]*desc.FieldDescriptor, 0, len(exts)+len(flds))
}
for tag, ext := range exts {
if _, ok := flds[tag]; ok {
// skip default extension and use the one explicitly registered instead
continue
}
fd, _ := desc.LoadFieldDescriptorForExtension(ext)
if fd != nil {
ret = append(ret, fd)
}
}
} else if len(flds) > 0 {
ret = make([]*desc.FieldDescriptor, 0, len(flds))
}
for _, ext := range flds {
ret = append(ret, ext)
}
return ret
}