David K. Bainbridge | bd6b288 | 2021-08-26 13:31:02 +0000 | [diff] [blame] | 1 | // Copyright 2019 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package proto |
| 6 | |
| 7 | import ( |
| 8 | "errors" |
| 9 | "fmt" |
| 10 | |
| 11 | "google.golang.org/protobuf/encoding/prototext" |
| 12 | "google.golang.org/protobuf/encoding/protowire" |
| 13 | "google.golang.org/protobuf/runtime/protoimpl" |
| 14 | ) |
| 15 | |
| 16 | const ( |
| 17 | WireVarint = 0 |
| 18 | WireFixed32 = 5 |
| 19 | WireFixed64 = 1 |
| 20 | WireBytes = 2 |
| 21 | WireStartGroup = 3 |
| 22 | WireEndGroup = 4 |
| 23 | ) |
| 24 | |
| 25 | // EncodeVarint returns the varint encoded bytes of v. |
| 26 | func EncodeVarint(v uint64) []byte { |
| 27 | return protowire.AppendVarint(nil, v) |
| 28 | } |
| 29 | |
| 30 | // SizeVarint returns the length of the varint encoded bytes of v. |
| 31 | // This is equal to len(EncodeVarint(v)). |
| 32 | func SizeVarint(v uint64) int { |
| 33 | return protowire.SizeVarint(v) |
| 34 | } |
| 35 | |
| 36 | // DecodeVarint parses a varint encoded integer from b, |
| 37 | // returning the integer value and the length of the varint. |
| 38 | // It returns (0, 0) if there is a parse error. |
| 39 | func DecodeVarint(b []byte) (uint64, int) { |
| 40 | v, n := protowire.ConsumeVarint(b) |
| 41 | if n < 0 { |
| 42 | return 0, 0 |
| 43 | } |
| 44 | return v, n |
| 45 | } |
| 46 | |
| 47 | // Buffer is a buffer for encoding and decoding the protobuf wire format. |
| 48 | // It may be reused between invocations to reduce memory usage. |
| 49 | type Buffer struct { |
| 50 | buf []byte |
| 51 | idx int |
| 52 | deterministic bool |
| 53 | } |
| 54 | |
| 55 | // NewBuffer allocates a new Buffer initialized with buf, |
| 56 | // where the contents of buf are considered the unread portion of the buffer. |
| 57 | func NewBuffer(buf []byte) *Buffer { |
| 58 | return &Buffer{buf: buf} |
| 59 | } |
| 60 | |
| 61 | // SetDeterministic specifies whether to use deterministic serialization. |
| 62 | // |
| 63 | // Deterministic serialization guarantees that for a given binary, equal |
| 64 | // messages will always be serialized to the same bytes. This implies: |
| 65 | // |
| 66 | // - Repeated serialization of a message will return the same bytes. |
| 67 | // - Different processes of the same binary (which may be executing on |
| 68 | // different machines) will serialize equal messages to the same bytes. |
| 69 | // |
| 70 | // Note that the deterministic serialization is NOT canonical across |
| 71 | // languages. It is not guaranteed to remain stable over time. It is unstable |
| 72 | // across different builds with schema changes due to unknown fields. |
| 73 | // Users who need canonical serialization (e.g., persistent storage in a |
| 74 | // canonical form, fingerprinting, etc.) should define their own |
| 75 | // canonicalization specification and implement their own serializer rather |
| 76 | // than relying on this API. |
| 77 | // |
| 78 | // If deterministic serialization is requested, map entries will be sorted |
| 79 | // by keys in lexographical order. This is an implementation detail and |
| 80 | // subject to change. |
| 81 | func (b *Buffer) SetDeterministic(deterministic bool) { |
| 82 | b.deterministic = deterministic |
| 83 | } |
| 84 | |
| 85 | // SetBuf sets buf as the internal buffer, |
| 86 | // where the contents of buf are considered the unread portion of the buffer. |
| 87 | func (b *Buffer) SetBuf(buf []byte) { |
| 88 | b.buf = buf |
| 89 | b.idx = 0 |
| 90 | } |
| 91 | |
| 92 | // Reset clears the internal buffer of all written and unread data. |
| 93 | func (b *Buffer) Reset() { |
| 94 | b.buf = b.buf[:0] |
| 95 | b.idx = 0 |
| 96 | } |
| 97 | |
| 98 | // Bytes returns the internal buffer. |
| 99 | func (b *Buffer) Bytes() []byte { |
| 100 | return b.buf |
| 101 | } |
| 102 | |
| 103 | // Unread returns the unread portion of the buffer. |
| 104 | func (b *Buffer) Unread() []byte { |
| 105 | return b.buf[b.idx:] |
| 106 | } |
| 107 | |
| 108 | // Marshal appends the wire-format encoding of m to the buffer. |
| 109 | func (b *Buffer) Marshal(m Message) error { |
| 110 | var err error |
| 111 | b.buf, err = marshalAppend(b.buf, m, b.deterministic) |
| 112 | return err |
| 113 | } |
| 114 | |
| 115 | // Unmarshal parses the wire-format message in the buffer and |
| 116 | // places the decoded results in m. |
| 117 | // It does not reset m before unmarshaling. |
| 118 | func (b *Buffer) Unmarshal(m Message) error { |
| 119 | err := UnmarshalMerge(b.Unread(), m) |
| 120 | b.idx = len(b.buf) |
| 121 | return err |
| 122 | } |
| 123 | |
| 124 | type unknownFields struct{ XXX_unrecognized protoimpl.UnknownFields } |
| 125 | |
| 126 | func (m *unknownFields) String() string { panic("not implemented") } |
| 127 | func (m *unknownFields) Reset() { panic("not implemented") } |
| 128 | func (m *unknownFields) ProtoMessage() { panic("not implemented") } |
| 129 | |
| 130 | // DebugPrint dumps the encoded bytes of b with a header and footer including s |
| 131 | // to stdout. This is only intended for debugging. |
| 132 | func (*Buffer) DebugPrint(s string, b []byte) { |
| 133 | m := MessageReflect(new(unknownFields)) |
| 134 | m.SetUnknown(b) |
| 135 | b, _ = prototext.MarshalOptions{AllowPartial: true, Indent: "\t"}.Marshal(m.Interface()) |
| 136 | fmt.Printf("==== %s ====\n%s==== %s ====\n", s, b, s) |
| 137 | } |
| 138 | |
| 139 | // EncodeVarint appends an unsigned varint encoding to the buffer. |
| 140 | func (b *Buffer) EncodeVarint(v uint64) error { |
| 141 | b.buf = protowire.AppendVarint(b.buf, v) |
| 142 | return nil |
| 143 | } |
| 144 | |
| 145 | // EncodeZigzag32 appends a 32-bit zig-zag varint encoding to the buffer. |
| 146 | func (b *Buffer) EncodeZigzag32(v uint64) error { |
| 147 | return b.EncodeVarint(uint64((uint32(v) << 1) ^ uint32((int32(v) >> 31)))) |
| 148 | } |
| 149 | |
| 150 | // EncodeZigzag64 appends a 64-bit zig-zag varint encoding to the buffer. |
| 151 | func (b *Buffer) EncodeZigzag64(v uint64) error { |
| 152 | return b.EncodeVarint(uint64((uint64(v) << 1) ^ uint64((int64(v) >> 63)))) |
| 153 | } |
| 154 | |
| 155 | // EncodeFixed32 appends a 32-bit little-endian integer to the buffer. |
| 156 | func (b *Buffer) EncodeFixed32(v uint64) error { |
| 157 | b.buf = protowire.AppendFixed32(b.buf, uint32(v)) |
| 158 | return nil |
| 159 | } |
| 160 | |
| 161 | // EncodeFixed64 appends a 64-bit little-endian integer to the buffer. |
| 162 | func (b *Buffer) EncodeFixed64(v uint64) error { |
| 163 | b.buf = protowire.AppendFixed64(b.buf, uint64(v)) |
| 164 | return nil |
| 165 | } |
| 166 | |
| 167 | // EncodeRawBytes appends a length-prefixed raw bytes to the buffer. |
| 168 | func (b *Buffer) EncodeRawBytes(v []byte) error { |
| 169 | b.buf = protowire.AppendBytes(b.buf, v) |
| 170 | return nil |
| 171 | } |
| 172 | |
| 173 | // EncodeStringBytes appends a length-prefixed raw bytes to the buffer. |
| 174 | // It does not validate whether v contains valid UTF-8. |
| 175 | func (b *Buffer) EncodeStringBytes(v string) error { |
| 176 | b.buf = protowire.AppendString(b.buf, v) |
| 177 | return nil |
| 178 | } |
| 179 | |
| 180 | // EncodeMessage appends a length-prefixed encoded message to the buffer. |
| 181 | func (b *Buffer) EncodeMessage(m Message) error { |
| 182 | var err error |
| 183 | b.buf = protowire.AppendVarint(b.buf, uint64(Size(m))) |
| 184 | b.buf, err = marshalAppend(b.buf, m, b.deterministic) |
| 185 | return err |
| 186 | } |
| 187 | |
| 188 | // DecodeVarint consumes an encoded unsigned varint from the buffer. |
| 189 | func (b *Buffer) DecodeVarint() (uint64, error) { |
| 190 | v, n := protowire.ConsumeVarint(b.buf[b.idx:]) |
| 191 | if n < 0 { |
| 192 | return 0, protowire.ParseError(n) |
| 193 | } |
| 194 | b.idx += n |
| 195 | return uint64(v), nil |
| 196 | } |
| 197 | |
| 198 | // DecodeZigzag32 consumes an encoded 32-bit zig-zag varint from the buffer. |
| 199 | func (b *Buffer) DecodeZigzag32() (uint64, error) { |
| 200 | v, err := b.DecodeVarint() |
| 201 | if err != nil { |
| 202 | return 0, err |
| 203 | } |
| 204 | return uint64((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31)), nil |
| 205 | } |
| 206 | |
| 207 | // DecodeZigzag64 consumes an encoded 64-bit zig-zag varint from the buffer. |
| 208 | func (b *Buffer) DecodeZigzag64() (uint64, error) { |
| 209 | v, err := b.DecodeVarint() |
| 210 | if err != nil { |
| 211 | return 0, err |
| 212 | } |
| 213 | return uint64((uint64(v) >> 1) ^ uint64((int64(v&1)<<63)>>63)), nil |
| 214 | } |
| 215 | |
| 216 | // DecodeFixed32 consumes a 32-bit little-endian integer from the buffer. |
| 217 | func (b *Buffer) DecodeFixed32() (uint64, error) { |
| 218 | v, n := protowire.ConsumeFixed32(b.buf[b.idx:]) |
| 219 | if n < 0 { |
| 220 | return 0, protowire.ParseError(n) |
| 221 | } |
| 222 | b.idx += n |
| 223 | return uint64(v), nil |
| 224 | } |
| 225 | |
| 226 | // DecodeFixed64 consumes a 64-bit little-endian integer from the buffer. |
| 227 | func (b *Buffer) DecodeFixed64() (uint64, error) { |
| 228 | v, n := protowire.ConsumeFixed64(b.buf[b.idx:]) |
| 229 | if n < 0 { |
| 230 | return 0, protowire.ParseError(n) |
| 231 | } |
| 232 | b.idx += n |
| 233 | return uint64(v), nil |
| 234 | } |
| 235 | |
| 236 | // DecodeRawBytes consumes a length-prefixed raw bytes from the buffer. |
| 237 | // If alloc is specified, it returns a copy the raw bytes |
| 238 | // rather than a sub-slice of the buffer. |
| 239 | func (b *Buffer) DecodeRawBytes(alloc bool) ([]byte, error) { |
| 240 | v, n := protowire.ConsumeBytes(b.buf[b.idx:]) |
| 241 | if n < 0 { |
| 242 | return nil, protowire.ParseError(n) |
| 243 | } |
| 244 | b.idx += n |
| 245 | if alloc { |
| 246 | v = append([]byte(nil), v...) |
| 247 | } |
| 248 | return v, nil |
| 249 | } |
| 250 | |
| 251 | // DecodeStringBytes consumes a length-prefixed raw bytes from the buffer. |
| 252 | // It does not validate whether the raw bytes contain valid UTF-8. |
| 253 | func (b *Buffer) DecodeStringBytes() (string, error) { |
| 254 | v, n := protowire.ConsumeString(b.buf[b.idx:]) |
| 255 | if n < 0 { |
| 256 | return "", protowire.ParseError(n) |
| 257 | } |
| 258 | b.idx += n |
| 259 | return v, nil |
| 260 | } |
| 261 | |
| 262 | // DecodeMessage consumes a length-prefixed message from the buffer. |
| 263 | // It does not reset m before unmarshaling. |
| 264 | func (b *Buffer) DecodeMessage(m Message) error { |
| 265 | v, err := b.DecodeRawBytes(false) |
| 266 | if err != nil { |
| 267 | return err |
| 268 | } |
| 269 | return UnmarshalMerge(v, m) |
| 270 | } |
| 271 | |
| 272 | // DecodeGroup consumes a message group from the buffer. |
| 273 | // It assumes that the start group marker has already been consumed and |
| 274 | // consumes all bytes until (and including the end group marker). |
| 275 | // It does not reset m before unmarshaling. |
| 276 | func (b *Buffer) DecodeGroup(m Message) error { |
| 277 | v, n, err := consumeGroup(b.buf[b.idx:]) |
| 278 | if err != nil { |
| 279 | return err |
| 280 | } |
| 281 | b.idx += n |
| 282 | return UnmarshalMerge(v, m) |
| 283 | } |
| 284 | |
| 285 | // consumeGroup parses b until it finds an end group marker, returning |
| 286 | // the raw bytes of the message (excluding the end group marker) and the |
| 287 | // the total length of the message (including the end group marker). |
| 288 | func consumeGroup(b []byte) ([]byte, int, error) { |
| 289 | b0 := b |
| 290 | depth := 1 // assume this follows a start group marker |
| 291 | for { |
| 292 | _, wtyp, tagLen := protowire.ConsumeTag(b) |
| 293 | if tagLen < 0 { |
| 294 | return nil, 0, protowire.ParseError(tagLen) |
| 295 | } |
| 296 | b = b[tagLen:] |
| 297 | |
| 298 | var valLen int |
| 299 | switch wtyp { |
| 300 | case protowire.VarintType: |
| 301 | _, valLen = protowire.ConsumeVarint(b) |
| 302 | case protowire.Fixed32Type: |
| 303 | _, valLen = protowire.ConsumeFixed32(b) |
| 304 | case protowire.Fixed64Type: |
| 305 | _, valLen = protowire.ConsumeFixed64(b) |
| 306 | case protowire.BytesType: |
| 307 | _, valLen = protowire.ConsumeBytes(b) |
| 308 | case protowire.StartGroupType: |
| 309 | depth++ |
| 310 | case protowire.EndGroupType: |
| 311 | depth-- |
| 312 | default: |
| 313 | return nil, 0, errors.New("proto: cannot parse reserved wire type") |
| 314 | } |
| 315 | if valLen < 0 { |
| 316 | return nil, 0, protowire.ParseError(valLen) |
| 317 | } |
| 318 | b = b[valLen:] |
| 319 | |
| 320 | if depth == 0 { |
| 321 | return b0[:len(b0)-len(b)-tagLen], len(b0) - len(b), nil |
| 322 | } |
| 323 | } |
| 324 | } |