blob: 9488b726131366509be747a136d4ba01bfcdf98a [file] [log] [blame]
David K. Bainbridgee05cf0c2021-08-19 03:16:50 +00001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10
11 "google.golang.org/protobuf/internal/detrand"
12 "google.golang.org/protobuf/internal/pragma"
13 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16type reflectMessageInfo struct {
17 fields map[pref.FieldNumber]*fieldInfo
18 oneofs map[pref.Name]*oneofInfo
19
20 // fieldTypes contains the zero value of an enum or message field.
21 // For lists, it contains the element type.
22 // For maps, it contains the entry value type.
23 fieldTypes map[pref.FieldNumber]interface{}
24
25 // denseFields is a subset of fields where:
26 // 0 < fieldDesc.Number() < len(denseFields)
27 // It provides faster access to the fieldInfo, but may be incomplete.
28 denseFields []*fieldInfo
29
30 // rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
31 rangeInfos []interface{} // either *fieldInfo or *oneofInfo
32
33 getUnknown func(pointer) pref.RawFields
34 setUnknown func(pointer, pref.RawFields)
35 extensionMap func(pointer) *extensionMap
36
37 nilMessage atomicNilMessage
38}
39
40// makeReflectFuncs generates the set of functions to support reflection.
41func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
42 mi.makeKnownFieldsFunc(si)
43 mi.makeUnknownFieldsFunc(t, si)
44 mi.makeExtensionFieldsFunc(t, si)
45 mi.makeFieldTypes(si)
46}
47
48// makeKnownFieldsFunc generates functions for operations that can be performed
49// on each protobuf message field. It takes in a reflect.Type representing the
50// Go struct and matches message fields with struct fields.
51//
52// This code assumes that the struct is well-formed and panics if there are
53// any discrepancies.
54func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
55 mi.fields = map[pref.FieldNumber]*fieldInfo{}
56 md := mi.Desc
57 fds := md.Fields()
58 for i := 0; i < fds.Len(); i++ {
59 fd := fds.Get(i)
60 fs := si.fieldsByNumber[fd.Number()]
61 isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
62 if isOneof {
63 fs = si.oneofsByName[fd.ContainingOneof().Name()]
64 }
65 var fi fieldInfo
66 switch {
67 case fs.Type == nil:
68 fi = fieldInfoForMissing(fd) // never occurs for officially generated message types
69 case isOneof:
70 fi = fieldInfoForOneof(fd, fs, mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
71 case fd.IsMap():
72 fi = fieldInfoForMap(fd, fs, mi.Exporter)
73 case fd.IsList():
74 fi = fieldInfoForList(fd, fs, mi.Exporter)
75 case fd.IsWeak():
76 fi = fieldInfoForWeakMessage(fd, si.weakOffset)
77 case fd.Message() != nil:
78 fi = fieldInfoForMessage(fd, fs, mi.Exporter)
79 default:
80 fi = fieldInfoForScalar(fd, fs, mi.Exporter)
81 }
82 mi.fields[fd.Number()] = &fi
83 }
84
85 mi.oneofs = map[pref.Name]*oneofInfo{}
86 for i := 0; i < md.Oneofs().Len(); i++ {
87 od := md.Oneofs().Get(i)
88 mi.oneofs[od.Name()] = makeOneofInfo(od, si, mi.Exporter)
89 }
90
91 mi.denseFields = make([]*fieldInfo, fds.Len()*2)
92 for i := 0; i < fds.Len(); i++ {
93 if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
94 mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
95 }
96 }
97
98 for i := 0; i < fds.Len(); {
99 fd := fds.Get(i)
100 if od := fd.ContainingOneof(); od != nil && !od.IsSynthetic() {
101 mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
102 i += od.Fields().Len()
103 } else {
104 mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
105 i++
106 }
107 }
108
109 // Introduce instability to iteration order, but keep it deterministic.
110 if len(mi.rangeInfos) > 1 && detrand.Bool() {
111 i := detrand.Intn(len(mi.rangeInfos) - 1)
112 mi.rangeInfos[i], mi.rangeInfos[i+1] = mi.rangeInfos[i+1], mi.rangeInfos[i]
113 }
114}
115
116func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
117 switch {
118 case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
119 // Handle as []byte.
120 mi.getUnknown = func(p pointer) pref.RawFields {
121 if p.IsNil() {
122 return nil
123 }
124 return *p.Apply(mi.unknownOffset).Bytes()
125 }
126 mi.setUnknown = func(p pointer, b pref.RawFields) {
127 if p.IsNil() {
128 panic("invalid SetUnknown on nil Message")
129 }
130 *p.Apply(mi.unknownOffset).Bytes() = b
131 }
132 case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
133 // Handle as *[]byte.
134 mi.getUnknown = func(p pointer) pref.RawFields {
135 if p.IsNil() {
136 return nil
137 }
138 bp := p.Apply(mi.unknownOffset).BytesPtr()
139 if *bp == nil {
140 return nil
141 }
142 return **bp
143 }
144 mi.setUnknown = func(p pointer, b pref.RawFields) {
145 if p.IsNil() {
146 panic("invalid SetUnknown on nil Message")
147 }
148 bp := p.Apply(mi.unknownOffset).BytesPtr()
149 if *bp == nil {
150 *bp = new([]byte)
151 }
152 **bp = b
153 }
154 default:
155 mi.getUnknown = func(pointer) pref.RawFields {
156 return nil
157 }
158 mi.setUnknown = func(p pointer, _ pref.RawFields) {
159 if p.IsNil() {
160 panic("invalid SetUnknown on nil Message")
161 }
162 }
163 }
164}
165
166func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
167 if si.extensionOffset.IsValid() {
168 mi.extensionMap = func(p pointer) *extensionMap {
169 if p.IsNil() {
170 return (*extensionMap)(nil)
171 }
172 v := p.Apply(si.extensionOffset).AsValueOf(extensionFieldsType)
173 return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
174 }
175 } else {
176 mi.extensionMap = func(pointer) *extensionMap {
177 return (*extensionMap)(nil)
178 }
179 }
180}
181func (mi *MessageInfo) makeFieldTypes(si structInfo) {
182 md := mi.Desc
183 fds := md.Fields()
184 for i := 0; i < fds.Len(); i++ {
185 var ft reflect.Type
186 fd := fds.Get(i)
187 fs := si.fieldsByNumber[fd.Number()]
188 isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
189 if isOneof {
190 fs = si.oneofsByName[fd.ContainingOneof().Name()]
191 }
192 var isMessage bool
193 switch {
194 case fs.Type == nil:
195 continue // never occurs for officially generated message types
196 case isOneof:
197 if fd.Enum() != nil || fd.Message() != nil {
198 ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
199 }
200 case fd.IsMap():
201 if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
202 ft = fs.Type.Elem()
203 }
204 isMessage = fd.MapValue().Message() != nil
205 case fd.IsList():
206 if fd.Enum() != nil || fd.Message() != nil {
207 ft = fs.Type.Elem()
208 }
209 isMessage = fd.Message() != nil
210 case fd.Enum() != nil:
211 ft = fs.Type
212 if fd.HasPresence() && ft.Kind() == reflect.Ptr {
213 ft = ft.Elem()
214 }
215 case fd.Message() != nil:
216 ft = fs.Type
217 if fd.IsWeak() {
218 ft = nil
219 }
220 isMessage = true
221 }
222 if isMessage && ft != nil && ft.Kind() != reflect.Ptr {
223 ft = reflect.PtrTo(ft) // never occurs for officially generated message types
224 }
225 if ft != nil {
226 if mi.fieldTypes == nil {
227 mi.fieldTypes = make(map[pref.FieldNumber]interface{})
228 }
229 mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
230 }
231 }
232}
233
234type extensionMap map[int32]ExtensionField
235
236func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
237 if m != nil {
238 for _, x := range *m {
239 xd := x.Type().TypeDescriptor()
240 v := x.Value()
241 if xd.IsList() && v.List().Len() == 0 {
242 continue
243 }
244 if !f(xd, v) {
245 return
246 }
247 }
248 }
249}
250func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
251 if m == nil {
252 return false
253 }
254 xd := xt.TypeDescriptor()
255 x, ok := (*m)[int32(xd.Number())]
256 if !ok {
257 return false
258 }
259 switch {
260 case xd.IsList():
261 return x.Value().List().Len() > 0
262 case xd.IsMap():
263 return x.Value().Map().Len() > 0
264 case xd.Message() != nil:
265 return x.Value().Message().IsValid()
266 }
267 return true
268}
269func (m *extensionMap) Clear(xt pref.ExtensionType) {
270 delete(*m, int32(xt.TypeDescriptor().Number()))
271}
272func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
273 xd := xt.TypeDescriptor()
274 if m != nil {
275 if x, ok := (*m)[int32(xd.Number())]; ok {
276 return x.Value()
277 }
278 }
279 return xt.Zero()
280}
281func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
282 xd := xt.TypeDescriptor()
283 isValid := true
284 switch {
285 case !xt.IsValidValue(v):
286 isValid = false
287 case xd.IsList():
288 isValid = v.List().IsValid()
289 case xd.IsMap():
290 isValid = v.Map().IsValid()
291 case xd.Message() != nil:
292 isValid = v.Message().IsValid()
293 }
294 if !isValid {
295 panic(fmt.Sprintf("%v: assigning invalid value", xt.TypeDescriptor().FullName()))
296 }
297
298 if *m == nil {
299 *m = make(map[int32]ExtensionField)
300 }
301 var x ExtensionField
302 x.Set(xt, v)
303 (*m)[int32(xd.Number())] = x
304}
305func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
306 xd := xt.TypeDescriptor()
307 if xd.Kind() != pref.MessageKind && xd.Kind() != pref.GroupKind && !xd.IsList() && !xd.IsMap() {
308 panic("invalid Mutable on field with non-composite type")
309 }
310 if x, ok := (*m)[int32(xd.Number())]; ok {
311 return x.Value()
312 }
313 v := xt.New()
314 m.Set(xt, v)
315 return v
316}
317
318// MessageState is a data structure that is nested as the first field in a
319// concrete message. It provides a way to implement the ProtoReflect method
320// in an allocation-free way without needing to have a shadow Go type generated
321// for every message type. This technique only works using unsafe.
322//
323//
324// Example generated code:
325//
326// type M struct {
327// state protoimpl.MessageState
328//
329// Field1 int32
330// Field2 string
331// Field3 *BarMessage
332// ...
333// }
334//
335// func (m *M) ProtoReflect() protoreflect.Message {
336// mi := &file_fizz_buzz_proto_msgInfos[5]
337// if protoimpl.UnsafeEnabled && m != nil {
338// ms := protoimpl.X.MessageStateOf(Pointer(m))
339// if ms.LoadMessageInfo() == nil {
340// ms.StoreMessageInfo(mi)
341// }
342// return ms
343// }
344// return mi.MessageOf(m)
345// }
346//
347// The MessageState type holds a *MessageInfo, which must be atomically set to
348// the message info associated with a given message instance.
349// By unsafely converting a *M into a *MessageState, the MessageState object
350// has access to all the information needed to implement protobuf reflection.
351// It has access to the message info as its first field, and a pointer to the
352// MessageState is identical to a pointer to the concrete message value.
353//
354//
355// Requirements:
356// • The type M must implement protoreflect.ProtoMessage.
357// • The address of m must not be nil.
358// • The address of m and the address of m.state must be equal,
359// even though they are different Go types.
360type MessageState struct {
361 pragma.NoUnkeyedLiterals
362 pragma.DoNotCompare
363 pragma.DoNotCopy
364
365 atomicMessageInfo *MessageInfo
366}
367
368type messageState MessageState
369
370var (
371 _ pref.Message = (*messageState)(nil)
372 _ unwrapper = (*messageState)(nil)
373)
374
375// messageDataType is a tuple of a pointer to the message data and
376// a pointer to the message type. It is a generalized way of providing a
377// reflective view over a message instance. The disadvantage of this approach
378// is the need to allocate this tuple of 16B.
379type messageDataType struct {
380 p pointer
381 mi *MessageInfo
382}
383
384type (
385 messageReflectWrapper messageDataType
386 messageIfaceWrapper messageDataType
387)
388
389var (
390 _ pref.Message = (*messageReflectWrapper)(nil)
391 _ unwrapper = (*messageReflectWrapper)(nil)
392 _ pref.ProtoMessage = (*messageIfaceWrapper)(nil)
393 _ unwrapper = (*messageIfaceWrapper)(nil)
394)
395
396// MessageOf returns a reflective view over a message. The input must be a
397// pointer to a named Go struct. If the provided type has a ProtoReflect method,
398// it must be implemented by calling this method.
399func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
400 if reflect.TypeOf(m) != mi.GoReflectType {
401 panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
402 }
403 p := pointerOfIface(m)
404 if p.IsNil() {
405 return mi.nilMessage.Init(mi)
406 }
407 return &messageReflectWrapper{p, mi}
408}
409
410func (m *messageReflectWrapper) pointer() pointer { return m.p }
411func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
412
413// Reset implements the v1 proto.Message.Reset method.
414func (m *messageIfaceWrapper) Reset() {
415 if mr, ok := m.protoUnwrap().(interface{ Reset() }); ok {
416 mr.Reset()
417 return
418 }
419 rv := reflect.ValueOf(m.protoUnwrap())
420 if rv.Kind() == reflect.Ptr && !rv.IsNil() {
421 rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
422 }
423}
424func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
425 return (*messageReflectWrapper)(m)
426}
427func (m *messageIfaceWrapper) protoUnwrap() interface{} {
428 return m.p.AsIfaceOf(m.mi.GoReflectType.Elem())
429}
430
431// checkField verifies that the provided field descriptor is valid.
432// Exactly one of the returned values is populated.
433func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
434 var fi *fieldInfo
435 if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
436 fi = mi.denseFields[n]
437 } else {
438 fi = mi.fields[n]
439 }
440 if fi != nil {
441 if fi.fieldDesc != fd {
442 if got, want := fd.FullName(), fi.fieldDesc.FullName(); got != want {
443 panic(fmt.Sprintf("mismatching field: got %v, want %v", got, want))
444 }
445 panic(fmt.Sprintf("mismatching field: %v", fd.FullName()))
446 }
447 return fi, nil
448 }
449
450 if fd.IsExtension() {
451 if got, want := fd.ContainingMessage().FullName(), mi.Desc.FullName(); got != want {
452 // TODO: Should this be exact containing message descriptor match?
453 panic(fmt.Sprintf("extension %v has mismatching containing message: got %v, want %v", fd.FullName(), got, want))
454 }
455 if !mi.Desc.ExtensionRanges().Has(fd.Number()) {
456 panic(fmt.Sprintf("extension %v extends %v outside the extension range", fd.FullName(), mi.Desc.FullName()))
457 }
458 xtd, ok := fd.(pref.ExtensionTypeDescriptor)
459 if !ok {
460 panic(fmt.Sprintf("extension %v does not implement protoreflect.ExtensionTypeDescriptor", fd.FullName()))
461 }
462 return nil, xtd.Type()
463 }
464 panic(fmt.Sprintf("field %v is invalid", fd.FullName()))
465}