blob: 2a7e59f66c7fd15738c2d6b5a8ac985a182671a7 [file] [log] [blame]
Matteo Scandoloa6a3aee2019-11-26 13:30:14 -07001package codec
2
3import (
4 "errors"
5 "fmt"
6 "io"
7 "math"
8
9 "github.com/golang/protobuf/proto"
10 "github.com/golang/protobuf/protoc-gen-go/descriptor"
11)
12
13// ErrOverflow is returned when an integer is too large to be represented.
14var ErrOverflow = errors.New("proto: integer overflow")
15
16// ErrBadWireType is returned when decoding a wire-type from a buffer that
17// is not valid.
18var ErrBadWireType = errors.New("proto: bad wiretype")
19
20var varintTypes = map[descriptor.FieldDescriptorProto_Type]bool{}
21var fixed32Types = map[descriptor.FieldDescriptorProto_Type]bool{}
22var fixed64Types = map[descriptor.FieldDescriptorProto_Type]bool{}
23
24func init() {
25 varintTypes[descriptor.FieldDescriptorProto_TYPE_BOOL] = true
26 varintTypes[descriptor.FieldDescriptorProto_TYPE_INT32] = true
27 varintTypes[descriptor.FieldDescriptorProto_TYPE_INT64] = true
28 varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT32] = true
29 varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT64] = true
30 varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT32] = true
31 varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT64] = true
32 varintTypes[descriptor.FieldDescriptorProto_TYPE_ENUM] = true
33
34 fixed32Types[descriptor.FieldDescriptorProto_TYPE_FIXED32] = true
35 fixed32Types[descriptor.FieldDescriptorProto_TYPE_SFIXED32] = true
36 fixed32Types[descriptor.FieldDescriptorProto_TYPE_FLOAT] = true
37
38 fixed64Types[descriptor.FieldDescriptorProto_TYPE_FIXED64] = true
39 fixed64Types[descriptor.FieldDescriptorProto_TYPE_SFIXED64] = true
40 fixed64Types[descriptor.FieldDescriptorProto_TYPE_DOUBLE] = true
41}
42
43func (cb *Buffer) decodeVarintSlow() (x uint64, err error) {
44 i := cb.index
45 l := len(cb.buf)
46
47 for shift := uint(0); shift < 64; shift += 7 {
48 if i >= l {
49 err = io.ErrUnexpectedEOF
50 return
51 }
52 b := cb.buf[i]
53 i++
54 x |= (uint64(b) & 0x7F) << shift
55 if b < 0x80 {
56 cb.index = i
57 return
58 }
59 }
60
61 // The number is too large to represent in a 64-bit value.
62 err = ErrOverflow
63 return
64}
65
66// DecodeVarint reads a varint-encoded integer from the Buffer.
67// This is the format for the
68// int32, int64, uint32, uint64, bool, and enum
69// protocol buffer types.
70func (cb *Buffer) DecodeVarint() (uint64, error) {
71 i := cb.index
72 buf := cb.buf
73
74 if i >= len(buf) {
75 return 0, io.ErrUnexpectedEOF
76 } else if buf[i] < 0x80 {
77 cb.index++
78 return uint64(buf[i]), nil
79 } else if len(buf)-i < 10 {
80 return cb.decodeVarintSlow()
81 }
82
83 var b uint64
84 // we already checked the first byte
85 x := uint64(buf[i]) - 0x80
86 i++
87
88 b = uint64(buf[i])
89 i++
90 x += b << 7
91 if b&0x80 == 0 {
92 goto done
93 }
94 x -= 0x80 << 7
95
96 b = uint64(buf[i])
97 i++
98 x += b << 14
99 if b&0x80 == 0 {
100 goto done
101 }
102 x -= 0x80 << 14
103
104 b = uint64(buf[i])
105 i++
106 x += b << 21
107 if b&0x80 == 0 {
108 goto done
109 }
110 x -= 0x80 << 21
111
112 b = uint64(buf[i])
113 i++
114 x += b << 28
115 if b&0x80 == 0 {
116 goto done
117 }
118 x -= 0x80 << 28
119
120 b = uint64(buf[i])
121 i++
122 x += b << 35
123 if b&0x80 == 0 {
124 goto done
125 }
126 x -= 0x80 << 35
127
128 b = uint64(buf[i])
129 i++
130 x += b << 42
131 if b&0x80 == 0 {
132 goto done
133 }
134 x -= 0x80 << 42
135
136 b = uint64(buf[i])
137 i++
138 x += b << 49
139 if b&0x80 == 0 {
140 goto done
141 }
142 x -= 0x80 << 49
143
144 b = uint64(buf[i])
145 i++
146 x += b << 56
147 if b&0x80 == 0 {
148 goto done
149 }
150 x -= 0x80 << 56
151
152 b = uint64(buf[i])
153 i++
154 x += b << 63
155 if b&0x80 == 0 {
156 goto done
157 }
158 // x -= 0x80 << 63 // Always zero.
159
160 return 0, ErrOverflow
161
162done:
163 cb.index = i
164 return x, nil
165}
166
167// DecodeTagAndWireType decodes a field tag and wire type from input.
168// This reads a varint and then extracts the two fields from the varint
169// value read.
170func (cb *Buffer) DecodeTagAndWireType() (tag int32, wireType int8, err error) {
171 var v uint64
172 v, err = cb.DecodeVarint()
173 if err != nil {
174 return
175 }
176 // low 7 bits is wire type
177 wireType = int8(v & 7)
178 // rest is int32 tag number
179 v = v >> 3
180 if v > math.MaxInt32 {
181 err = fmt.Errorf("tag number out of range: %d", v)
182 return
183 }
184 tag = int32(v)
185 return
186}
187
188// DecodeFixed64 reads a 64-bit integer from the Buffer.
189// This is the format for the
190// fixed64, sfixed64, and double protocol buffer types.
191func (cb *Buffer) DecodeFixed64() (x uint64, err error) {
192 // x, err already 0
193 i := cb.index + 8
194 if i < 0 || i > len(cb.buf) {
195 err = io.ErrUnexpectedEOF
196 return
197 }
198 cb.index = i
199
200 x = uint64(cb.buf[i-8])
201 x |= uint64(cb.buf[i-7]) << 8
202 x |= uint64(cb.buf[i-6]) << 16
203 x |= uint64(cb.buf[i-5]) << 24
204 x |= uint64(cb.buf[i-4]) << 32
205 x |= uint64(cb.buf[i-3]) << 40
206 x |= uint64(cb.buf[i-2]) << 48
207 x |= uint64(cb.buf[i-1]) << 56
208 return
209}
210
211// DecodeFixed32 reads a 32-bit integer from the Buffer.
212// This is the format for the
213// fixed32, sfixed32, and float protocol buffer types.
214func (cb *Buffer) DecodeFixed32() (x uint64, err error) {
215 // x, err already 0
216 i := cb.index + 4
217 if i < 0 || i > len(cb.buf) {
218 err = io.ErrUnexpectedEOF
219 return
220 }
221 cb.index = i
222
223 x = uint64(cb.buf[i-4])
224 x |= uint64(cb.buf[i-3]) << 8
225 x |= uint64(cb.buf[i-2]) << 16
226 x |= uint64(cb.buf[i-1]) << 24
227 return
228}
229
230// DecodeZigZag32 decodes a signed 32-bit integer from the given
231// zig-zag encoded value.
232func DecodeZigZag32(v uint64) int32 {
233 return int32((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31))
234}
235
236// DecodeZigZag64 decodes a signed 64-bit integer from the given
237// zig-zag encoded value.
238func DecodeZigZag64(v uint64) int64 {
239 return int64((v >> 1) ^ uint64((int64(v&1)<<63)>>63))
240}
241
242// DecodeRawBytes reads a count-delimited byte buffer from the Buffer.
243// This is the format used for the bytes protocol buffer
244// type and for embedded messages.
245func (cb *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
246 n, err := cb.DecodeVarint()
247 if err != nil {
248 return nil, err
249 }
250
251 nb := int(n)
252 if nb < 0 {
253 return nil, fmt.Errorf("proto: bad byte length %d", nb)
254 }
255 end := cb.index + nb
256 if end < cb.index || end > len(cb.buf) {
257 return nil, io.ErrUnexpectedEOF
258 }
259
260 if !alloc {
261 buf = cb.buf[cb.index:end]
262 cb.index = end
263 return
264 }
265
266 buf = make([]byte, nb)
267 copy(buf, cb.buf[cb.index:])
268 cb.index = end
269 return
270}
271
272// ReadGroup reads the input until a "group end" tag is found
273// and returns the data up to that point. Subsequent reads from
274// the buffer will read data after the group end tag. If alloc
275// is true, the data is copied to a new slice before being returned.
276// Otherwise, the returned slice is a view into the buffer's
277// underlying byte slice.
278//
279// This function correctly handles nested groups: if a "group start"
280// tag is found, then that group's end tag will be included in the
281// returned data.
282func (cb *Buffer) ReadGroup(alloc bool) ([]byte, error) {
283 var groupEnd, dataEnd int
284 groupEnd, dataEnd, err := cb.findGroupEnd()
285 if err != nil {
286 return nil, err
287 }
288 var results []byte
289 if !alloc {
290 results = cb.buf[cb.index:dataEnd]
291 } else {
292 results = make([]byte, dataEnd-cb.index)
293 copy(results, cb.buf[cb.index:])
294 }
295 cb.index = groupEnd
296 return results, nil
297}
298
299// SkipGroup is like ReadGroup, except that it discards the
300// data and just advances the buffer to point to the input
301// right *after* the "group end" tag.
302func (cb *Buffer) SkipGroup() error {
303 groupEnd, _, err := cb.findGroupEnd()
304 if err != nil {
305 return err
306 }
307 cb.index = groupEnd
308 return nil
309}
310
311func (cb *Buffer) findGroupEnd() (groupEnd int, dataEnd int, err error) {
312 bs := cb.buf
313 start := cb.index
314 defer func() {
315 cb.index = start
316 }()
317 for {
318 fieldStart := cb.index
319 // read a field tag
320 _, wireType, err := cb.DecodeTagAndWireType()
321 if err != nil {
322 return 0, 0, err
323 }
324 // skip past the field's data
325 switch wireType {
326 case proto.WireFixed32:
327 if err := cb.Skip(4); err != nil {
328 return 0, 0, err
329 }
330 case proto.WireFixed64:
331 if err := cb.Skip(8); err != nil {
332 return 0, 0, err
333 }
334 case proto.WireVarint:
335 // skip varint by finding last byte (has high bit unset)
336 i := cb.index
337 limit := i + 10 // varint cannot be >10 bytes
338 for {
339 if i >= limit {
340 return 0, 0, ErrOverflow
341 }
342 if i >= len(bs) {
343 return 0, 0, io.ErrUnexpectedEOF
344 }
345 if bs[i]&0x80 == 0 {
346 break
347 }
348 i++
349 }
350 // TODO: This would only overflow if buffer length was MaxInt and we
351 // read the last byte. This is not a real/feasible concern on 64-bit
352 // systems. Something to worry about for 32-bit systems? Do we care?
353 cb.index = i + 1
354 case proto.WireBytes:
355 l, err := cb.DecodeVarint()
356 if err != nil {
357 return 0, 0, err
358 }
359 if err := cb.Skip(int(l)); err != nil {
360 return 0, 0, err
361 }
362 case proto.WireStartGroup:
363 if err := cb.SkipGroup(); err != nil {
364 return 0, 0, err
365 }
366 case proto.WireEndGroup:
367 return cb.index, fieldStart, nil
368 default:
369 return 0, 0, ErrBadWireType
370 }
371 }
372}