blob: 02f8a321fd90f179ece406e14369d0e1b8b8076f [file] [log] [blame]
khenaidooefff76e2021-12-15 16:51:30 -05001package codec
2
3import (
4 "errors"
5 "fmt"
6 "io"
7 "math"
8
9 "github.com/golang/protobuf/proto"
10 "github.com/golang/protobuf/protoc-gen-go/descriptor"
11
12 "github.com/jhump/protoreflect/desc"
13)
14
15var varintTypes = map[descriptor.FieldDescriptorProto_Type]bool{}
16var fixed32Types = map[descriptor.FieldDescriptorProto_Type]bool{}
17var fixed64Types = map[descriptor.FieldDescriptorProto_Type]bool{}
18
19func init() {
20 varintTypes[descriptor.FieldDescriptorProto_TYPE_BOOL] = true
21 varintTypes[descriptor.FieldDescriptorProto_TYPE_INT32] = true
22 varintTypes[descriptor.FieldDescriptorProto_TYPE_INT64] = true
23 varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT32] = true
24 varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT64] = true
25 varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT32] = true
26 varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT64] = true
27 varintTypes[descriptor.FieldDescriptorProto_TYPE_ENUM] = true
28
29 fixed32Types[descriptor.FieldDescriptorProto_TYPE_FIXED32] = true
30 fixed32Types[descriptor.FieldDescriptorProto_TYPE_SFIXED32] = true
31 fixed32Types[descriptor.FieldDescriptorProto_TYPE_FLOAT] = true
32
33 fixed64Types[descriptor.FieldDescriptorProto_TYPE_FIXED64] = true
34 fixed64Types[descriptor.FieldDescriptorProto_TYPE_SFIXED64] = true
35 fixed64Types[descriptor.FieldDescriptorProto_TYPE_DOUBLE] = true
36}
37
38// ErrWireTypeEndGroup is returned from DecodeFieldValue if the tag and wire-type
39// it reads indicates an end-group marker.
40var ErrWireTypeEndGroup = errors.New("unexpected wire type: end group")
41
42// MessageFactory is used to instantiate messages when DecodeFieldValue needs to
43// decode a message value.
44//
45// Also see MessageFactory in "github.com/jhump/protoreflect/dynamic", which
46// implements this interface.
47type MessageFactory interface {
48 NewMessage(md *desc.MessageDescriptor) proto.Message
49}
50
51// UnknownField represents a field that was parsed from the binary wire
52// format for a message, but was not a recognized field number. Enough
53// information is preserved so that re-serializing the message won't lose
54// any of the unrecognized data.
55type UnknownField struct {
56 // The tag number for the unrecognized field.
57 Tag int32
58
59 // Encoding indicates how the unknown field was encoded on the wire. If it
60 // is proto.WireBytes or proto.WireGroupStart then Contents will be set to
61 // the raw bytes. If it is proto.WireTypeFixed32 then the data is in the least
62 // significant 32 bits of Value. Otherwise, the data is in all 64 bits of
63 // Value.
64 Encoding int8
65 Contents []byte
66 Value uint64
67}
68
69// DecodeZigZag32 decodes a signed 32-bit integer from the given
70// zig-zag encoded value.
71func DecodeZigZag32(v uint64) int32 {
72 return int32((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31))
73}
74
75// DecodeZigZag64 decodes a signed 64-bit integer from the given
76// zig-zag encoded value.
77func DecodeZigZag64(v uint64) int64 {
78 return int64((v >> 1) ^ uint64((int64(v&1)<<63)>>63))
79}
80
81// DecodeFieldValue will read a field value from the buffer and return its
82// value and the corresponding field descriptor. The given function is used
83// to lookup a field descriptor by tag number. The given factory is used to
84// instantiate a message if the field value is (or contains) a message value.
85//
86// On error, the field descriptor and value are typically nil. However, if the
87// error returned is ErrWireTypeEndGroup, the returned value will indicate any
88// tag number encoded in the end-group marker.
89//
90// If the field descriptor returned is nil, that means that the given function
91// returned nil. This is expected to happen for unrecognized tag numbers. In
92// that case, no error is returned, and the value will be an UnknownField.
93func (cb *Buffer) DecodeFieldValue(fieldFinder func(int32) *desc.FieldDescriptor, fact MessageFactory) (*desc.FieldDescriptor, interface{}, error) {
94 if cb.EOF() {
95 return nil, nil, io.EOF
96 }
97 tagNumber, wireType, err := cb.DecodeTagAndWireType()
98 if err != nil {
99 return nil, nil, err
100 }
101 if wireType == proto.WireEndGroup {
102 return nil, tagNumber, ErrWireTypeEndGroup
103 }
104 fd := fieldFinder(tagNumber)
105 if fd == nil {
106 val, err := cb.decodeUnknownField(tagNumber, wireType)
107 return nil, val, err
108 }
109 val, err := cb.decodeKnownField(fd, wireType, fact)
110 return fd, val, err
111}
112
113// DecodeScalarField extracts a properly-typed value from v. The returned value's
114// type depends on the given field descriptor type. It will be the same type as
115// generated structs use for the field descriptor's type. Enum types will return
116// an int32. If the given field type uses length-delimited encoding (nested
117// messages, bytes, and strings), an error is returned.
118func DecodeScalarField(fd *desc.FieldDescriptor, v uint64) (interface{}, error) {
119 switch fd.GetType() {
120 case descriptor.FieldDescriptorProto_TYPE_BOOL:
121 return v != 0, nil
122 case descriptor.FieldDescriptorProto_TYPE_UINT32,
123 descriptor.FieldDescriptorProto_TYPE_FIXED32:
124 if v > math.MaxUint32 {
125 return nil, ErrOverflow
126 }
127 return uint32(v), nil
128
129 case descriptor.FieldDescriptorProto_TYPE_INT32,
130 descriptor.FieldDescriptorProto_TYPE_ENUM:
131 s := int64(v)
132 if s > math.MaxInt32 || s < math.MinInt32 {
133 return nil, ErrOverflow
134 }
135 return int32(s), nil
136
137 case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
138 if v > math.MaxUint32 {
139 return nil, ErrOverflow
140 }
141 return int32(v), nil
142
143 case descriptor.FieldDescriptorProto_TYPE_SINT32:
144 if v > math.MaxUint32 {
145 return nil, ErrOverflow
146 }
147 return DecodeZigZag32(v), nil
148
149 case descriptor.FieldDescriptorProto_TYPE_UINT64,
150 descriptor.FieldDescriptorProto_TYPE_FIXED64:
151 return v, nil
152
153 case descriptor.FieldDescriptorProto_TYPE_INT64,
154 descriptor.FieldDescriptorProto_TYPE_SFIXED64:
155 return int64(v), nil
156
157 case descriptor.FieldDescriptorProto_TYPE_SINT64:
158 return DecodeZigZag64(v), nil
159
160 case descriptor.FieldDescriptorProto_TYPE_FLOAT:
161 if v > math.MaxUint32 {
162 return nil, ErrOverflow
163 }
164 return math.Float32frombits(uint32(v)), nil
165
166 case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
167 return math.Float64frombits(v), nil
168
169 default:
170 // bytes, string, message, and group cannot be represented as a simple numeric value
171 return nil, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
172 }
173}
174
175// DecodeLengthDelimitedField extracts a properly-typed value from bytes. The
176// returned value's type will usually be []byte, string, or, for nested messages,
177// the type returned from the given message factory. However, since repeated
178// scalar fields can be length-delimited, when they used packed encoding, it can
179// also return an []interface{}, where each element is a scalar value. Furthermore,
180// it could return a scalar type, not in a slice, if the given field descriptor is
181// not repeated. This is to support cases where a field is changed from optional
182// to repeated. New code may emit a packed repeated representation, but old code
183// still expects a single scalar value. In this case, if the actual data in bytes
184// contains multiple values, only the last value is returned.
185func DecodeLengthDelimitedField(fd *desc.FieldDescriptor, bytes []byte, mf MessageFactory) (interface{}, error) {
186 switch {
187 case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES:
188 return bytes, nil
189
190 case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_STRING:
191 return string(bytes), nil
192
193 case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE ||
194 fd.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP:
195 msg := mf.NewMessage(fd.GetMessageType())
196 err := proto.Unmarshal(bytes, msg)
197 if err != nil {
198 return nil, err
199 } else {
200 return msg, nil
201 }
202
203 default:
204 // even if the field is not repeated or not packed, we still parse it as such for
205 // backwards compatibility (e.g. message we are de-serializing could have been both
206 // repeated and packed at the time of serialization)
207 packedBuf := NewBuffer(bytes)
208 var slice []interface{}
209 var val interface{}
210 for !packedBuf.EOF() {
211 var v uint64
212 var err error
213 if varintTypes[fd.GetType()] {
214 v, err = packedBuf.DecodeVarint()
215 } else if fixed32Types[fd.GetType()] {
216 v, err = packedBuf.DecodeFixed32()
217 } else if fixed64Types[fd.GetType()] {
218 v, err = packedBuf.DecodeFixed64()
219 } else {
220 return nil, fmt.Errorf("bad input; cannot parse length-delimited wire type for field %s", fd.GetFullyQualifiedName())
221 }
222 if err != nil {
223 return nil, err
224 }
225 val, err = DecodeScalarField(fd, v)
226 if err != nil {
227 return nil, err
228 }
229 if fd.IsRepeated() {
230 slice = append(slice, val)
231 }
232 }
233 if fd.IsRepeated() {
234 return slice, nil
235 } else {
236 // if not a repeated field, last value wins
237 return val, nil
238 }
239 }
240}
241
242func (b *Buffer) decodeKnownField(fd *desc.FieldDescriptor, encoding int8, fact MessageFactory) (interface{}, error) {
243 var val interface{}
244 var err error
245 switch encoding {
246 case proto.WireFixed32:
247 var num uint64
248 num, err = b.DecodeFixed32()
249 if err == nil {
250 val, err = DecodeScalarField(fd, num)
251 }
252 case proto.WireFixed64:
253 var num uint64
254 num, err = b.DecodeFixed64()
255 if err == nil {
256 val, err = DecodeScalarField(fd, num)
257 }
258 case proto.WireVarint:
259 var num uint64
260 num, err = b.DecodeVarint()
261 if err == nil {
262 val, err = DecodeScalarField(fd, num)
263 }
264
265 case proto.WireBytes:
266 alloc := fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES
267 var raw []byte
268 raw, err = b.DecodeRawBytes(alloc)
269 if err == nil {
270 val, err = DecodeLengthDelimitedField(fd, raw, fact)
271 }
272
273 case proto.WireStartGroup:
274 if fd.GetMessageType() == nil {
275 return nil, fmt.Errorf("cannot parse field %s from group-encoded wire type", fd.GetFullyQualifiedName())
276 }
277 msg := fact.NewMessage(fd.GetMessageType())
278 var data []byte
279 data, err = b.ReadGroup(false)
280 if err == nil {
281 err = proto.Unmarshal(data, msg)
282 if err == nil {
283 val = msg
284 }
285 }
286
287 default:
288 return nil, ErrBadWireType
289 }
290 if err != nil {
291 return nil, err
292 }
293
294 return val, nil
295}
296
297func (b *Buffer) decodeUnknownField(tagNumber int32, encoding int8) (interface{}, error) {
298 u := UnknownField{Tag: tagNumber, Encoding: encoding}
299 var err error
300 switch encoding {
301 case proto.WireFixed32:
302 u.Value, err = b.DecodeFixed32()
303 case proto.WireFixed64:
304 u.Value, err = b.DecodeFixed64()
305 case proto.WireVarint:
306 u.Value, err = b.DecodeVarint()
307 case proto.WireBytes:
308 u.Contents, err = b.DecodeRawBytes(true)
309 case proto.WireStartGroup:
310 u.Contents, err = b.ReadGroup(true)
311 default:
312 err = ErrBadWireType
313 }
314 if err != nil {
315 return nil, err
316 }
317 return u, nil
318}