blob: a25f680f82c9a38949ee8f4928f99f266cbe9d89 [file] [log] [blame]
khenaidooefff76e2021-12-15 16:51:30 -05001package codec
2
3import (
4 "errors"
5 "fmt"
6 "io"
7 "math"
8
9 "github.com/golang/protobuf/proto"
10)
11
12// ErrOverflow is returned when an integer is too large to be represented.
13var ErrOverflow = errors.New("proto: integer overflow")
14
15// ErrBadWireType is returned when decoding a wire-type from a buffer that
16// is not valid.
17var ErrBadWireType = errors.New("proto: bad wiretype")
18
19func (cb *Buffer) decodeVarintSlow() (x uint64, err error) {
20 i := cb.index
21 l := len(cb.buf)
22
23 for shift := uint(0); shift < 64; shift += 7 {
24 if i >= l {
25 err = io.ErrUnexpectedEOF
26 return
27 }
28 b := cb.buf[i]
29 i++
30 x |= (uint64(b) & 0x7F) << shift
31 if b < 0x80 {
32 cb.index = i
33 return
34 }
35 }
36
37 // The number is too large to represent in a 64-bit value.
38 err = ErrOverflow
39 return
40}
41
42// DecodeVarint reads a varint-encoded integer from the Buffer.
43// This is the format for the
44// int32, int64, uint32, uint64, bool, and enum
45// protocol buffer types.
46func (cb *Buffer) DecodeVarint() (uint64, error) {
47 i := cb.index
48 buf := cb.buf
49
50 if i >= len(buf) {
51 return 0, io.ErrUnexpectedEOF
52 } else if buf[i] < 0x80 {
53 cb.index++
54 return uint64(buf[i]), nil
55 } else if len(buf)-i < 10 {
56 return cb.decodeVarintSlow()
57 }
58
59 var b uint64
60 // we already checked the first byte
61 x := uint64(buf[i]) - 0x80
62 i++
63
64 b = uint64(buf[i])
65 i++
66 x += b << 7
67 if b&0x80 == 0 {
68 goto done
69 }
70 x -= 0x80 << 7
71
72 b = uint64(buf[i])
73 i++
74 x += b << 14
75 if b&0x80 == 0 {
76 goto done
77 }
78 x -= 0x80 << 14
79
80 b = uint64(buf[i])
81 i++
82 x += b << 21
83 if b&0x80 == 0 {
84 goto done
85 }
86 x -= 0x80 << 21
87
88 b = uint64(buf[i])
89 i++
90 x += b << 28
91 if b&0x80 == 0 {
92 goto done
93 }
94 x -= 0x80 << 28
95
96 b = uint64(buf[i])
97 i++
98 x += b << 35
99 if b&0x80 == 0 {
100 goto done
101 }
102 x -= 0x80 << 35
103
104 b = uint64(buf[i])
105 i++
106 x += b << 42
107 if b&0x80 == 0 {
108 goto done
109 }
110 x -= 0x80 << 42
111
112 b = uint64(buf[i])
113 i++
114 x += b << 49
115 if b&0x80 == 0 {
116 goto done
117 }
118 x -= 0x80 << 49
119
120 b = uint64(buf[i])
121 i++
122 x += b << 56
123 if b&0x80 == 0 {
124 goto done
125 }
126 x -= 0x80 << 56
127
128 b = uint64(buf[i])
129 i++
130 x += b << 63
131 if b&0x80 == 0 {
132 goto done
133 }
134 // x -= 0x80 << 63 // Always zero.
135
136 return 0, ErrOverflow
137
138done:
139 cb.index = i
140 return x, nil
141}
142
143// DecodeTagAndWireType decodes a field tag and wire type from input.
144// This reads a varint and then extracts the two fields from the varint
145// value read.
146func (cb *Buffer) DecodeTagAndWireType() (tag int32, wireType int8, err error) {
147 var v uint64
148 v, err = cb.DecodeVarint()
149 if err != nil {
150 return
151 }
152 // low 7 bits is wire type
153 wireType = int8(v & 7)
154 // rest is int32 tag number
155 v = v >> 3
156 if v > math.MaxInt32 {
157 err = fmt.Errorf("tag number out of range: %d", v)
158 return
159 }
160 tag = int32(v)
161 return
162}
163
164// DecodeFixed64 reads a 64-bit integer from the Buffer.
165// This is the format for the
166// fixed64, sfixed64, and double protocol buffer types.
167func (cb *Buffer) DecodeFixed64() (x uint64, err error) {
168 // x, err already 0
169 i := cb.index + 8
170 if i < 0 || i > len(cb.buf) {
171 err = io.ErrUnexpectedEOF
172 return
173 }
174 cb.index = i
175
176 x = uint64(cb.buf[i-8])
177 x |= uint64(cb.buf[i-7]) << 8
178 x |= uint64(cb.buf[i-6]) << 16
179 x |= uint64(cb.buf[i-5]) << 24
180 x |= uint64(cb.buf[i-4]) << 32
181 x |= uint64(cb.buf[i-3]) << 40
182 x |= uint64(cb.buf[i-2]) << 48
183 x |= uint64(cb.buf[i-1]) << 56
184 return
185}
186
187// DecodeFixed32 reads a 32-bit integer from the Buffer.
188// This is the format for the
189// fixed32, sfixed32, and float protocol buffer types.
190func (cb *Buffer) DecodeFixed32() (x uint64, err error) {
191 // x, err already 0
192 i := cb.index + 4
193 if i < 0 || i > len(cb.buf) {
194 err = io.ErrUnexpectedEOF
195 return
196 }
197 cb.index = i
198
199 x = uint64(cb.buf[i-4])
200 x |= uint64(cb.buf[i-3]) << 8
201 x |= uint64(cb.buf[i-2]) << 16
202 x |= uint64(cb.buf[i-1]) << 24
203 return
204}
205
206// DecodeRawBytes reads a count-delimited byte buffer from the Buffer.
207// This is the format used for the bytes protocol buffer
208// type and for embedded messages.
209func (cb *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
210 n, err := cb.DecodeVarint()
211 if err != nil {
212 return nil, err
213 }
214
215 nb := int(n)
216 if nb < 0 {
217 return nil, fmt.Errorf("proto: bad byte length %d", nb)
218 }
219 end := cb.index + nb
220 if end < cb.index || end > len(cb.buf) {
221 return nil, io.ErrUnexpectedEOF
222 }
223
224 if !alloc {
225 buf = cb.buf[cb.index:end]
226 cb.index = end
227 return
228 }
229
230 buf = make([]byte, nb)
231 copy(buf, cb.buf[cb.index:])
232 cb.index = end
233 return
234}
235
236// ReadGroup reads the input until a "group end" tag is found
237// and returns the data up to that point. Subsequent reads from
238// the buffer will read data after the group end tag. If alloc
239// is true, the data is copied to a new slice before being returned.
240// Otherwise, the returned slice is a view into the buffer's
241// underlying byte slice.
242//
243// This function correctly handles nested groups: if a "group start"
244// tag is found, then that group's end tag will be included in the
245// returned data.
246func (cb *Buffer) ReadGroup(alloc bool) ([]byte, error) {
247 var groupEnd, dataEnd int
248 groupEnd, dataEnd, err := cb.findGroupEnd()
249 if err != nil {
250 return nil, err
251 }
252 var results []byte
253 if !alloc {
254 results = cb.buf[cb.index:dataEnd]
255 } else {
256 results = make([]byte, dataEnd-cb.index)
257 copy(results, cb.buf[cb.index:])
258 }
259 cb.index = groupEnd
260 return results, nil
261}
262
263// SkipGroup is like ReadGroup, except that it discards the
264// data and just advances the buffer to point to the input
265// right *after* the "group end" tag.
266func (cb *Buffer) SkipGroup() error {
267 groupEnd, _, err := cb.findGroupEnd()
268 if err != nil {
269 return err
270 }
271 cb.index = groupEnd
272 return nil
273}
274
275// SkipField attempts to skip the value of a field with the given wire
276// type. When consuming a protobuf-encoded stream, it can be called immediately
277// after DecodeTagAndWireType to discard the subsequent data for the field.
278func (cb *Buffer) SkipField(wireType int8) error {
279 switch wireType {
280 case proto.WireFixed32:
281 if err := cb.Skip(4); err != nil {
282 return err
283 }
284 case proto.WireFixed64:
285 if err := cb.Skip(8); err != nil {
286 return err
287 }
288 case proto.WireVarint:
289 // skip varint by finding last byte (has high bit unset)
290 i := cb.index
291 limit := i + 10 // varint cannot be >10 bytes
292 for {
293 if i >= limit {
294 return ErrOverflow
295 }
296 if i >= len(cb.buf) {
297 return io.ErrUnexpectedEOF
298 }
299 if cb.buf[i]&0x80 == 0 {
300 break
301 }
302 i++
303 }
304 // TODO: This would only overflow if buffer length was MaxInt and we
305 // read the last byte. This is not a real/feasible concern on 64-bit
306 // systems. Something to worry about for 32-bit systems? Do we care?
307 cb.index = i + 1
308 case proto.WireBytes:
309 l, err := cb.DecodeVarint()
310 if err != nil {
311 return err
312 }
313 if err := cb.Skip(int(l)); err != nil {
314 return err
315 }
316 case proto.WireStartGroup:
317 if err := cb.SkipGroup(); err != nil {
318 return err
319 }
320 default:
321 return ErrBadWireType
322 }
323 return nil
324}
325
326func (cb *Buffer) findGroupEnd() (groupEnd int, dataEnd int, err error) {
327 start := cb.index
328 defer func() {
329 cb.index = start
330 }()
331 for {
332 fieldStart := cb.index
333 // read a field tag
334 _, wireType, err := cb.DecodeTagAndWireType()
335 if err != nil {
336 return 0, 0, err
337 }
338 if wireType == proto.WireEndGroup {
339 return cb.index, fieldStart, nil
340 }
341 // skip past the field's data
342 if err := cb.SkipField(wireType); err != nil {
343 return 0, 0, err
344 }
345 }
346}