| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Helper code for parsing a protocol buffer |
| |
| package protolazy |
| |
| import ( |
| "errors" |
| "fmt" |
| "io" |
| |
| "google.golang.org/protobuf/encoding/protowire" |
| ) |
| |
| // BufferReader is a structure encapsulating a protobuf and a current position |
| type BufferReader struct { |
| Buf []byte |
| Pos int |
| } |
| |
| // NewBufferReader creates a new BufferRead from a protobuf |
| func NewBufferReader(buf []byte) BufferReader { |
| return BufferReader{Buf: buf, Pos: 0} |
| } |
| |
| var errOutOfBounds = errors.New("protobuf decoding: out of bounds") |
| var errOverflow = errors.New("proto: integer overflow") |
| |
| func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) { |
| i := b.Pos |
| l := len(b.Buf) |
| |
| for shift := uint(0); shift < 64; shift += 7 { |
| if i >= l { |
| err = io.ErrUnexpectedEOF |
| return |
| } |
| v := b.Buf[i] |
| i++ |
| x |= (uint64(v) & 0x7F) << shift |
| if v < 0x80 { |
| b.Pos = i |
| return |
| } |
| } |
| |
| // The number is too large to represent in a 64-bit value. |
| err = errOverflow |
| return |
| } |
| |
| // decodeVarint decodes a varint at the current position |
| func (b *BufferReader) DecodeVarint() (x uint64, err error) { |
| i := b.Pos |
| buf := b.Buf |
| |
| if i >= len(buf) { |
| return 0, io.ErrUnexpectedEOF |
| } else if buf[i] < 0x80 { |
| b.Pos++ |
| return uint64(buf[i]), nil |
| } else if len(buf)-i < 10 { |
| return b.DecodeVarintSlow() |
| } |
| |
| var v uint64 |
| // we already checked the first byte |
| x = uint64(buf[i]) & 127 |
| i++ |
| |
| v = uint64(buf[i]) |
| i++ |
| x |= (v & 127) << 7 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint64(buf[i]) |
| i++ |
| x |= (v & 127) << 14 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint64(buf[i]) |
| i++ |
| x |= (v & 127) << 21 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint64(buf[i]) |
| i++ |
| x |= (v & 127) << 28 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint64(buf[i]) |
| i++ |
| x |= (v & 127) << 35 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint64(buf[i]) |
| i++ |
| x |= (v & 127) << 42 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint64(buf[i]) |
| i++ |
| x |= (v & 127) << 49 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint64(buf[i]) |
| i++ |
| x |= (v & 127) << 56 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint64(buf[i]) |
| i++ |
| x |= (v & 127) << 63 |
| if v < 128 { |
| goto done |
| } |
| |
| return 0, errOverflow |
| |
| done: |
| b.Pos = i |
| return |
| } |
| |
| // decodeVarint32 decodes a varint32 at the current position |
| func (b *BufferReader) DecodeVarint32() (x uint32, err error) { |
| i := b.Pos |
| buf := b.Buf |
| |
| if i >= len(buf) { |
| return 0, io.ErrUnexpectedEOF |
| } else if buf[i] < 0x80 { |
| b.Pos++ |
| return uint32(buf[i]), nil |
| } else if len(buf)-i < 5 { |
| v, err := b.DecodeVarintSlow() |
| return uint32(v), err |
| } |
| |
| var v uint32 |
| // we already checked the first byte |
| x = uint32(buf[i]) & 127 |
| i++ |
| |
| v = uint32(buf[i]) |
| i++ |
| x |= (v & 127) << 7 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint32(buf[i]) |
| i++ |
| x |= (v & 127) << 14 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint32(buf[i]) |
| i++ |
| x |= (v & 127) << 21 |
| if v < 128 { |
| goto done |
| } |
| |
| v = uint32(buf[i]) |
| i++ |
| x |= (v & 127) << 28 |
| if v < 128 { |
| goto done |
| } |
| |
| return 0, errOverflow |
| |
| done: |
| b.Pos = i |
| return |
| } |
| |
| // skipValue skips a value in the protobuf, based on the specified tag |
| func (b *BufferReader) SkipValue(tag uint32) (err error) { |
| wireType := tag & 0x7 |
| switch protowire.Type(wireType) { |
| case protowire.VarintType: |
| err = b.SkipVarint() |
| case protowire.Fixed64Type: |
| err = b.SkipFixed64() |
| case protowire.BytesType: |
| var n uint32 |
| n, err = b.DecodeVarint32() |
| if err == nil { |
| err = b.Skip(int(n)) |
| } |
| case protowire.StartGroupType: |
| err = b.SkipGroup(tag) |
| case protowire.Fixed32Type: |
| err = b.SkipFixed32() |
| default: |
| err = fmt.Errorf("Unexpected wire type (%d)", wireType) |
| } |
| return |
| } |
| |
| // skipGroup skips a group with the specified tag. It executes efficiently using a tag stack |
| func (b *BufferReader) SkipGroup(tag uint32) (err error) { |
| tagStack := make([]uint32, 0, 16) |
| tagStack = append(tagStack, tag) |
| var n uint32 |
| for len(tagStack) > 0 { |
| tag, err = b.DecodeVarint32() |
| if err != nil { |
| return err |
| } |
| switch protowire.Type(tag & 0x7) { |
| case protowire.VarintType: |
| err = b.SkipVarint() |
| case protowire.Fixed64Type: |
| err = b.Skip(8) |
| case protowire.BytesType: |
| n, err = b.DecodeVarint32() |
| if err == nil { |
| err = b.Skip(int(n)) |
| } |
| case protowire.StartGroupType: |
| tagStack = append(tagStack, tag) |
| case protowire.Fixed32Type: |
| err = b.SkipFixed32() |
| case protowire.EndGroupType: |
| if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) { |
| tagStack = tagStack[:len(tagStack)-1] |
| } else { |
| err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d", |
| protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos) |
| } |
| } |
| if err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| // skipVarint effiently skips a varint |
| func (b *BufferReader) SkipVarint() (err error) { |
| i := b.Pos |
| |
| if len(b.Buf)-i < 10 { |
| // Use DecodeVarintSlow() to check for buffer overflow, but ignore result |
| if _, err := b.DecodeVarintSlow(); err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| i++ |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| i++ |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| i++ |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| i++ |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| i++ |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| i++ |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| i++ |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| i++ |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| i++ |
| |
| if b.Buf[i] < 0x80 { |
| goto out |
| } |
| return errOverflow |
| |
| out: |
| b.Pos = i + 1 |
| return nil |
| } |
| |
| // skip skips the specified number of bytes |
| func (b *BufferReader) Skip(n int) (err error) { |
| if len(b.Buf) < b.Pos+n { |
| return io.ErrUnexpectedEOF |
| } |
| b.Pos += n |
| return |
| } |
| |
| // skipFixed64 skips a fixed64 |
| func (b *BufferReader) SkipFixed64() (err error) { |
| return b.Skip(8) |
| } |
| |
| // skipFixed32 skips a fixed32 |
| func (b *BufferReader) SkipFixed32() (err error) { |
| return b.Skip(4) |
| } |
| |
| // skipBytes skips a set of bytes |
| func (b *BufferReader) SkipBytes() (err error) { |
| n, err := b.DecodeVarint32() |
| if err != nil { |
| return err |
| } |
| return b.Skip(int(n)) |
| } |
| |
| // Done returns whether we are at the end of the protobuf |
| func (b *BufferReader) Done() bool { |
| return b.Pos == len(b.Buf) |
| } |
| |
| // Remaining returns how many bytes remain |
| func (b *BufferReader) Remaining() int { |
| return len(b.Buf) - b.Pos |
| } |