blob: b1fbe7cbc6d6c6d817e54d7ea164d0726d6f8f95 [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001package dynamic
2
3// Binary serialization and de-serialization for dynamic messages
4
5import (
6 "fmt"
7 "io"
8 "math"
9 "reflect"
10 "sort"
11
12 "github.com/golang/protobuf/proto"
13 "github.com/golang/protobuf/protoc-gen-go/descriptor"
14
15 "github.com/jhump/protoreflect/desc"
16)
17
18// defaultDeterminism, if true, will mean that calls to Marshal will produce
19// deterministic output. This is used to make the output of proto.Marshal(...)
20// deterministic (since there is no way to have that convey determinism intent).
21// **This is only used from tests.**
22var defaultDeterminism = false
23
24// Marshal serializes this message to bytes, returning an error if the operation
25// fails. The resulting bytes are in the standard protocol buffer binary format.
26func (m *Message) Marshal() ([]byte, error) {
27 var b codedBuffer
28 if err := m.marshal(&b, defaultDeterminism); err != nil {
29 return nil, err
30 }
31 return b.buf, nil
32}
33
34// MarshalAppend behaves exactly the same as Marshal, except instead of allocating a
35// new byte slice to marshal into, it uses the provided byte slice. The backing array
36// for the returned byte slice *may* be the same as the one that was passed in, but
37// it's not guaranteed as a new backing array will automatically be allocated if
38// more bytes need to be written than the provided buffer has capacity for.
39func (m *Message) MarshalAppend(b []byte) ([]byte, error) {
40 codedBuf := codedBuffer{buf: b}
41 if err := m.marshal(&codedBuf, defaultDeterminism); err != nil {
42 return nil, err
43 }
44 return codedBuf.buf, nil
45}
46
47// MarshalDeterministic serializes this message to bytes in a deterministic way,
48// returning an error if the operation fails. This differs from Marshal in that
49// map keys will be sorted before serializing to bytes. The protobuf spec does
50// not define ordering for map entries, so Marshal will use standard Go map
51// iteration order (which will be random). But for cases where determinism is
52// more important than performance, use this method instead.
53func (m *Message) MarshalDeterministic() ([]byte, error) {
54 var b codedBuffer
55 if err := m.marshal(&b, true); err != nil {
56 return nil, err
57 }
58 return b.buf, nil
59}
60
61func (m *Message) marshal(b *codedBuffer, deterministic bool) error {
62 if err := m.marshalKnownFields(b, deterministic); err != nil {
63 return err
64 }
65 return m.marshalUnknownFields(b)
66}
67
68func (m *Message) marshalKnownFields(b *codedBuffer, deterministic bool) error {
69 for _, tag := range m.knownFieldTags() {
70 itag := int32(tag)
71 val := m.values[itag]
72 fd := m.FindFieldDescriptor(itag)
73 if fd == nil {
74 panic(fmt.Sprintf("Couldn't find field for tag %d", itag))
75 }
76 if err := marshalField(itag, fd, val, b, deterministic); err != nil {
77 return err
78 }
79 }
80 return nil
81}
82
83func (m *Message) marshalUnknownFields(b *codedBuffer) error {
84 for _, tag := range m.unknownFieldTags() {
85 itag := int32(tag)
86 sl := m.unknownFields[itag]
87 for _, u := range sl {
88 if err := b.encodeTagAndWireType(itag, u.Encoding); err != nil {
89 return err
90 }
91 switch u.Encoding {
92 case proto.WireBytes:
93 if err := b.encodeRawBytes(u.Contents); err != nil {
94 return err
95 }
96 case proto.WireStartGroup:
97 b.buf = append(b.buf, u.Contents...)
98 if err := b.encodeTagAndWireType(itag, proto.WireEndGroup); err != nil {
99 return err
100 }
101 case proto.WireFixed32:
102 if err := b.encodeFixed32(u.Value); err != nil {
103 return err
104 }
105 case proto.WireFixed64:
106 if err := b.encodeFixed64(u.Value); err != nil {
107 return err
108 }
109 case proto.WireVarint:
110 if err := b.encodeVarint(u.Value); err != nil {
111 return err
112 }
113 default:
114 return proto.ErrInternalBadWireType
115 }
116 }
117 }
118 return nil
119}
120
121func marshalField(tag int32, fd *desc.FieldDescriptor, val interface{}, b *codedBuffer, deterministic bool) error {
122 if fd.IsMap() {
123 mp := val.(map[interface{}]interface{})
124 entryType := fd.GetMessageType()
125 keyType := entryType.FindFieldByNumber(1)
126 valType := entryType.FindFieldByNumber(2)
127 var entryBuffer codedBuffer
128 if deterministic {
129 keys := make([]interface{}, 0, len(mp))
130 for k := range mp {
131 keys = append(keys, k)
132 }
133 sort.Sort(sortable(keys))
134 for _, k := range keys {
135 v := mp[k]
136 entryBuffer.reset()
137 if err := marshalFieldElement(1, keyType, k, &entryBuffer, deterministic); err != nil {
138 return err
139 }
140 if err := marshalFieldElement(2, valType, v, &entryBuffer, deterministic); err != nil {
141 return err
142 }
143 if err := b.encodeTagAndWireType(tag, proto.WireBytes); err != nil {
144 return err
145 }
146 if err := b.encodeRawBytes(entryBuffer.buf); err != nil {
147 return err
148 }
149 }
150 } else {
151 for k, v := range mp {
152 entryBuffer.reset()
153 if err := marshalFieldElement(1, keyType, k, &entryBuffer, deterministic); err != nil {
154 return err
155 }
156 if err := marshalFieldElement(2, valType, v, &entryBuffer, deterministic); err != nil {
157 return err
158 }
159 if err := b.encodeTagAndWireType(tag, proto.WireBytes); err != nil {
160 return err
161 }
162 if err := b.encodeRawBytes(entryBuffer.buf); err != nil {
163 return err
164 }
165 }
166 }
167 return nil
168 } else if fd.IsRepeated() {
169 sl := val.([]interface{})
170 wt, err := getWireType(fd.GetType())
171 if err != nil {
172 return err
173 }
174 if isPacked(fd) && len(sl) > 1 &&
175 (wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) {
176 // packed repeated field
177 var packedBuffer codedBuffer
178 for _, v := range sl {
179 if err := marshalFieldValue(fd, v, &packedBuffer, deterministic); err != nil {
180 return err
181 }
182 }
183 if err := b.encodeTagAndWireType(tag, proto.WireBytes); err != nil {
184 return err
185 }
186 return b.encodeRawBytes(packedBuffer.buf)
187 } else {
188 // non-packed repeated field
189 for _, v := range sl {
190 if err := marshalFieldElement(tag, fd, v, b, deterministic); err != nil {
191 return err
192 }
193 }
194 return nil
195 }
196 } else {
197 return marshalFieldElement(tag, fd, val, b, deterministic)
198 }
199}
200
201func isPacked(fd *desc.FieldDescriptor) bool {
202 opts := fd.AsFieldDescriptorProto().GetOptions()
203 // if set, use that value
204 if opts != nil && opts.Packed != nil {
205 return opts.GetPacked()
206 }
207 // if unset: proto2 defaults to false, proto3 to true
208 return fd.GetFile().IsProto3()
209}
210
211// sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64),
212// bools, or strings.
213type sortable []interface{}
214
215func (s sortable) Len() int {
216 return len(s)
217}
218
219func (s sortable) Less(i, j int) bool {
220 vi := s[i]
221 vj := s[j]
222 switch reflect.TypeOf(vi).Kind() {
223 case reflect.Int32:
224 return vi.(int32) < vj.(int32)
225 case reflect.Int64:
226 return vi.(int64) < vj.(int64)
227 case reflect.Uint32:
228 return vi.(uint32) < vj.(uint32)
229 case reflect.Uint64:
230 return vi.(uint64) < vj.(uint64)
231 case reflect.String:
232 return vi.(string) < vj.(string)
233 case reflect.Bool:
234 return vi.(bool) && !vj.(bool)
235 default:
236 panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi)))
237 }
238}
239
240func (s sortable) Swap(i, j int) {
241 s[i], s[j] = s[j], s[i]
242}
243
244func marshalFieldElement(tag int32, fd *desc.FieldDescriptor, val interface{}, b *codedBuffer, deterministic bool) error {
245 wt, err := getWireType(fd.GetType())
246 if err != nil {
247 return err
248 }
249 if err := b.encodeTagAndWireType(tag, wt); err != nil {
250 return err
251 }
252 if err := marshalFieldValue(fd, val, b, deterministic); err != nil {
253 return err
254 }
255 if wt == proto.WireStartGroup {
256 return b.encodeTagAndWireType(tag, proto.WireEndGroup)
257 }
258 return nil
259}
260
261func marshalFieldValue(fd *desc.FieldDescriptor, val interface{}, b *codedBuffer, deterministic bool) error {
262 switch fd.GetType() {
263 case descriptor.FieldDescriptorProto_TYPE_BOOL:
264 v := val.(bool)
265 if v {
266 return b.encodeVarint(1)
267 } else {
268 return b.encodeVarint(0)
269 }
270
271 case descriptor.FieldDescriptorProto_TYPE_ENUM,
272 descriptor.FieldDescriptorProto_TYPE_INT32:
273 v := val.(int32)
274 return b.encodeVarint(uint64(v))
275
276 case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
277 v := val.(int32)
278 return b.encodeFixed32(uint64(v))
279
280 case descriptor.FieldDescriptorProto_TYPE_SINT32:
281 v := val.(int32)
282 return b.encodeVarint(encodeZigZag32(v))
283
284 case descriptor.FieldDescriptorProto_TYPE_UINT32:
285 v := val.(uint32)
286 return b.encodeVarint(uint64(v))
287
288 case descriptor.FieldDescriptorProto_TYPE_FIXED32:
289 v := val.(uint32)
290 return b.encodeFixed32(uint64(v))
291
292 case descriptor.FieldDescriptorProto_TYPE_INT64:
293 v := val.(int64)
294 return b.encodeVarint(uint64(v))
295
296 case descriptor.FieldDescriptorProto_TYPE_SFIXED64:
297 v := val.(int64)
298 return b.encodeFixed64(uint64(v))
299
300 case descriptor.FieldDescriptorProto_TYPE_SINT64:
301 v := val.(int64)
302 return b.encodeVarint(encodeZigZag64(v))
303
304 case descriptor.FieldDescriptorProto_TYPE_UINT64:
305 v := val.(uint64)
306 return b.encodeVarint(v)
307
308 case descriptor.FieldDescriptorProto_TYPE_FIXED64:
309 v := val.(uint64)
310 return b.encodeFixed64(v)
311
312 case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
313 v := val.(float64)
314 return b.encodeFixed64(math.Float64bits(v))
315
316 case descriptor.FieldDescriptorProto_TYPE_FLOAT:
317 v := val.(float32)
318 return b.encodeFixed32(uint64(math.Float32bits(v)))
319
320 case descriptor.FieldDescriptorProto_TYPE_BYTES:
321 v := val.([]byte)
322 return b.encodeRawBytes(v)
323
324 case descriptor.FieldDescriptorProto_TYPE_STRING:
325 v := val.(string)
326 return b.encodeRawBytes(([]byte)(v))
327
328 case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
329 m := val.(proto.Message)
330 if bytes, err := proto.Marshal(m); err != nil {
331 return err
332 } else {
333 return b.encodeRawBytes(bytes)
334 }
335
336 case descriptor.FieldDescriptorProto_TYPE_GROUP:
337 // just append the nested message to this buffer
338 dm, ok := val.(*Message)
339 if ok {
340 return dm.marshal(b, deterministic)
341 } else {
342 m := val.(proto.Message)
343 return b.encodeMessage(m)
344 }
345 // whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag
346
347 default:
348 return fmt.Errorf("unrecognized field type: %v", fd.GetType())
349 }
350}
351
352func getWireType(t descriptor.FieldDescriptorProto_Type) (int8, error) {
353 switch t {
354 case descriptor.FieldDescriptorProto_TYPE_ENUM,
355 descriptor.FieldDescriptorProto_TYPE_BOOL,
356 descriptor.FieldDescriptorProto_TYPE_INT32,
357 descriptor.FieldDescriptorProto_TYPE_SINT32,
358 descriptor.FieldDescriptorProto_TYPE_UINT32,
359 descriptor.FieldDescriptorProto_TYPE_INT64,
360 descriptor.FieldDescriptorProto_TYPE_SINT64,
361 descriptor.FieldDescriptorProto_TYPE_UINT64:
362 return proto.WireVarint, nil
363
364 case descriptor.FieldDescriptorProto_TYPE_FIXED32,
365 descriptor.FieldDescriptorProto_TYPE_SFIXED32,
366 descriptor.FieldDescriptorProto_TYPE_FLOAT:
367 return proto.WireFixed32, nil
368
369 case descriptor.FieldDescriptorProto_TYPE_FIXED64,
370 descriptor.FieldDescriptorProto_TYPE_SFIXED64,
371 descriptor.FieldDescriptorProto_TYPE_DOUBLE:
372 return proto.WireFixed64, nil
373
374 case descriptor.FieldDescriptorProto_TYPE_BYTES,
375 descriptor.FieldDescriptorProto_TYPE_STRING,
376 descriptor.FieldDescriptorProto_TYPE_MESSAGE:
377 return proto.WireBytes, nil
378
379 case descriptor.FieldDescriptorProto_TYPE_GROUP:
380 return proto.WireStartGroup, nil
381
382 default:
383 return 0, proto.ErrInternalBadWireType
384 }
385}
386
387// Unmarshal de-serializes the message that is present in the given bytes into
388// this message. It first resets the current message. It returns an error if the
389// given bytes do not contain a valid encoding of this message type.
390func (m *Message) Unmarshal(b []byte) error {
391 m.Reset()
392 if err := m.UnmarshalMerge(b); err != nil {
393 return err
394 }
395 return m.Validate()
396}
397
398// UnmarshalMerge de-serializes the message that is present in the given bytes
399// into this message. Unlike Unmarshal, it does not first reset the message,
400// instead merging the data in the given bytes into the existing data in this
401// message.
402func (m *Message) UnmarshalMerge(b []byte) error {
403 return m.unmarshal(newCodedBuffer(b), false)
404}
405
406func (m *Message) unmarshal(buf *codedBuffer, isGroup bool) error {
407 for !buf.eof() {
408 tagNumber, wireType, err := buf.decodeTagAndWireType()
409 if err != nil {
410 return err
411 }
412 if wireType == proto.WireEndGroup {
413 if isGroup {
414 // finished parsing group
415 return nil
416 } else {
417 return proto.ErrInternalBadWireType
418 }
419 }
420 fd := m.FindFieldDescriptor(tagNumber)
421 if fd == nil {
422 err := m.unmarshalUnknownField(tagNumber, wireType, buf)
423 if err != nil {
424 return err
425 }
426 } else {
427 err := m.unmarshalKnownField(fd, wireType, buf)
428 if err != nil {
429 return err
430 }
431 }
432 }
433 if isGroup {
434 return io.ErrUnexpectedEOF
435 }
436 return nil
437}
438
439func unmarshalSimpleField(fd *desc.FieldDescriptor, v uint64) (interface{}, error) {
440 switch fd.GetType() {
441 case descriptor.FieldDescriptorProto_TYPE_BOOL:
442 return v != 0, nil
443 case descriptor.FieldDescriptorProto_TYPE_UINT32,
444 descriptor.FieldDescriptorProto_TYPE_FIXED32:
445 if v > math.MaxUint32 {
446 return nil, NumericOverflowError
447 }
448 return uint32(v), nil
449
450 case descriptor.FieldDescriptorProto_TYPE_INT32,
451 descriptor.FieldDescriptorProto_TYPE_ENUM:
452 s := int64(v)
453 if s > math.MaxInt32 || s < math.MinInt32 {
454 return nil, NumericOverflowError
455 }
456 return int32(s), nil
457
458 case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
459 if v > math.MaxUint32 {
460 return nil, NumericOverflowError
461 }
462 return int32(v), nil
463
464 case descriptor.FieldDescriptorProto_TYPE_SINT32:
465 if v > math.MaxUint32 {
466 return nil, NumericOverflowError
467 }
468 return decodeZigZag32(v), nil
469
470 case descriptor.FieldDescriptorProto_TYPE_UINT64,
471 descriptor.FieldDescriptorProto_TYPE_FIXED64:
472 return v, nil
473
474 case descriptor.FieldDescriptorProto_TYPE_INT64,
475 descriptor.FieldDescriptorProto_TYPE_SFIXED64:
476 return int64(v), nil
477
478 case descriptor.FieldDescriptorProto_TYPE_SINT64:
479 return decodeZigZag64(v), nil
480
481 case descriptor.FieldDescriptorProto_TYPE_FLOAT:
482 if v > math.MaxUint32 {
483 return nil, NumericOverflowError
484 }
485 return math.Float32frombits(uint32(v)), nil
486
487 case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
488 return math.Float64frombits(v), nil
489
490 default:
491 // bytes, string, message, and group cannot be represented as a simple numeric value
492 return nil, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
493 }
494}
495
496func unmarshalLengthDelimitedField(fd *desc.FieldDescriptor, bytes []byte, mf *MessageFactory) (interface{}, error) {
497 switch {
498 case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES:
499 return bytes, nil
500
501 case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_STRING:
502 return string(bytes), nil
503
504 case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE ||
505 fd.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP:
506 msg := mf.NewMessage(fd.GetMessageType())
507 err := proto.Unmarshal(bytes, msg)
508 if err != nil {
509 return nil, err
510 } else {
511 return msg, nil
512 }
513
514 default:
515 // even if the field is not repeated or not packed, we still parse it as such for
516 // backwards compatibility (e.g. message we are de-serializing could have been both
517 // repeated and packed at the time of serialization)
518 packedBuf := newCodedBuffer(bytes)
519 var slice []interface{}
520 var val interface{}
521 for !packedBuf.eof() {
522 var v uint64
523 var err error
524 if varintTypes[fd.GetType()] {
525 v, err = packedBuf.decodeVarint()
526 } else if fixed32Types[fd.GetType()] {
527 v, err = packedBuf.decodeFixed32()
528 } else if fixed64Types[fd.GetType()] {
529 v, err = packedBuf.decodeFixed64()
530 } else {
531 return nil, fmt.Errorf("bad input; cannot parse length-delimited wire type for field %s", fd.GetFullyQualifiedName())
532 }
533 if err != nil {
534 return nil, err
535 }
536 val, err = unmarshalSimpleField(fd, v)
537 if err != nil {
538 return nil, err
539 }
540 if fd.IsRepeated() {
541 slice = append(slice, val)
542 }
543 }
544 if fd.IsRepeated() {
545 return slice, nil
546 } else {
547 // if not a repeated field, last value wins
548 return val, nil
549 }
550 }
551}
552
553func (m *Message) unmarshalKnownField(fd *desc.FieldDescriptor, encoding int8, b *codedBuffer) error {
554 var val interface{}
555 var err error
556 switch encoding {
557 case proto.WireFixed32:
558 var num uint64
559 num, err = b.decodeFixed32()
560 if err == nil {
561 val, err = unmarshalSimpleField(fd, num)
562 }
563 case proto.WireFixed64:
564 var num uint64
565 num, err = b.decodeFixed64()
566 if err == nil {
567 val, err = unmarshalSimpleField(fd, num)
568 }
569 case proto.WireVarint:
570 var num uint64
571 num, err = b.decodeVarint()
572 if err == nil {
573 val, err = unmarshalSimpleField(fd, num)
574 }
575
576 case proto.WireBytes:
577 if fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES {
578 val, err = b.decodeRawBytes(true) // defensive copy
579 } else if fd.GetType() == descriptor.FieldDescriptorProto_TYPE_STRING {
580 var raw []byte
581 raw, err = b.decodeRawBytes(true) // defensive copy
582 if err == nil {
583 val = string(raw)
584 }
585 } else {
586 var raw []byte
587 raw, err = b.decodeRawBytes(false)
588 if err == nil {
589 val, err = unmarshalLengthDelimitedField(fd, raw, m.mf)
590 }
591 }
592
593 case proto.WireStartGroup:
594 if fd.GetMessageType() == nil {
595 return fmt.Errorf("cannot parse field %s from group-encoded wire type", fd.GetFullyQualifiedName())
596 }
597 msg := m.mf.NewMessage(fd.GetMessageType())
598 if dm, ok := msg.(*Message); ok {
599 err = dm.unmarshal(b, true)
600 if err == nil {
601 val = dm
602 }
603 } else {
604 var groupEnd, dataEnd int
605 groupEnd, dataEnd, err = skipGroup(b)
606 if err == nil {
607 err = proto.Unmarshal(b.buf[b.index:dataEnd], msg)
608 if err == nil {
609 val = msg
610 }
611 b.index = groupEnd
612 }
613 }
614
615 default:
616 return proto.ErrInternalBadWireType
617 }
618 if err != nil {
619 return err
620 }
621
622 return mergeField(m, fd, val)
623}
624
625func (m *Message) unmarshalUnknownField(tagNumber int32, encoding int8, b *codedBuffer) error {
626 u := UnknownField{Encoding: encoding}
627 var err error
628 switch encoding {
629 case proto.WireFixed32:
630 u.Value, err = b.decodeFixed32()
631 case proto.WireFixed64:
632 u.Value, err = b.decodeFixed64()
633 case proto.WireVarint:
634 u.Value, err = b.decodeVarint()
635 case proto.WireBytes:
636 u.Contents, err = b.decodeRawBytes(true)
637 case proto.WireStartGroup:
638 var groupEnd, dataEnd int
639 groupEnd, dataEnd, err = skipGroup(b)
640 if err == nil {
641 u.Contents = make([]byte, dataEnd-b.index)
642 copy(u.Contents, b.buf[b.index:])
643 b.index = groupEnd
644 }
645 default:
646 err = proto.ErrInternalBadWireType
647 }
648 if err != nil {
649 return err
650 }
651 if m.unknownFields == nil {
652 m.unknownFields = map[int32][]UnknownField{}
653 }
654 m.unknownFields[tagNumber] = append(m.unknownFields[tagNumber], u)
655 return nil
656}
657
658func skipGroup(b *codedBuffer) (int, int, error) {
659 bs := b.buf
660 start := b.index
661 defer func() {
662 b.index = start
663 }()
664 for {
665 fieldStart := b.index
666 // read a field tag
667 _, wireType, err := b.decodeTagAndWireType()
668 if err != nil {
669 return 0, 0, err
670 }
671 // skip past the field's data
672 switch wireType {
673 case proto.WireFixed32:
674 if !b.skip(4) {
675 return 0, 0, io.ErrUnexpectedEOF
676 }
677 case proto.WireFixed64:
678 if !b.skip(8) {
679 return 0, 0, io.ErrUnexpectedEOF
680 }
681 case proto.WireVarint:
682 // skip varint by finding last byte (has high bit unset)
683 i := b.index
684 for {
685 if i >= len(bs) {
686 return 0, 0, io.ErrUnexpectedEOF
687 }
688 if bs[i]&0x80 == 0 {
689 break
690 }
691 i++
692 }
693 b.index = i + 1
694 case proto.WireBytes:
695 l, err := b.decodeVarint()
696 if err != nil {
697 return 0, 0, err
698 }
699 if !b.skip(int(l)) {
700 return 0, 0, io.ErrUnexpectedEOF
701 }
702 case proto.WireStartGroup:
703 endIndex, _, err := skipGroup(b)
704 if err != nil {
705 return 0, 0, err
706 }
707 b.index = endIndex
708 case proto.WireEndGroup:
709 return b.index, fieldStart, nil
710 default:
711 return 0, 0, proto.ErrInternalBadWireType
712 }
713 }
714}