blob: 82e5cab4aadd15bbf9d0b069942bf897eae0ac77 [file] [log] [blame]
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +05301// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Helper code for parsing a protocol buffer
6
7package protolazy
8
9import (
10 "errors"
11 "fmt"
12 "io"
13
14 "google.golang.org/protobuf/encoding/protowire"
15)
16
17// BufferReader is a structure encapsulating a protobuf and a current position
18type BufferReader struct {
19 Buf []byte
20 Pos int
21}
22
23// NewBufferReader creates a new BufferRead from a protobuf
24func NewBufferReader(buf []byte) BufferReader {
25 return BufferReader{Buf: buf, Pos: 0}
26}
27
28var errOutOfBounds = errors.New("protobuf decoding: out of bounds")
29var errOverflow = errors.New("proto: integer overflow")
30
31func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) {
32 i := b.Pos
33 l := len(b.Buf)
34
35 for shift := uint(0); shift < 64; shift += 7 {
36 if i >= l {
37 err = io.ErrUnexpectedEOF
38 return
39 }
40 v := b.Buf[i]
41 i++
42 x |= (uint64(v) & 0x7F) << shift
43 if v < 0x80 {
44 b.Pos = i
45 return
46 }
47 }
48
49 // The number is too large to represent in a 64-bit value.
50 err = errOverflow
51 return
52}
53
54// decodeVarint decodes a varint at the current position
55func (b *BufferReader) DecodeVarint() (x uint64, err error) {
56 i := b.Pos
57 buf := b.Buf
58
59 if i >= len(buf) {
60 return 0, io.ErrUnexpectedEOF
61 } else if buf[i] < 0x80 {
62 b.Pos++
63 return uint64(buf[i]), nil
64 } else if len(buf)-i < 10 {
65 return b.DecodeVarintSlow()
66 }
67
68 var v uint64
69 // we already checked the first byte
70 x = uint64(buf[i]) & 127
71 i++
72
73 v = uint64(buf[i])
74 i++
75 x |= (v & 127) << 7
76 if v < 128 {
77 goto done
78 }
79
80 v = uint64(buf[i])
81 i++
82 x |= (v & 127) << 14
83 if v < 128 {
84 goto done
85 }
86
87 v = uint64(buf[i])
88 i++
89 x |= (v & 127) << 21
90 if v < 128 {
91 goto done
92 }
93
94 v = uint64(buf[i])
95 i++
96 x |= (v & 127) << 28
97 if v < 128 {
98 goto done
99 }
100
101 v = uint64(buf[i])
102 i++
103 x |= (v & 127) << 35
104 if v < 128 {
105 goto done
106 }
107
108 v = uint64(buf[i])
109 i++
110 x |= (v & 127) << 42
111 if v < 128 {
112 goto done
113 }
114
115 v = uint64(buf[i])
116 i++
117 x |= (v & 127) << 49
118 if v < 128 {
119 goto done
120 }
121
122 v = uint64(buf[i])
123 i++
124 x |= (v & 127) << 56
125 if v < 128 {
126 goto done
127 }
128
129 v = uint64(buf[i])
130 i++
131 x |= (v & 127) << 63
132 if v < 128 {
133 goto done
134 }
135
136 return 0, errOverflow
137
138done:
139 b.Pos = i
140 return
141}
142
143// decodeVarint32 decodes a varint32 at the current position
144func (b *BufferReader) DecodeVarint32() (x uint32, err error) {
145 i := b.Pos
146 buf := b.Buf
147
148 if i >= len(buf) {
149 return 0, io.ErrUnexpectedEOF
150 } else if buf[i] < 0x80 {
151 b.Pos++
152 return uint32(buf[i]), nil
153 } else if len(buf)-i < 5 {
154 v, err := b.DecodeVarintSlow()
155 return uint32(v), err
156 }
157
158 var v uint32
159 // we already checked the first byte
160 x = uint32(buf[i]) & 127
161 i++
162
163 v = uint32(buf[i])
164 i++
165 x |= (v & 127) << 7
166 if v < 128 {
167 goto done
168 }
169
170 v = uint32(buf[i])
171 i++
172 x |= (v & 127) << 14
173 if v < 128 {
174 goto done
175 }
176
177 v = uint32(buf[i])
178 i++
179 x |= (v & 127) << 21
180 if v < 128 {
181 goto done
182 }
183
184 v = uint32(buf[i])
185 i++
186 x |= (v & 127) << 28
187 if v < 128 {
188 goto done
189 }
190
191 return 0, errOverflow
192
193done:
194 b.Pos = i
195 return
196}
197
198// skipValue skips a value in the protobuf, based on the specified tag
199func (b *BufferReader) SkipValue(tag uint32) (err error) {
200 wireType := tag & 0x7
201 switch protowire.Type(wireType) {
202 case protowire.VarintType:
203 err = b.SkipVarint()
204 case protowire.Fixed64Type:
205 err = b.SkipFixed64()
206 case protowire.BytesType:
207 var n uint32
208 n, err = b.DecodeVarint32()
209 if err == nil {
210 err = b.Skip(int(n))
211 }
212 case protowire.StartGroupType:
213 err = b.SkipGroup(tag)
214 case protowire.Fixed32Type:
215 err = b.SkipFixed32()
216 default:
217 err = fmt.Errorf("Unexpected wire type (%d)", wireType)
218 }
219 return
220}
221
222// skipGroup skips a group with the specified tag. It executes efficiently using a tag stack
223func (b *BufferReader) SkipGroup(tag uint32) (err error) {
224 tagStack := make([]uint32, 0, 16)
225 tagStack = append(tagStack, tag)
226 var n uint32
227 for len(tagStack) > 0 {
228 tag, err = b.DecodeVarint32()
229 if err != nil {
230 return err
231 }
232 switch protowire.Type(tag & 0x7) {
233 case protowire.VarintType:
234 err = b.SkipVarint()
235 case protowire.Fixed64Type:
236 err = b.Skip(8)
237 case protowire.BytesType:
238 n, err = b.DecodeVarint32()
239 if err == nil {
240 err = b.Skip(int(n))
241 }
242 case protowire.StartGroupType:
243 tagStack = append(tagStack, tag)
244 case protowire.Fixed32Type:
245 err = b.SkipFixed32()
246 case protowire.EndGroupType:
247 if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) {
248 tagStack = tagStack[:len(tagStack)-1]
249 } else {
250 err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d",
251 protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos)
252 }
253 }
254 if err != nil {
255 return err
256 }
257 }
258 return nil
259}
260
261// skipVarint effiently skips a varint
262func (b *BufferReader) SkipVarint() (err error) {
263 i := b.Pos
264
265 if len(b.Buf)-i < 10 {
266 // Use DecodeVarintSlow() to check for buffer overflow, but ignore result
267 if _, err := b.DecodeVarintSlow(); err != nil {
268 return err
269 }
270 return nil
271 }
272
273 if b.Buf[i] < 0x80 {
274 goto out
275 }
276 i++
277
278 if b.Buf[i] < 0x80 {
279 goto out
280 }
281 i++
282
283 if b.Buf[i] < 0x80 {
284 goto out
285 }
286 i++
287
288 if b.Buf[i] < 0x80 {
289 goto out
290 }
291 i++
292
293 if b.Buf[i] < 0x80 {
294 goto out
295 }
296 i++
297
298 if b.Buf[i] < 0x80 {
299 goto out
300 }
301 i++
302
303 if b.Buf[i] < 0x80 {
304 goto out
305 }
306 i++
307
308 if b.Buf[i] < 0x80 {
309 goto out
310 }
311 i++
312
313 if b.Buf[i] < 0x80 {
314 goto out
315 }
316 i++
317
318 if b.Buf[i] < 0x80 {
319 goto out
320 }
321 return errOverflow
322
323out:
324 b.Pos = i + 1
325 return nil
326}
327
328// skip skips the specified number of bytes
329func (b *BufferReader) Skip(n int) (err error) {
330 if len(b.Buf) < b.Pos+n {
331 return io.ErrUnexpectedEOF
332 }
333 b.Pos += n
334 return
335}
336
337// skipFixed64 skips a fixed64
338func (b *BufferReader) SkipFixed64() (err error) {
339 return b.Skip(8)
340}
341
342// skipFixed32 skips a fixed32
343func (b *BufferReader) SkipFixed32() (err error) {
344 return b.Skip(4)
345}
346
347// skipBytes skips a set of bytes
348func (b *BufferReader) SkipBytes() (err error) {
349 n, err := b.DecodeVarint32()
350 if err != nil {
351 return err
352 }
353 return b.Skip(int(n))
354}
355
356// Done returns whether we are at the end of the protobuf
357func (b *BufferReader) Done() bool {
358 return b.Pos == len(b.Buf)
359}
360
361// Remaining returns how many bytes remain
362func (b *BufferReader) Remaining() int {
363 return len(b.Buf) - b.Pos
364}