blob: 938b4d9dcb56584ff85881f53c2a47c3dffb41e4 [file] [log] [blame]
Scott Baker4a35a702019-11-26 08:17:33 -08001package codec
2
3import (
4 "errors"
5 "fmt"
6 "io"
7 "math"
8
9 "github.com/golang/protobuf/proto"
10 "github.com/golang/protobuf/protoc-gen-go/descriptor"
11
12 "github.com/jhump/protoreflect/desc"
13)
14
15// ErrWireTypeEndGroup is returned from DecodeFieldValue if the tag and wire-type
16// it reads indicates an end-group marker.
17var ErrWireTypeEndGroup = errors.New("unexpected wire type: end group")
18
19// MessageFactory is used to instantiate messages when DecodeFieldValue needs to
20// decode a message value.
21//
22// Also see MessageFactory in "github.com/jhump/protoreflect/dynamic", which
23// implements this interface.
24type MessageFactory interface {
25 NewMessage(md *desc.MessageDescriptor) proto.Message
26}
27
28// UnknownField represents a field that was parsed from the binary wire
29// format for a message, but was not a recognized field number. Enough
30// information is preserved so that re-serializing the message won't lose
31// any of the unrecognized data.
32type UnknownField struct {
33 // The tag number for the unrecognized field.
34 Tag int32
35
36 // Encoding indicates how the unknown field was encoded on the wire. If it
37 // is proto.WireBytes or proto.WireGroupStart then Contents will be set to
38 // the raw bytes. If it is proto.WireTypeFixed32 then the data is in the least
39 // significant 32 bits of Value. Otherwise, the data is in all 64 bits of
40 // Value.
41 Encoding int8
42 Contents []byte
43 Value uint64
44}
45
46// DecodeFieldValue will read a field value from the buffer and return its
47// value and the corresponding field descriptor. The given function is used
48// to lookup a field descriptor by tag number. The given factory is used to
49// instantiate a message if the field value is (or contains) a message value.
50//
51// On error, the field descriptor and value are typically nil. However, if the
52// error returned is ErrWireTypeEndGroup, the returned value will indicate any
53// tag number encoded in the end-group marker.
54//
55// If the field descriptor returned is nil, that means that the given function
56// returned nil. This is expected to happen for unrecognized tag numbers. In
57// that case, no error is returned, and the value will be an UnknownField.
58func (cb *Buffer) DecodeFieldValue(fieldFinder func(int32) *desc.FieldDescriptor, fact MessageFactory) (*desc.FieldDescriptor, interface{}, error) {
59 if cb.EOF() {
60 return nil, nil, io.EOF
61 }
62 tagNumber, wireType, err := cb.DecodeTagAndWireType()
63 if err != nil {
64 return nil, nil, err
65 }
66 if wireType == proto.WireEndGroup {
67 return nil, tagNumber, ErrWireTypeEndGroup
68 }
69 fd := fieldFinder(tagNumber)
70 if fd == nil {
71 val, err := cb.decodeUnknownField(tagNumber, wireType)
72 return nil, val, err
73 }
74 val, err := cb.decodeKnownField(fd, wireType, fact)
75 return fd, val, err
76}
77
78// DecodeScalarField extracts a properly-typed value from v. The returned value's
79// type depends on the given field descriptor type. It will be the same type as
80// generated structs use for the field descriptor's type. Enum types will return
81// an int32. If the given field type uses length-delimited encoding (nested
82// messages, bytes, and strings), an error is returned.
83func DecodeScalarField(fd *desc.FieldDescriptor, v uint64) (interface{}, error) {
84 switch fd.GetType() {
85 case descriptor.FieldDescriptorProto_TYPE_BOOL:
86 return v != 0, nil
87 case descriptor.FieldDescriptorProto_TYPE_UINT32,
88 descriptor.FieldDescriptorProto_TYPE_FIXED32:
89 if v > math.MaxUint32 {
90 return nil, ErrOverflow
91 }
92 return uint32(v), nil
93
94 case descriptor.FieldDescriptorProto_TYPE_INT32,
95 descriptor.FieldDescriptorProto_TYPE_ENUM:
96 s := int64(v)
97 if s > math.MaxInt32 || s < math.MinInt32 {
98 return nil, ErrOverflow
99 }
100 return int32(s), nil
101
102 case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
103 if v > math.MaxUint32 {
104 return nil, ErrOverflow
105 }
106 return int32(v), nil
107
108 case descriptor.FieldDescriptorProto_TYPE_SINT32:
109 if v > math.MaxUint32 {
110 return nil, ErrOverflow
111 }
112 return DecodeZigZag32(v), nil
113
114 case descriptor.FieldDescriptorProto_TYPE_UINT64,
115 descriptor.FieldDescriptorProto_TYPE_FIXED64:
116 return v, nil
117
118 case descriptor.FieldDescriptorProto_TYPE_INT64,
119 descriptor.FieldDescriptorProto_TYPE_SFIXED64:
120 return int64(v), nil
121
122 case descriptor.FieldDescriptorProto_TYPE_SINT64:
123 return DecodeZigZag64(v), nil
124
125 case descriptor.FieldDescriptorProto_TYPE_FLOAT:
126 if v > math.MaxUint32 {
127 return nil, ErrOverflow
128 }
129 return math.Float32frombits(uint32(v)), nil
130
131 case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
132 return math.Float64frombits(v), nil
133
134 default:
135 // bytes, string, message, and group cannot be represented as a simple numeric value
136 return nil, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
137 }
138}
139
140// DecodeLengthDelimitedField extracts a properly-typed value from bytes. The
141// returned value's type will usually be []byte, string, or, for nested messages,
142// the type returned from the given message factory. However, since repeated
143// scalar fields can be length-delimited, when they used packed encoding, it can
144// also return an []interface{}, where each element is a scalar value. Furthermore,
145// it could return a scalar type, not in a slice, if the given field descriptor is
146// not repeated. This is to support cases where a field is changed from optional
147// to repeated. New code may emit a packed repeated representation, but old code
148// still expects a single scalar value. In this case, if the actual data in bytes
149// contains multiple values, only the last value is returned.
150func DecodeLengthDelimitedField(fd *desc.FieldDescriptor, bytes []byte, mf MessageFactory) (interface{}, error) {
151 switch {
152 case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES:
153 return bytes, nil
154
155 case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_STRING:
156 return string(bytes), nil
157
158 case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE ||
159 fd.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP:
160 msg := mf.NewMessage(fd.GetMessageType())
161 err := proto.Unmarshal(bytes, msg)
162 if err != nil {
163 return nil, err
164 } else {
165 return msg, nil
166 }
167
168 default:
169 // even if the field is not repeated or not packed, we still parse it as such for
170 // backwards compatibility (e.g. message we are de-serializing could have been both
171 // repeated and packed at the time of serialization)
172 packedBuf := NewBuffer(bytes)
173 var slice []interface{}
174 var val interface{}
175 for !packedBuf.EOF() {
176 var v uint64
177 var err error
178 if varintTypes[fd.GetType()] {
179 v, err = packedBuf.DecodeVarint()
180 } else if fixed32Types[fd.GetType()] {
181 v, err = packedBuf.DecodeFixed32()
182 } else if fixed64Types[fd.GetType()] {
183 v, err = packedBuf.DecodeFixed64()
184 } else {
185 return nil, fmt.Errorf("bad input; cannot parse length-delimited wire type for field %s", fd.GetFullyQualifiedName())
186 }
187 if err != nil {
188 return nil, err
189 }
190 val, err = DecodeScalarField(fd, v)
191 if err != nil {
192 return nil, err
193 }
194 if fd.IsRepeated() {
195 slice = append(slice, val)
196 }
197 }
198 if fd.IsRepeated() {
199 return slice, nil
200 } else {
201 // if not a repeated field, last value wins
202 return val, nil
203 }
204 }
205}
206
207func (b *Buffer) decodeKnownField(fd *desc.FieldDescriptor, encoding int8, fact MessageFactory) (interface{}, error) {
208 var val interface{}
209 var err error
210 switch encoding {
211 case proto.WireFixed32:
212 var num uint64
213 num, err = b.DecodeFixed32()
214 if err == nil {
215 val, err = DecodeScalarField(fd, num)
216 }
217 case proto.WireFixed64:
218 var num uint64
219 num, err = b.DecodeFixed64()
220 if err == nil {
221 val, err = DecodeScalarField(fd, num)
222 }
223 case proto.WireVarint:
224 var num uint64
225 num, err = b.DecodeVarint()
226 if err == nil {
227 val, err = DecodeScalarField(fd, num)
228 }
229
230 case proto.WireBytes:
231 alloc := fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES
232 var raw []byte
233 raw, err = b.DecodeRawBytes(alloc)
234 if err == nil {
235 val, err = DecodeLengthDelimitedField(fd, raw, fact)
236 }
237
238 case proto.WireStartGroup:
239 if fd.GetMessageType() == nil {
240 return nil, fmt.Errorf("cannot parse field %s from group-encoded wire type", fd.GetFullyQualifiedName())
241 }
242 msg := fact.NewMessage(fd.GetMessageType())
243 var data []byte
244 data, err = b.ReadGroup(false)
245 if err == nil {
246 err = proto.Unmarshal(data, msg)
247 if err == nil {
248 val = msg
249 }
250 }
251
252 default:
253 return nil, ErrBadWireType
254 }
255 if err != nil {
256 return nil, err
257 }
258
259 return val, nil
260}
261
262func (b *Buffer) decodeUnknownField(tagNumber int32, encoding int8) (interface{}, error) {
263 u := UnknownField{Tag: tagNumber, Encoding: encoding}
264 var err error
265 switch encoding {
266 case proto.WireFixed32:
267 u.Value, err = b.DecodeFixed32()
268 case proto.WireFixed64:
269 u.Value, err = b.DecodeFixed64()
270 case proto.WireVarint:
271 u.Value, err = b.DecodeVarint()
272 case proto.WireBytes:
273 u.Contents, err = b.DecodeRawBytes(true)
274 case proto.WireStartGroup:
275 u.Contents, err = b.ReadGroup(true)
276 default:
277 err = ErrBadWireType
278 }
279 if err != nil {
280 return nil, err
281 }
282 return u, nil
283}