blob: 68768278f202d95768414a95aea6b4be47ad1093 [file] [log] [blame]
khenaidooefff76e2021-12-15 16:51:30 -05001package dynamic
2
3import (
4 "fmt"
5 "reflect"
6 "sync"
7
8 "github.com/golang/protobuf/proto"
9
10 "github.com/jhump/protoreflect/desc"
11)
12
13// ExtensionRegistry is a registry of known extension fields. This is used to parse
14// extension fields encountered when de-serializing a dynamic message.
15type ExtensionRegistry struct {
16 includeDefault bool
17 mu sync.RWMutex
18 exts map[string]map[int32]*desc.FieldDescriptor
19}
20
21// NewExtensionRegistryWithDefaults is a registry that includes all "default" extensions,
22// which are those that are statically linked into the current program (e.g. registered by
23// protoc-generated code via proto.RegisterExtension). Extensions explicitly added to the
24// registry will override any default extensions that are for the same extendee and have the
25// same tag number and/or name.
26func NewExtensionRegistryWithDefaults() *ExtensionRegistry {
27 return &ExtensionRegistry{includeDefault: true}
28}
29
30// AddExtensionDesc adds the given extensions to the registry.
31func (r *ExtensionRegistry) AddExtensionDesc(exts ...*proto.ExtensionDesc) error {
32 flds := make([]*desc.FieldDescriptor, len(exts))
33 for i, ext := range exts {
34 fd, err := desc.LoadFieldDescriptorForExtension(ext)
35 if err != nil {
36 return err
37 }
38 flds[i] = fd
39 }
40 r.mu.Lock()
41 defer r.mu.Unlock()
42 if r.exts == nil {
43 r.exts = map[string]map[int32]*desc.FieldDescriptor{}
44 }
45 for _, fd := range flds {
46 r.putExtensionLocked(fd)
47 }
48 return nil
49}
50
51// AddExtension adds the given extensions to the registry. The given extensions
52// will overwrite any previously added extensions that are for the same extendee
53// message and same extension tag number.
54func (r *ExtensionRegistry) AddExtension(exts ...*desc.FieldDescriptor) error {
55 for _, ext := range exts {
56 if !ext.IsExtension() {
57 return fmt.Errorf("given field is not an extension: %s", ext.GetFullyQualifiedName())
58 }
59 }
60 r.mu.Lock()
61 defer r.mu.Unlock()
62 if r.exts == nil {
63 r.exts = map[string]map[int32]*desc.FieldDescriptor{}
64 }
65 for _, ext := range exts {
66 r.putExtensionLocked(ext)
67 }
68 return nil
69}
70
71// AddExtensionsFromFile adds to the registry all extension fields defined in the given file descriptor.
72func (r *ExtensionRegistry) AddExtensionsFromFile(fd *desc.FileDescriptor) {
73 r.mu.Lock()
74 defer r.mu.Unlock()
75 r.addExtensionsFromFileLocked(fd, false, nil)
76}
77
78// AddExtensionsFromFileRecursively adds to the registry all extension fields defined in the give file
79// descriptor and also recursively adds all extensions defined in that file's dependencies. This adds
80// extensions from the entire transitive closure for the given file.
81func (r *ExtensionRegistry) AddExtensionsFromFileRecursively(fd *desc.FileDescriptor) {
82 r.mu.Lock()
83 defer r.mu.Unlock()
84 already := map[*desc.FileDescriptor]struct{}{}
85 r.addExtensionsFromFileLocked(fd, true, already)
86}
87
88func (r *ExtensionRegistry) addExtensionsFromFileLocked(fd *desc.FileDescriptor, recursive bool, alreadySeen map[*desc.FileDescriptor]struct{}) {
89 if _, ok := alreadySeen[fd]; ok {
90 return
91 }
92
93 if r.exts == nil {
94 r.exts = map[string]map[int32]*desc.FieldDescriptor{}
95 }
96 for _, ext := range fd.GetExtensions() {
97 r.putExtensionLocked(ext)
98 }
99 for _, msg := range fd.GetMessageTypes() {
100 r.addExtensionsFromMessageLocked(msg)
101 }
102
103 if recursive {
104 alreadySeen[fd] = struct{}{}
105 for _, dep := range fd.GetDependencies() {
106 r.addExtensionsFromFileLocked(dep, recursive, alreadySeen)
107 }
108 }
109}
110
111func (r *ExtensionRegistry) addExtensionsFromMessageLocked(md *desc.MessageDescriptor) {
112 for _, ext := range md.GetNestedExtensions() {
113 r.putExtensionLocked(ext)
114 }
115 for _, msg := range md.GetNestedMessageTypes() {
116 r.addExtensionsFromMessageLocked(msg)
117 }
118}
119
120func (r *ExtensionRegistry) putExtensionLocked(fd *desc.FieldDescriptor) {
121 msgName := fd.GetOwner().GetFullyQualifiedName()
122 m := r.exts[msgName]
123 if m == nil {
124 m = map[int32]*desc.FieldDescriptor{}
125 r.exts[msgName] = m
126 }
127 m[fd.GetNumber()] = fd
128}
129
130// FindExtension queries for the extension field with the given extendee name (must be a fully-qualified
131// message name) and tag number. If no extension is known, nil is returned.
132func (r *ExtensionRegistry) FindExtension(messageName string, tagNumber int32) *desc.FieldDescriptor {
133 if r == nil {
134 return nil
135 }
136 r.mu.RLock()
137 defer r.mu.RUnlock()
138 fd := r.exts[messageName][tagNumber]
139 if fd == nil && r.includeDefault {
140 ext := getDefaultExtensions(messageName)[tagNumber]
141 if ext != nil {
142 fd, _ = desc.LoadFieldDescriptorForExtension(ext)
143 }
144 }
145 return fd
146}
147
148// FindExtensionByName queries for the extension field with the given extendee name (must be a fully-qualified
149// message name) and field name (must also be a fully-qualified extension name). If no extension is known, nil
150// is returned.
151func (r *ExtensionRegistry) FindExtensionByName(messageName string, fieldName string) *desc.FieldDescriptor {
152 if r == nil {
153 return nil
154 }
155 r.mu.RLock()
156 defer r.mu.RUnlock()
157 for _, fd := range r.exts[messageName] {
158 if fd.GetFullyQualifiedName() == fieldName {
159 return fd
160 }
161 }
162 if r.includeDefault {
163 for _, ext := range getDefaultExtensions(messageName) {
164 fd, _ := desc.LoadFieldDescriptorForExtension(ext)
165 if fd.GetFullyQualifiedName() == fieldName {
166 return fd
167 }
168 }
169 }
170 return nil
171}
172
173// FindExtensionByJSONName queries for the extension field with the given extendee name (must be a fully-qualified
174// message name) and JSON field name (must also be a fully-qualified name). If no extension is known, nil is returned.
175// The fully-qualified JSON name is the same as the extension's normal fully-qualified name except that the last
176// component uses the field's JSON name (if present).
177func (r *ExtensionRegistry) FindExtensionByJSONName(messageName string, fieldName string) *desc.FieldDescriptor {
178 if r == nil {
179 return nil
180 }
181 r.mu.RLock()
182 defer r.mu.RUnlock()
183 for _, fd := range r.exts[messageName] {
184 if fd.GetFullyQualifiedJSONName() == fieldName {
185 return fd
186 }
187 }
188 if r.includeDefault {
189 for _, ext := range getDefaultExtensions(messageName) {
190 fd, _ := desc.LoadFieldDescriptorForExtension(ext)
191 if fd.GetFullyQualifiedJSONName() == fieldName {
192 return fd
193 }
194 }
195 }
196 return nil
197}
198
199func getDefaultExtensions(messageName string) map[int32]*proto.ExtensionDesc {
200 t := proto.MessageType(messageName)
201 if t != nil {
202 msg := reflect.Zero(t).Interface().(proto.Message)
203 return proto.RegisteredExtensions(msg)
204 }
205 return nil
206}
207
208// AllExtensionsForType returns all known extension fields for the given extendee name (must be a
209// fully-qualified message name).
210func (r *ExtensionRegistry) AllExtensionsForType(messageName string) []*desc.FieldDescriptor {
211 if r == nil {
212 return []*desc.FieldDescriptor(nil)
213 }
214 r.mu.RLock()
215 defer r.mu.RUnlock()
216 flds := r.exts[messageName]
217 var ret []*desc.FieldDescriptor
218 if r.includeDefault {
219 exts := getDefaultExtensions(messageName)
220 if len(exts) > 0 || len(flds) > 0 {
221 ret = make([]*desc.FieldDescriptor, 0, len(exts)+len(flds))
222 }
223 for tag, ext := range exts {
224 if _, ok := flds[tag]; ok {
225 // skip default extension and use the one explicitly registered instead
226 continue
227 }
228 fd, _ := desc.LoadFieldDescriptorForExtension(ext)
229 if fd != nil {
230 ret = append(ret, fd)
231 }
232 }
233 } else if len(flds) > 0 {
234 ret = make([]*desc.FieldDescriptor, 0, len(flds))
235 }
236
237 for _, ext := range flds {
238 ret = append(ret, ext)
239 }
240 return ret
241}