blob: cd40527ff646203e19747ff7b823fc7750bd427c [file] [log] [blame]
khenaidood948f772021-08-11 17:49:24 -04001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10 "sort"
11
12 "google.golang.org/protobuf/encoding/protowire"
13 "google.golang.org/protobuf/internal/encoding/messageset"
14 "google.golang.org/protobuf/internal/order"
15 pref "google.golang.org/protobuf/reflect/protoreflect"
16 piface "google.golang.org/protobuf/runtime/protoiface"
17)
18
19// coderMessageInfo contains per-message information used by the fast-path functions.
20// This is a different type from MessageInfo to keep MessageInfo as general-purpose as
21// possible.
22type coderMessageInfo struct {
23 methods piface.Methods
24
25 orderedCoderFields []*coderFieldInfo
26 denseCoderFields []*coderFieldInfo
27 coderFields map[protowire.Number]*coderFieldInfo
28 sizecacheOffset offset
29 unknownOffset offset
30 unknownPtrKind bool
31 extensionOffset offset
32 needsInitCheck bool
33 isMessageSet bool
34 numRequiredFields uint8
35}
36
37type coderFieldInfo struct {
38 funcs pointerCoderFuncs // fast-path per-field functions
39 mi *MessageInfo // field's message
40 ft reflect.Type
41 validation validationInfo // information used by message validation
42 num pref.FieldNumber // field number
43 offset offset // struct field offset
44 wiretag uint64 // field tag (number + wire type)
45 tagsize int // size of the varint-encoded tag
46 isPointer bool // true if IsNil may be called on the struct field
47 isRequired bool // true if field is required
48}
49
50func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
51 mi.sizecacheOffset = invalidOffset
52 mi.unknownOffset = invalidOffset
53 mi.extensionOffset = invalidOffset
54
55 if si.sizecacheOffset.IsValid() && si.sizecacheType == sizecacheType {
56 mi.sizecacheOffset = si.sizecacheOffset
57 }
58 if si.unknownOffset.IsValid() && (si.unknownType == unknownFieldsAType || si.unknownType == unknownFieldsBType) {
59 mi.unknownOffset = si.unknownOffset
60 mi.unknownPtrKind = si.unknownType.Kind() == reflect.Ptr
61 }
62 if si.extensionOffset.IsValid() && si.extensionType == extensionFieldsType {
63 mi.extensionOffset = si.extensionOffset
64 }
65
66 mi.coderFields = make(map[protowire.Number]*coderFieldInfo)
67 fields := mi.Desc.Fields()
68 preallocFields := make([]coderFieldInfo, fields.Len())
69 for i := 0; i < fields.Len(); i++ {
70 fd := fields.Get(i)
71
72 fs := si.fieldsByNumber[fd.Number()]
73 isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
74 if isOneof {
75 fs = si.oneofsByName[fd.ContainingOneof().Name()]
76 }
77 ft := fs.Type
78 var wiretag uint64
79 if !fd.IsPacked() {
80 wiretag = protowire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
81 } else {
82 wiretag = protowire.EncodeTag(fd.Number(), protowire.BytesType)
83 }
84 var fieldOffset offset
85 var funcs pointerCoderFuncs
86 var childMessage *MessageInfo
87 switch {
88 case ft == nil:
89 // This never occurs for generated message types.
90 // It implies that a hand-crafted type has missing Go fields
91 // for specific protobuf message fields.
92 funcs = pointerCoderFuncs{
93 size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
94 return 0
95 },
96 marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
97 return nil, nil
98 },
99 unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
100 panic("missing Go struct field for " + string(fd.FullName()))
101 },
102 isInit: func(p pointer, f *coderFieldInfo) error {
103 panic("missing Go struct field for " + string(fd.FullName()))
104 },
105 merge: func(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
106 panic("missing Go struct field for " + string(fd.FullName()))
107 },
108 }
109 case isOneof:
110 fieldOffset = offsetOf(fs, mi.Exporter)
111 case fd.IsWeak():
112 fieldOffset = si.weakOffset
113 funcs = makeWeakMessageFieldCoder(fd)
114 default:
115 fieldOffset = offsetOf(fs, mi.Exporter)
116 childMessage, funcs = fieldCoder(fd, ft)
117 }
118 cf := &preallocFields[i]
119 *cf = coderFieldInfo{
120 num: fd.Number(),
121 offset: fieldOffset,
122 wiretag: wiretag,
123 ft: ft,
124 tagsize: protowire.SizeVarint(wiretag),
125 funcs: funcs,
126 mi: childMessage,
127 validation: newFieldValidationInfo(mi, si, fd, ft),
128 isPointer: fd.Cardinality() == pref.Repeated || fd.HasPresence(),
129 isRequired: fd.Cardinality() == pref.Required,
130 }
131 mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
132 mi.coderFields[cf.num] = cf
133 }
134 for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ {
135 if od := oneofs.Get(i); !od.IsSynthetic() {
136 mi.initOneofFieldCoders(od, si)
137 }
138 }
139 if messageset.IsMessageSet(mi.Desc) {
140 if !mi.extensionOffset.IsValid() {
141 panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName()))
142 }
143 if !mi.unknownOffset.IsValid() {
144 panic(fmt.Sprintf("%v: MessageSet with no unknown field", mi.Desc.FullName()))
145 }
146 mi.isMessageSet = true
147 }
148 sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
149 return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num
150 })
151
152 var maxDense pref.FieldNumber
153 for _, cf := range mi.orderedCoderFields {
154 if cf.num >= 16 && cf.num >= 2*maxDense {
155 break
156 }
157 maxDense = cf.num
158 }
159 mi.denseCoderFields = make([]*coderFieldInfo, maxDense+1)
160 for _, cf := range mi.orderedCoderFields {
161 if int(cf.num) >= len(mi.denseCoderFields) {
162 break
163 }
164 mi.denseCoderFields[cf.num] = cf
165 }
166
167 // To preserve compatibility with historic wire output, marshal oneofs last.
168 if mi.Desc.Oneofs().Len() > 0 {
169 sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
170 fi := fields.ByNumber(mi.orderedCoderFields[i].num)
171 fj := fields.ByNumber(mi.orderedCoderFields[j].num)
172 return order.LegacyFieldOrder(fi, fj)
173 })
174 }
175
176 mi.needsInitCheck = needsInitCheck(mi.Desc)
177 if mi.methods.Marshal == nil && mi.methods.Size == nil {
178 mi.methods.Flags |= piface.SupportMarshalDeterministic
179 mi.methods.Marshal = mi.marshal
180 mi.methods.Size = mi.size
181 }
182 if mi.methods.Unmarshal == nil {
183 mi.methods.Flags |= piface.SupportUnmarshalDiscardUnknown
184 mi.methods.Unmarshal = mi.unmarshal
185 }
186 if mi.methods.CheckInitialized == nil {
187 mi.methods.CheckInitialized = mi.checkInitialized
188 }
189 if mi.methods.Merge == nil {
190 mi.methods.Merge = mi.merge
191 }
192}
193
194// getUnknownBytes returns a *[]byte for the unknown fields.
195// It is the caller's responsibility to check whether the pointer is nil.
196// This function is specially designed to be inlineable.
197func (mi *MessageInfo) getUnknownBytes(p pointer) *[]byte {
198 if mi.unknownPtrKind {
199 return *p.Apply(mi.unknownOffset).BytesPtr()
200 } else {
201 return p.Apply(mi.unknownOffset).Bytes()
202 }
203}
204
205// mutableUnknownBytes returns a *[]byte for the unknown fields.
206// The returned pointer is guaranteed to not be nil.
207func (mi *MessageInfo) mutableUnknownBytes(p pointer) *[]byte {
208 if mi.unknownPtrKind {
209 bp := p.Apply(mi.unknownOffset).BytesPtr()
210 if *bp == nil {
211 *bp = new([]byte)
212 }
213 return *bp
214 } else {
215 return p.Apply(mi.unknownOffset).Bytes()
216 }
217}