blob: 0f4b8db760aa98977cd66e8c27e4a85c7a7e3c7b [file] [log] [blame]
Andrea Campanella3614a922021-02-25 12:40:42 +01001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10
11 "google.golang.org/protobuf/internal/pragma"
12 pref "google.golang.org/protobuf/reflect/protoreflect"
13)
14
15type reflectMessageInfo struct {
16 fields map[pref.FieldNumber]*fieldInfo
17 oneofs map[pref.Name]*oneofInfo
18
19 // denseFields is a subset of fields where:
20 // 0 < fieldDesc.Number() < len(denseFields)
21 // It provides faster access to the fieldInfo, but may be incomplete.
22 denseFields []*fieldInfo
23
24 // rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
25 rangeInfos []interface{} // either *fieldInfo or *oneofInfo
26
27 getUnknown func(pointer) pref.RawFields
28 setUnknown func(pointer, pref.RawFields)
29 extensionMap func(pointer) *extensionMap
30
31 nilMessage atomicNilMessage
32}
33
34// makeReflectFuncs generates the set of functions to support reflection.
35func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
36 mi.makeKnownFieldsFunc(si)
37 mi.makeUnknownFieldsFunc(t, si)
38 mi.makeExtensionFieldsFunc(t, si)
39}
40
41// makeKnownFieldsFunc generates functions for operations that can be performed
42// on each protobuf message field. It takes in a reflect.Type representing the
43// Go struct and matches message fields with struct fields.
44//
45// This code assumes that the struct is well-formed and panics if there are
46// any discrepancies.
47func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
48 mi.fields = map[pref.FieldNumber]*fieldInfo{}
49 md := mi.Desc
50 fds := md.Fields()
51 for i := 0; i < fds.Len(); i++ {
52 fd := fds.Get(i)
53 fs := si.fieldsByNumber[fd.Number()]
54 var fi fieldInfo
55 switch {
56 case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
57 fi = fieldInfoForOneof(fd, si.oneofsByName[fd.ContainingOneof().Name()], mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
58 case fd.IsMap():
59 fi = fieldInfoForMap(fd, fs, mi.Exporter)
60 case fd.IsList():
61 fi = fieldInfoForList(fd, fs, mi.Exporter)
62 case fd.IsWeak():
63 fi = fieldInfoForWeakMessage(fd, si.weakOffset)
64 case fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind:
65 fi = fieldInfoForMessage(fd, fs, mi.Exporter)
66 default:
67 fi = fieldInfoForScalar(fd, fs, mi.Exporter)
68 }
69 mi.fields[fd.Number()] = &fi
70 }
71
72 mi.oneofs = map[pref.Name]*oneofInfo{}
73 for i := 0; i < md.Oneofs().Len(); i++ {
74 od := md.Oneofs().Get(i)
75 mi.oneofs[od.Name()] = makeOneofInfo(od, si, mi.Exporter)
76 }
77
78 mi.denseFields = make([]*fieldInfo, fds.Len()*2)
79 for i := 0; i < fds.Len(); i++ {
80 if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
81 mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
82 }
83 }
84
85 for i := 0; i < fds.Len(); {
86 fd := fds.Get(i)
87 if od := fd.ContainingOneof(); od != nil && !od.IsSynthetic() {
88 mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
89 i += od.Fields().Len()
90 } else {
91 mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
92 i++
93 }
94 }
95}
96
97func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
98 mi.getUnknown = func(pointer) pref.RawFields { return nil }
99 mi.setUnknown = func(pointer, pref.RawFields) { return }
100 if si.unknownOffset.IsValid() {
101 mi.getUnknown = func(p pointer) pref.RawFields {
102 if p.IsNil() {
103 return nil
104 }
105 rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType)
106 return pref.RawFields(*rv.Interface().(*[]byte))
107 }
108 mi.setUnknown = func(p pointer, b pref.RawFields) {
109 if p.IsNil() {
110 panic("invalid SetUnknown on nil Message")
111 }
112 rv := p.Apply(si.unknownOffset).AsValueOf(unknownFieldsType)
113 *rv.Interface().(*[]byte) = []byte(b)
114 }
115 } else {
116 mi.getUnknown = func(pointer) pref.RawFields {
117 return nil
118 }
119 mi.setUnknown = func(p pointer, _ pref.RawFields) {
120 if p.IsNil() {
121 panic("invalid SetUnknown on nil Message")
122 }
123 }
124 }
125}
126
127func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
128 if si.extensionOffset.IsValid() {
129 mi.extensionMap = func(p pointer) *extensionMap {
130 if p.IsNil() {
131 return (*extensionMap)(nil)
132 }
133 v := p.Apply(si.extensionOffset).AsValueOf(extensionFieldsType)
134 return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
135 }
136 } else {
137 mi.extensionMap = func(pointer) *extensionMap {
138 return (*extensionMap)(nil)
139 }
140 }
141}
142
143type extensionMap map[int32]ExtensionField
144
145func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
146 if m != nil {
147 for _, x := range *m {
148 xd := x.Type().TypeDescriptor()
149 v := x.Value()
150 if xd.IsList() && v.List().Len() == 0 {
151 continue
152 }
153 if !f(xd, v) {
154 return
155 }
156 }
157 }
158}
159func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
160 if m == nil {
161 return false
162 }
163 xd := xt.TypeDescriptor()
164 x, ok := (*m)[int32(xd.Number())]
165 if !ok {
166 return false
167 }
168 switch {
169 case xd.IsList():
170 return x.Value().List().Len() > 0
171 case xd.IsMap():
172 return x.Value().Map().Len() > 0
173 case xd.Message() != nil:
174 return x.Value().Message().IsValid()
175 }
176 return true
177}
178func (m *extensionMap) Clear(xt pref.ExtensionType) {
179 delete(*m, int32(xt.TypeDescriptor().Number()))
180}
181func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
182 xd := xt.TypeDescriptor()
183 if m != nil {
184 if x, ok := (*m)[int32(xd.Number())]; ok {
185 return x.Value()
186 }
187 }
188 return xt.Zero()
189}
190func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
191 xd := xt.TypeDescriptor()
192 isValid := true
193 switch {
194 case !xt.IsValidValue(v):
195 isValid = false
196 case xd.IsList():
197 isValid = v.List().IsValid()
198 case xd.IsMap():
199 isValid = v.Map().IsValid()
200 case xd.Message() != nil:
201 isValid = v.Message().IsValid()
202 }
203 if !isValid {
204 panic(fmt.Sprintf("%v: assigning invalid value", xt.TypeDescriptor().FullName()))
205 }
206
207 if *m == nil {
208 *m = make(map[int32]ExtensionField)
209 }
210 var x ExtensionField
211 x.Set(xt, v)
212 (*m)[int32(xd.Number())] = x
213}
214func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
215 xd := xt.TypeDescriptor()
216 if xd.Kind() != pref.MessageKind && xd.Kind() != pref.GroupKind && !xd.IsList() && !xd.IsMap() {
217 panic("invalid Mutable on field with non-composite type")
218 }
219 if x, ok := (*m)[int32(xd.Number())]; ok {
220 return x.Value()
221 }
222 v := xt.New()
223 m.Set(xt, v)
224 return v
225}
226
227// MessageState is a data structure that is nested as the first field in a
228// concrete message. It provides a way to implement the ProtoReflect method
229// in an allocation-free way without needing to have a shadow Go type generated
230// for every message type. This technique only works using unsafe.
231//
232//
233// Example generated code:
234//
235// type M struct {
236// state protoimpl.MessageState
237//
238// Field1 int32
239// Field2 string
240// Field3 *BarMessage
241// ...
242// }
243//
244// func (m *M) ProtoReflect() protoreflect.Message {
245// mi := &file_fizz_buzz_proto_msgInfos[5]
246// if protoimpl.UnsafeEnabled && m != nil {
247// ms := protoimpl.X.MessageStateOf(Pointer(m))
248// if ms.LoadMessageInfo() == nil {
249// ms.StoreMessageInfo(mi)
250// }
251// return ms
252// }
253// return mi.MessageOf(m)
254// }
255//
256// The MessageState type holds a *MessageInfo, which must be atomically set to
257// the message info associated with a given message instance.
258// By unsafely converting a *M into a *MessageState, the MessageState object
259// has access to all the information needed to implement protobuf reflection.
260// It has access to the message info as its first field, and a pointer to the
261// MessageState is identical to a pointer to the concrete message value.
262//
263//
264// Requirements:
265// • The type M must implement protoreflect.ProtoMessage.
266// • The address of m must not be nil.
267// • The address of m and the address of m.state must be equal,
268// even though they are different Go types.
269type MessageState struct {
270 pragma.NoUnkeyedLiterals
271 pragma.DoNotCompare
272 pragma.DoNotCopy
273
274 atomicMessageInfo *MessageInfo
275}
276
277type messageState MessageState
278
279var (
280 _ pref.Message = (*messageState)(nil)
281 _ unwrapper = (*messageState)(nil)
282)
283
284// messageDataType is a tuple of a pointer to the message data and
285// a pointer to the message type. It is a generalized way of providing a
286// reflective view over a message instance. The disadvantage of this approach
287// is the need to allocate this tuple of 16B.
288type messageDataType struct {
289 p pointer
290 mi *MessageInfo
291}
292
293type (
294 messageReflectWrapper messageDataType
295 messageIfaceWrapper messageDataType
296)
297
298var (
299 _ pref.Message = (*messageReflectWrapper)(nil)
300 _ unwrapper = (*messageReflectWrapper)(nil)
301 _ pref.ProtoMessage = (*messageIfaceWrapper)(nil)
302 _ unwrapper = (*messageIfaceWrapper)(nil)
303)
304
305// MessageOf returns a reflective view over a message. The input must be a
306// pointer to a named Go struct. If the provided type has a ProtoReflect method,
307// it must be implemented by calling this method.
308func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
309 // TODO: Switch the input to be an opaque Pointer.
310 if reflect.TypeOf(m) != mi.GoReflectType {
311 panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
312 }
313 p := pointerOfIface(m)
314 if p.IsNil() {
315 return mi.nilMessage.Init(mi)
316 }
317 return &messageReflectWrapper{p, mi}
318}
319
320func (m *messageReflectWrapper) pointer() pointer { return m.p }
321func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
322
323func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
324 return (*messageReflectWrapper)(m)
325}
326func (m *messageIfaceWrapper) protoUnwrap() interface{} {
327 return m.p.AsIfaceOf(m.mi.GoReflectType.Elem())
328}
329
330// checkField verifies that the provided field descriptor is valid.
331// Exactly one of the returned values is populated.
332func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
333 var fi *fieldInfo
334 if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
335 fi = mi.denseFields[n]
336 } else {
337 fi = mi.fields[n]
338 }
339 if fi != nil {
340 if fi.fieldDesc != fd {
341 if got, want := fd.FullName(), fi.fieldDesc.FullName(); got != want {
342 panic(fmt.Sprintf("mismatching field: got %v, want %v", got, want))
343 }
344 panic(fmt.Sprintf("mismatching field: %v", fd.FullName()))
345 }
346 return fi, nil
347 }
348
349 if fd.IsExtension() {
350 if got, want := fd.ContainingMessage().FullName(), mi.Desc.FullName(); got != want {
351 // TODO: Should this be exact containing message descriptor match?
352 panic(fmt.Sprintf("extension %v has mismatching containing message: got %v, want %v", fd.FullName(), got, want))
353 }
354 if !mi.Desc.ExtensionRanges().Has(fd.Number()) {
355 panic(fmt.Sprintf("extension %v extends %v outside the extension range", fd.FullName(), mi.Desc.FullName()))
356 }
357 xtd, ok := fd.(pref.ExtensionTypeDescriptor)
358 if !ok {
359 panic(fmt.Sprintf("extension %v does not implement protoreflect.ExtensionTypeDescriptor", fd.FullName()))
360 }
361 return nil, xtd.Type()
362 }
363 panic(fmt.Sprintf("field %v is invalid", fd.FullName()))
364}