blob: 4a05830d1f613cec09b3bcf851697feb5bfd4337 [file] [log] [blame]
khenaidoof3333552021-12-15 16:52:31 -05001package desc
2
3import (
4 "fmt"
5 "reflect"
6 "sync"
7
8 "github.com/golang/protobuf/proto"
9 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
10
11 "github.com/jhump/protoreflect/internal"
12)
13
14var (
15 cacheMu sync.RWMutex
16 filesCache = map[string]*FileDescriptor{}
17 messagesCache = map[string]*MessageDescriptor{}
18 enumCache = map[reflect.Type]*EnumDescriptor{}
19)
20
21// LoadFileDescriptor creates a file descriptor using the bytes returned by
22// proto.FileDescriptor. Descriptors are cached so that they do not need to be
23// re-processed if the same file is fetched again later.
24func LoadFileDescriptor(file string) (*FileDescriptor, error) {
25 return loadFileDescriptor(file, nil)
26}
27
28func loadFileDescriptor(file string, r *ImportResolver) (*FileDescriptor, error) {
29 f := getFileFromCache(file)
30 if f != nil {
31 return f, nil
32 }
33 cacheMu.Lock()
34 defer cacheMu.Unlock()
35 return loadFileDescriptorLocked(file, r)
36}
37
38func loadFileDescriptorLocked(file string, r *ImportResolver) (*FileDescriptor, error) {
39 f := filesCache[file]
40 if f != nil {
41 return f, nil
42 }
43 fd, err := internal.LoadFileDescriptor(file)
44 if err != nil {
45 return nil, err
46 }
47
48 f, err = toFileDescriptorLocked(fd, r)
49 if err != nil {
50 return nil, err
51 }
52 putCacheLocked(file, f)
53 return f, nil
54}
55
56func toFileDescriptorLocked(fd *dpb.FileDescriptorProto, r *ImportResolver) (*FileDescriptor, error) {
57 deps := make([]*FileDescriptor, len(fd.GetDependency()))
58 for i, dep := range fd.GetDependency() {
59 resolvedDep := r.ResolveImport(fd.GetName(), dep)
60 var err error
61 deps[i], err = loadFileDescriptorLocked(resolvedDep, r)
62 if _, ok := err.(internal.ErrNoSuchFile); ok && resolvedDep != dep {
63 // try original path
64 deps[i], err = loadFileDescriptorLocked(dep, r)
65 }
66 if err != nil {
67 return nil, err
68 }
69 }
70 return CreateFileDescriptor(fd, deps...)
71}
72
73func getFileFromCache(file string) *FileDescriptor {
74 cacheMu.RLock()
75 defer cacheMu.RUnlock()
76 return filesCache[file]
77}
78
79func putCacheLocked(filename string, fd *FileDescriptor) {
80 filesCache[filename] = fd
81 putMessageCacheLocked(fd.messages)
82}
83
84func putMessageCacheLocked(mds []*MessageDescriptor) {
85 for _, md := range mds {
86 messagesCache[md.fqn] = md
87 putMessageCacheLocked(md.nested)
88 }
89}
90
91// interface implemented by generated messages, which all have a Descriptor() method in
92// addition to the methods of proto.Message
93type protoMessage interface {
94 proto.Message
95 Descriptor() ([]byte, []int)
96}
97
98// LoadMessageDescriptor loads descriptor using the encoded descriptor proto returned by
99// Message.Descriptor() for the given message type. If the given type is not recognized,
100// then a nil descriptor is returned.
101func LoadMessageDescriptor(message string) (*MessageDescriptor, error) {
102 return loadMessageDescriptor(message, nil)
103}
104
105func loadMessageDescriptor(message string, r *ImportResolver) (*MessageDescriptor, error) {
106 m := getMessageFromCache(message)
107 if m != nil {
108 return m, nil
109 }
110
111 pt := proto.MessageType(message)
112 if pt == nil {
113 return nil, nil
114 }
115 msg, err := messageFromType(pt)
116 if err != nil {
117 return nil, err
118 }
119
120 cacheMu.Lock()
121 defer cacheMu.Unlock()
122 return loadMessageDescriptorForTypeLocked(message, msg, r)
123}
124
125// LoadMessageDescriptorForType loads descriptor using the encoded descriptor proto returned
126// by message.Descriptor() for the given message type. If the given type is not recognized,
127// then a nil descriptor is returned.
128func LoadMessageDescriptorForType(messageType reflect.Type) (*MessageDescriptor, error) {
129 return loadMessageDescriptorForType(messageType, nil)
130}
131
132func loadMessageDescriptorForType(messageType reflect.Type, r *ImportResolver) (*MessageDescriptor, error) {
133 m, err := messageFromType(messageType)
134 if err != nil {
135 return nil, err
136 }
137 return loadMessageDescriptorForMessage(m, r)
138}
139
140// LoadMessageDescriptorForMessage loads descriptor using the encoded descriptor proto
141// returned by message.Descriptor(). If the given type is not recognized, then a nil
142// descriptor is returned.
143func LoadMessageDescriptorForMessage(message proto.Message) (*MessageDescriptor, error) {
144 return loadMessageDescriptorForMessage(message, nil)
145}
146
147func loadMessageDescriptorForMessage(message proto.Message, r *ImportResolver) (*MessageDescriptor, error) {
148 // efficiently handle dynamic messages
149 type descriptorable interface {
150 GetMessageDescriptor() *MessageDescriptor
151 }
152 if d, ok := message.(descriptorable); ok {
153 return d.GetMessageDescriptor(), nil
154 }
155
156 name := proto.MessageName(message)
157 if name == "" {
158 return nil, nil
159 }
160 m := getMessageFromCache(name)
161 if m != nil {
162 return m, nil
163 }
164
165 cacheMu.Lock()
166 defer cacheMu.Unlock()
167 return loadMessageDescriptorForTypeLocked(name, message.(protoMessage), nil)
168}
169
170func messageFromType(mt reflect.Type) (protoMessage, error) {
171 if mt.Kind() != reflect.Ptr {
172 mt = reflect.PtrTo(mt)
173 }
174 m, ok := reflect.Zero(mt).Interface().(protoMessage)
175 if !ok {
176 return nil, fmt.Errorf("failed to create message from type: %v", mt)
177 }
178 return m, nil
179}
180
181func loadMessageDescriptorForTypeLocked(name string, message protoMessage, r *ImportResolver) (*MessageDescriptor, error) {
182 m := messagesCache[name]
183 if m != nil {
184 return m, nil
185 }
186
187 fdb, _ := message.Descriptor()
188 fd, err := internal.DecodeFileDescriptor(name, fdb)
189 if err != nil {
190 return nil, err
191 }
192
193 f, err := toFileDescriptorLocked(fd, r)
194 if err != nil {
195 return nil, err
196 }
197 putCacheLocked(fd.GetName(), f)
198 return f.FindSymbol(name).(*MessageDescriptor), nil
199}
200
201func getMessageFromCache(message string) *MessageDescriptor {
202 cacheMu.RLock()
203 defer cacheMu.RUnlock()
204 return messagesCache[message]
205}
206
207// interface implemented by all generated enums
208type protoEnum interface {
209 EnumDescriptor() ([]byte, []int)
210}
211
212// NB: There is no LoadEnumDescriptor that takes a fully-qualified enum name because
213// it is not useful since protoc-gen-go does not expose the name anywhere in generated
214// code or register it in a way that is it accessible for reflection code. This also
215// means we have to cache enum descriptors differently -- we can only cache them as
216// they are requested, as opposed to caching all enum types whenever a file descriptor
217// is cached. This is because we need to know the generated type of the enums, and we
218// don't know that at the time of caching file descriptors.
219
220// LoadEnumDescriptorForType loads descriptor using the encoded descriptor proto returned
221// by enum.EnumDescriptor() for the given enum type.
222func LoadEnumDescriptorForType(enumType reflect.Type) (*EnumDescriptor, error) {
223 return loadEnumDescriptorForType(enumType, nil)
224}
225
226func loadEnumDescriptorForType(enumType reflect.Type, r *ImportResolver) (*EnumDescriptor, error) {
227 // we cache descriptors using non-pointer type
228 if enumType.Kind() == reflect.Ptr {
229 enumType = enumType.Elem()
230 }
231 e := getEnumFromCache(enumType)
232 if e != nil {
233 return e, nil
234 }
235 enum, err := enumFromType(enumType)
236 if err != nil {
237 return nil, err
238 }
239
240 cacheMu.Lock()
241 defer cacheMu.Unlock()
242 return loadEnumDescriptorForTypeLocked(enumType, enum, r)
243}
244
245// LoadEnumDescriptorForEnum loads descriptor using the encoded descriptor proto
246// returned by enum.EnumDescriptor().
247func LoadEnumDescriptorForEnum(enum protoEnum) (*EnumDescriptor, error) {
248 return loadEnumDescriptorForEnum(enum, nil)
249}
250
251func loadEnumDescriptorForEnum(enum protoEnum, r *ImportResolver) (*EnumDescriptor, error) {
252 et := reflect.TypeOf(enum)
253 // we cache descriptors using non-pointer type
254 if et.Kind() == reflect.Ptr {
255 et = et.Elem()
256 enum = reflect.Zero(et).Interface().(protoEnum)
257 }
258 e := getEnumFromCache(et)
259 if e != nil {
260 return e, nil
261 }
262
263 cacheMu.Lock()
264 defer cacheMu.Unlock()
265 return loadEnumDescriptorForTypeLocked(et, enum, r)
266}
267
268func enumFromType(et reflect.Type) (protoEnum, error) {
269 if et.Kind() != reflect.Int32 {
270 et = reflect.PtrTo(et)
271 }
272 e, ok := reflect.Zero(et).Interface().(protoEnum)
273 if !ok {
274 return nil, fmt.Errorf("failed to create enum from type: %v", et)
275 }
276 return e, nil
277}
278
279func loadEnumDescriptorForTypeLocked(et reflect.Type, enum protoEnum, r *ImportResolver) (*EnumDescriptor, error) {
280 e := enumCache[et]
281 if e != nil {
282 return e, nil
283 }
284
285 fdb, path := enum.EnumDescriptor()
286 name := fmt.Sprintf("%v", et)
287 fd, err := internal.DecodeFileDescriptor(name, fdb)
288 if err != nil {
289 return nil, err
290 }
291 // see if we already have cached "rich" descriptor
292 f, ok := filesCache[fd.GetName()]
293 if !ok {
294 f, err = toFileDescriptorLocked(fd, r)
295 if err != nil {
296 return nil, err
297 }
298 putCacheLocked(fd.GetName(), f)
299 }
300
301 ed := findEnum(f, path)
302 enumCache[et] = ed
303 return ed, nil
304}
305
306func getEnumFromCache(et reflect.Type) *EnumDescriptor {
307 cacheMu.RLock()
308 defer cacheMu.RUnlock()
309 return enumCache[et]
310}
311
312func findEnum(fd *FileDescriptor, path []int) *EnumDescriptor {
313 if len(path) == 1 {
314 return fd.GetEnumTypes()[path[0]]
315 }
316 md := fd.GetMessageTypes()[path[0]]
317 for _, i := range path[1 : len(path)-1] {
318 md = md.GetNestedMessageTypes()[i]
319 }
320 return md.GetNestedEnumTypes()[path[len(path)-1]]
321}
322
323// LoadFieldDescriptorForExtension loads the field descriptor that corresponds to the given
324// extension description.
325func LoadFieldDescriptorForExtension(ext *proto.ExtensionDesc) (*FieldDescriptor, error) {
326 return loadFieldDescriptorForExtension(ext, nil)
327}
328
329func loadFieldDescriptorForExtension(ext *proto.ExtensionDesc, r *ImportResolver) (*FieldDescriptor, error) {
330 file, err := loadFileDescriptor(ext.Filename, r)
331 if err != nil {
332 return nil, err
333 }
334 field, ok := file.FindSymbol(ext.Name).(*FieldDescriptor)
335 // make sure descriptor agrees with attributes of the ExtensionDesc
336 if !ok || !field.IsExtension() || field.GetOwner().GetFullyQualifiedName() != proto.MessageName(ext.ExtendedType) ||
337 field.GetNumber() != ext.Field {
338 return nil, fmt.Errorf("file descriptor contained unexpected object with name %s", ext.Name)
339 }
340 return field, nil
341}