blob: e810e6fea129d47402785bf2850cc8505d51e4c9 [file] [log] [blame]
khenaidoo26721882021-08-11 17:42:52 -04001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "errors"
9 "fmt"
10
11 "google.golang.org/protobuf/encoding/prototext"
12 "google.golang.org/protobuf/encoding/protowire"
13 "google.golang.org/protobuf/runtime/protoimpl"
14)
15
16const (
17 WireVarint = 0
18 WireFixed32 = 5
19 WireFixed64 = 1
20 WireBytes = 2
21 WireStartGroup = 3
22 WireEndGroup = 4
23)
24
25// EncodeVarint returns the varint encoded bytes of v.
26func EncodeVarint(v uint64) []byte {
27 return protowire.AppendVarint(nil, v)
28}
29
30// SizeVarint returns the length of the varint encoded bytes of v.
31// This is equal to len(EncodeVarint(v)).
32func SizeVarint(v uint64) int {
33 return protowire.SizeVarint(v)
34}
35
36// DecodeVarint parses a varint encoded integer from b,
37// returning the integer value and the length of the varint.
38// It returns (0, 0) if there is a parse error.
39func DecodeVarint(b []byte) (uint64, int) {
40 v, n := protowire.ConsumeVarint(b)
41 if n < 0 {
42 return 0, 0
43 }
44 return v, n
45}
46
47// Buffer is a buffer for encoding and decoding the protobuf wire format.
48// It may be reused between invocations to reduce memory usage.
49type Buffer struct {
50 buf []byte
51 idx int
52 deterministic bool
53}
54
55// NewBuffer allocates a new Buffer initialized with buf,
56// where the contents of buf are considered the unread portion of the buffer.
57func NewBuffer(buf []byte) *Buffer {
58 return &Buffer{buf: buf}
59}
60
61// SetDeterministic specifies whether to use deterministic serialization.
62//
63// Deterministic serialization guarantees that for a given binary, equal
64// messages will always be serialized to the same bytes. This implies:
65//
66// - Repeated serialization of a message will return the same bytes.
67// - Different processes of the same binary (which may be executing on
68// different machines) will serialize equal messages to the same bytes.
69//
70// Note that the deterministic serialization is NOT canonical across
71// languages. It is not guaranteed to remain stable over time. It is unstable
72// across different builds with schema changes due to unknown fields.
73// Users who need canonical serialization (e.g., persistent storage in a
74// canonical form, fingerprinting, etc.) should define their own
75// canonicalization specification and implement their own serializer rather
76// than relying on this API.
77//
78// If deterministic serialization is requested, map entries will be sorted
79// by keys in lexographical order. This is an implementation detail and
80// subject to change.
81func (b *Buffer) SetDeterministic(deterministic bool) {
82 b.deterministic = deterministic
83}
84
85// SetBuf sets buf as the internal buffer,
86// where the contents of buf are considered the unread portion of the buffer.
87func (b *Buffer) SetBuf(buf []byte) {
88 b.buf = buf
89 b.idx = 0
90}
91
92// Reset clears the internal buffer of all written and unread data.
93func (b *Buffer) Reset() {
94 b.buf = b.buf[:0]
95 b.idx = 0
96}
97
98// Bytes returns the internal buffer.
99func (b *Buffer) Bytes() []byte {
100 return b.buf
101}
102
103// Unread returns the unread portion of the buffer.
104func (b *Buffer) Unread() []byte {
105 return b.buf[b.idx:]
106}
107
108// Marshal appends the wire-format encoding of m to the buffer.
109func (b *Buffer) Marshal(m Message) error {
110 var err error
111 b.buf, err = marshalAppend(b.buf, m, b.deterministic)
112 return err
113}
114
115// Unmarshal parses the wire-format message in the buffer and
116// places the decoded results in m.
117// It does not reset m before unmarshaling.
118func (b *Buffer) Unmarshal(m Message) error {
119 err := UnmarshalMerge(b.Unread(), m)
120 b.idx = len(b.buf)
121 return err
122}
123
124type unknownFields struct{ XXX_unrecognized protoimpl.UnknownFields }
125
126func (m *unknownFields) String() string { panic("not implemented") }
127func (m *unknownFields) Reset() { panic("not implemented") }
128func (m *unknownFields) ProtoMessage() { panic("not implemented") }
129
130// DebugPrint dumps the encoded bytes of b with a header and footer including s
131// to stdout. This is only intended for debugging.
132func (*Buffer) DebugPrint(s string, b []byte) {
133 m := MessageReflect(new(unknownFields))
134 m.SetUnknown(b)
135 b, _ = prototext.MarshalOptions{AllowPartial: true, Indent: "\t"}.Marshal(m.Interface())
136 fmt.Printf("==== %s ====\n%s==== %s ====\n", s, b, s)
137}
138
139// EncodeVarint appends an unsigned varint encoding to the buffer.
140func (b *Buffer) EncodeVarint(v uint64) error {
141 b.buf = protowire.AppendVarint(b.buf, v)
142 return nil
143}
144
145// EncodeZigzag32 appends a 32-bit zig-zag varint encoding to the buffer.
146func (b *Buffer) EncodeZigzag32(v uint64) error {
147 return b.EncodeVarint(uint64((uint32(v) << 1) ^ uint32((int32(v) >> 31))))
148}
149
150// EncodeZigzag64 appends a 64-bit zig-zag varint encoding to the buffer.
151func (b *Buffer) EncodeZigzag64(v uint64) error {
152 return b.EncodeVarint(uint64((uint64(v) << 1) ^ uint64((int64(v) >> 63))))
153}
154
155// EncodeFixed32 appends a 32-bit little-endian integer to the buffer.
156func (b *Buffer) EncodeFixed32(v uint64) error {
157 b.buf = protowire.AppendFixed32(b.buf, uint32(v))
158 return nil
159}
160
161// EncodeFixed64 appends a 64-bit little-endian integer to the buffer.
162func (b *Buffer) EncodeFixed64(v uint64) error {
163 b.buf = protowire.AppendFixed64(b.buf, uint64(v))
164 return nil
165}
166
167// EncodeRawBytes appends a length-prefixed raw bytes to the buffer.
168func (b *Buffer) EncodeRawBytes(v []byte) error {
169 b.buf = protowire.AppendBytes(b.buf, v)
170 return nil
171}
172
173// EncodeStringBytes appends a length-prefixed raw bytes to the buffer.
174// It does not validate whether v contains valid UTF-8.
175func (b *Buffer) EncodeStringBytes(v string) error {
176 b.buf = protowire.AppendString(b.buf, v)
177 return nil
178}
179
180// EncodeMessage appends a length-prefixed encoded message to the buffer.
181func (b *Buffer) EncodeMessage(m Message) error {
182 var err error
183 b.buf = protowire.AppendVarint(b.buf, uint64(Size(m)))
184 b.buf, err = marshalAppend(b.buf, m, b.deterministic)
185 return err
186}
187
188// DecodeVarint consumes an encoded unsigned varint from the buffer.
189func (b *Buffer) DecodeVarint() (uint64, error) {
190 v, n := protowire.ConsumeVarint(b.buf[b.idx:])
191 if n < 0 {
192 return 0, protowire.ParseError(n)
193 }
194 b.idx += n
195 return uint64(v), nil
196}
197
198// DecodeZigzag32 consumes an encoded 32-bit zig-zag varint from the buffer.
199func (b *Buffer) DecodeZigzag32() (uint64, error) {
200 v, err := b.DecodeVarint()
201 if err != nil {
202 return 0, err
203 }
204 return uint64((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31)), nil
205}
206
207// DecodeZigzag64 consumes an encoded 64-bit zig-zag varint from the buffer.
208func (b *Buffer) DecodeZigzag64() (uint64, error) {
209 v, err := b.DecodeVarint()
210 if err != nil {
211 return 0, err
212 }
213 return uint64((uint64(v) >> 1) ^ uint64((int64(v&1)<<63)>>63)), nil
214}
215
216// DecodeFixed32 consumes a 32-bit little-endian integer from the buffer.
217func (b *Buffer) DecodeFixed32() (uint64, error) {
218 v, n := protowire.ConsumeFixed32(b.buf[b.idx:])
219 if n < 0 {
220 return 0, protowire.ParseError(n)
221 }
222 b.idx += n
223 return uint64(v), nil
224}
225
226// DecodeFixed64 consumes a 64-bit little-endian integer from the buffer.
227func (b *Buffer) DecodeFixed64() (uint64, error) {
228 v, n := protowire.ConsumeFixed64(b.buf[b.idx:])
229 if n < 0 {
230 return 0, protowire.ParseError(n)
231 }
232 b.idx += n
233 return uint64(v), nil
234}
235
236// DecodeRawBytes consumes a length-prefixed raw bytes from the buffer.
237// If alloc is specified, it returns a copy the raw bytes
238// rather than a sub-slice of the buffer.
239func (b *Buffer) DecodeRawBytes(alloc bool) ([]byte, error) {
240 v, n := protowire.ConsumeBytes(b.buf[b.idx:])
241 if n < 0 {
242 return nil, protowire.ParseError(n)
243 }
244 b.idx += n
245 if alloc {
246 v = append([]byte(nil), v...)
247 }
248 return v, nil
249}
250
251// DecodeStringBytes consumes a length-prefixed raw bytes from the buffer.
252// It does not validate whether the raw bytes contain valid UTF-8.
253func (b *Buffer) DecodeStringBytes() (string, error) {
254 v, n := protowire.ConsumeString(b.buf[b.idx:])
255 if n < 0 {
256 return "", protowire.ParseError(n)
257 }
258 b.idx += n
259 return v, nil
260}
261
262// DecodeMessage consumes a length-prefixed message from the buffer.
263// It does not reset m before unmarshaling.
264func (b *Buffer) DecodeMessage(m Message) error {
265 v, err := b.DecodeRawBytes(false)
266 if err != nil {
267 return err
268 }
269 return UnmarshalMerge(v, m)
270}
271
272// DecodeGroup consumes a message group from the buffer.
273// It assumes that the start group marker has already been consumed and
274// consumes all bytes until (and including the end group marker).
275// It does not reset m before unmarshaling.
276func (b *Buffer) DecodeGroup(m Message) error {
277 v, n, err := consumeGroup(b.buf[b.idx:])
278 if err != nil {
279 return err
280 }
281 b.idx += n
282 return UnmarshalMerge(v, m)
283}
284
285// consumeGroup parses b until it finds an end group marker, returning
286// the raw bytes of the message (excluding the end group marker) and the
287// the total length of the message (including the end group marker).
288func consumeGroup(b []byte) ([]byte, int, error) {
289 b0 := b
290 depth := 1 // assume this follows a start group marker
291 for {
292 _, wtyp, tagLen := protowire.ConsumeTag(b)
293 if tagLen < 0 {
294 return nil, 0, protowire.ParseError(tagLen)
295 }
296 b = b[tagLen:]
297
298 var valLen int
299 switch wtyp {
300 case protowire.VarintType:
301 _, valLen = protowire.ConsumeVarint(b)
302 case protowire.Fixed32Type:
303 _, valLen = protowire.ConsumeFixed32(b)
304 case protowire.Fixed64Type:
305 _, valLen = protowire.ConsumeFixed64(b)
306 case protowire.BytesType:
307 _, valLen = protowire.ConsumeBytes(b)
308 case protowire.StartGroupType:
309 depth++
310 case protowire.EndGroupType:
311 depth--
312 default:
313 return nil, 0, errors.New("proto: cannot parse reserved wire type")
314 }
315 if valLen < 0 {
316 return nil, 0, protowire.ParseError(valLen)
317 }
318 b = b[valLen:]
319
320 if depth == 0 {
321 return b0[:len(b0)-len(b)-tagLen], len(b0) - len(b), nil
322 }
323 }
324}