blob: 71c755c15595f27aed6bc60c8d06d219975e59f1 [file] [log] [blame]
amit.ghosh258d14c2020-10-02 15:13:38 +02001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package protojson
6
7import (
8 "encoding/base64"
9 "fmt"
10 "math"
11 "strconv"
12 "strings"
13
14 "google.golang.org/protobuf/internal/encoding/json"
15 "google.golang.org/protobuf/internal/encoding/messageset"
16 "google.golang.org/protobuf/internal/errors"
17 "google.golang.org/protobuf/internal/flags"
18 "google.golang.org/protobuf/internal/pragma"
19 "google.golang.org/protobuf/internal/set"
20 "google.golang.org/protobuf/proto"
21 pref "google.golang.org/protobuf/reflect/protoreflect"
22 "google.golang.org/protobuf/reflect/protoregistry"
23)
24
25// Unmarshal reads the given []byte into the given proto.Message.
26func Unmarshal(b []byte, m proto.Message) error {
27 return UnmarshalOptions{}.Unmarshal(b, m)
28}
29
30// UnmarshalOptions is a configurable JSON format parser.
31type UnmarshalOptions struct {
32 pragma.NoUnkeyedLiterals
33
34 // If AllowPartial is set, input for messages that will result in missing
35 // required fields will not return an error.
36 AllowPartial bool
37
38 // If DiscardUnknown is set, unknown fields are ignored.
39 DiscardUnknown bool
40
41 // Resolver is used for looking up types when unmarshaling
42 // google.protobuf.Any messages or extension fields.
43 // If nil, this defaults to using protoregistry.GlobalTypes.
44 Resolver interface {
45 protoregistry.MessageTypeResolver
46 protoregistry.ExtensionTypeResolver
47 }
48}
49
50// Unmarshal reads the given []byte and populates the given proto.Message using
51// options in UnmarshalOptions object. It will clear the message first before
52// setting the fields. If it returns an error, the given message may be
53// partially set.
54func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
55 proto.Reset(m)
56
57 if o.Resolver == nil {
58 o.Resolver = protoregistry.GlobalTypes
59 }
60
61 dec := decoder{json.NewDecoder(b), o}
62 if err := dec.unmarshalMessage(m.ProtoReflect(), false); err != nil {
63 return err
64 }
65
66 // Check for EOF.
67 tok, err := dec.Read()
68 if err != nil {
69 return err
70 }
71 if tok.Kind() != json.EOF {
72 return dec.unexpectedTokenError(tok)
73 }
74
75 if o.AllowPartial {
76 return nil
77 }
78 return proto.CheckInitialized(m)
79}
80
81type decoder struct {
82 *json.Decoder
83 opts UnmarshalOptions
84}
85
86// newError returns an error object with position info.
87func (d decoder) newError(pos int, f string, x ...interface{}) error {
88 line, column := d.Position(pos)
89 head := fmt.Sprintf("(line %d:%d): ", line, column)
90 return errors.New(head+f, x...)
91}
92
93// unexpectedTokenError returns a syntax error for the given unexpected token.
94func (d decoder) unexpectedTokenError(tok json.Token) error {
95 return d.syntaxError(tok.Pos(), "unexpected token %s", tok.RawString())
96}
97
98// syntaxError returns a syntax error for given position.
99func (d decoder) syntaxError(pos int, f string, x ...interface{}) error {
100 line, column := d.Position(pos)
101 head := fmt.Sprintf("syntax error (line %d:%d): ", line, column)
102 return errors.New(head+f, x...)
103}
104
105// unmarshalMessage unmarshals a message into the given protoreflect.Message.
106func (d decoder) unmarshalMessage(m pref.Message, skipTypeURL bool) error {
107 if isCustomType(m.Descriptor().FullName()) {
108 return d.unmarshalCustomType(m)
109 }
110
111 tok, err := d.Read()
112 if err != nil {
113 return err
114 }
115 if tok.Kind() != json.ObjectOpen {
116 return d.unexpectedTokenError(tok)
117 }
118
119 if err := d.unmarshalFields(m, skipTypeURL); err != nil {
120 return err
121 }
122
123 return nil
124}
125
126// unmarshalFields unmarshals the fields into the given protoreflect.Message.
127func (d decoder) unmarshalFields(m pref.Message, skipTypeURL bool) error {
128 messageDesc := m.Descriptor()
129 if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) {
130 return errors.New("no support for proto1 MessageSets")
131 }
132
133 var seenNums set.Ints
134 var seenOneofs set.Ints
135 fieldDescs := messageDesc.Fields()
136 for {
137 // Read field name.
138 tok, err := d.Read()
139 if err != nil {
140 return err
141 }
142 switch tok.Kind() {
143 default:
144 return d.unexpectedTokenError(tok)
145 case json.ObjectClose:
146 return nil
147 case json.Name:
148 // Continue below.
149 }
150
151 name := tok.Name()
152 // Unmarshaling a non-custom embedded message in Any will contain the
153 // JSON field "@type" which should be skipped because it is not a field
154 // of the embedded message, but simply an artifact of the Any format.
155 if skipTypeURL && name == "@type" {
156 d.Read()
157 continue
158 }
159
160 // Get the FieldDescriptor.
161 var fd pref.FieldDescriptor
162 if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
163 // Only extension names are in [name] format.
164 extName := pref.FullName(name[1 : len(name)-1])
165 extType, err := d.findExtension(extName)
166 if err != nil && err != protoregistry.NotFound {
167 return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err)
168 }
169 if extType != nil {
170 fd = extType.TypeDescriptor()
171 if !messageDesc.ExtensionRanges().Has(fd.Number()) || fd.ContainingMessage().FullName() != messageDesc.FullName() {
172 return d.newError(tok.Pos(), "message %v cannot be extended by %v", messageDesc.FullName(), fd.FullName())
173 }
174 }
175 } else {
176 // The name can either be the JSON name or the proto field name.
177 fd = fieldDescs.ByJSONName(name)
178 if fd == nil {
179 fd = fieldDescs.ByName(pref.Name(name))
180 if fd == nil {
181 // The proto name of a group field is in all lowercase,
182 // while the textual field name is the group message name.
183 gd := fieldDescs.ByName(pref.Name(strings.ToLower(name)))
184 if gd != nil && gd.Kind() == pref.GroupKind && gd.Message().Name() == pref.Name(name) {
185 fd = gd
186 }
187 } else if fd.Kind() == pref.GroupKind && fd.Message().Name() != pref.Name(name) {
188 fd = nil // reset since field name is actually the message name
189 }
190 }
191 }
192 if flags.ProtoLegacy {
193 if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
194 fd = nil // reset since the weak reference is not linked in
195 }
196 }
197
198 if fd == nil {
199 // Field is unknown.
200 if d.opts.DiscardUnknown {
201 if err := d.skipJSONValue(); err != nil {
202 return err
203 }
204 continue
205 }
206 return d.newError(tok.Pos(), "unknown field %v", tok.RawString())
207 }
208
209 // Do not allow duplicate fields.
210 num := uint64(fd.Number())
211 if seenNums.Has(num) {
212 return d.newError(tok.Pos(), "duplicate field %v", tok.RawString())
213 }
214 seenNums.Set(num)
215
216 // No need to set values for JSON null unless the field type is
217 // google.protobuf.Value or google.protobuf.NullValue.
218 if tok, _ := d.Peek(); tok.Kind() == json.Null && !isKnownValue(fd) && !isNullValue(fd) {
219 d.Read()
220 continue
221 }
222
223 switch {
224 case fd.IsList():
225 list := m.Mutable(fd).List()
226 if err := d.unmarshalList(list, fd); err != nil {
227 return err
228 }
229 case fd.IsMap():
230 mmap := m.Mutable(fd).Map()
231 if err := d.unmarshalMap(mmap, fd); err != nil {
232 return err
233 }
234 default:
235 // If field is a oneof, check if it has already been set.
236 if od := fd.ContainingOneof(); od != nil {
237 idx := uint64(od.Index())
238 if seenOneofs.Has(idx) {
239 return d.newError(tok.Pos(), "error parsing %s, oneof %v is already set", tok.RawString(), od.FullName())
240 }
241 seenOneofs.Set(idx)
242 }
243
244 // Required or optional fields.
245 if err := d.unmarshalSingular(m, fd); err != nil {
246 return err
247 }
248 }
249 }
250}
251
252// findExtension returns protoreflect.ExtensionType from the resolver if found.
253func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
254 xt, err := d.opts.Resolver.FindExtensionByName(xtName)
255 if err == nil {
256 return xt, nil
257 }
258 return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
259}
260
261func isKnownValue(fd pref.FieldDescriptor) bool {
262 md := fd.Message()
263 return md != nil && md.FullName() == "google.protobuf.Value"
264}
265
266func isNullValue(fd pref.FieldDescriptor) bool {
267 ed := fd.Enum()
268 return ed != nil && ed.FullName() == "google.protobuf.NullValue"
269}
270
271// unmarshalSingular unmarshals to the non-repeated field specified
272// by the given FieldDescriptor.
273func (d decoder) unmarshalSingular(m pref.Message, fd pref.FieldDescriptor) error {
274 var val pref.Value
275 var err error
276 switch fd.Kind() {
277 case pref.MessageKind, pref.GroupKind:
278 val = m.NewField(fd)
279 err = d.unmarshalMessage(val.Message(), false)
280 default:
281 val, err = d.unmarshalScalar(fd)
282 }
283
284 if err != nil {
285 return err
286 }
287 m.Set(fd, val)
288 return nil
289}
290
291// unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by
292// the given FieldDescriptor.
293func (d decoder) unmarshalScalar(fd pref.FieldDescriptor) (pref.Value, error) {
294 const b32 int = 32
295 const b64 int = 64
296
297 tok, err := d.Read()
298 if err != nil {
299 return pref.Value{}, err
300 }
301
302 kind := fd.Kind()
303 switch kind {
304 case pref.BoolKind:
305 if tok.Kind() == json.Bool {
306 return pref.ValueOfBool(tok.Bool()), nil
307 }
308
309 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
310 if v, ok := unmarshalInt(tok, b32); ok {
311 return v, nil
312 }
313
314 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
315 if v, ok := unmarshalInt(tok, b64); ok {
316 return v, nil
317 }
318
319 case pref.Uint32Kind, pref.Fixed32Kind:
320 if v, ok := unmarshalUint(tok, b32); ok {
321 return v, nil
322 }
323
324 case pref.Uint64Kind, pref.Fixed64Kind:
325 if v, ok := unmarshalUint(tok, b64); ok {
326 return v, nil
327 }
328
329 case pref.FloatKind:
330 if v, ok := unmarshalFloat(tok, b32); ok {
331 return v, nil
332 }
333
334 case pref.DoubleKind:
335 if v, ok := unmarshalFloat(tok, b64); ok {
336 return v, nil
337 }
338
339 case pref.StringKind:
340 if tok.Kind() == json.String {
341 return pref.ValueOfString(tok.ParsedString()), nil
342 }
343
344 case pref.BytesKind:
345 if v, ok := unmarshalBytes(tok); ok {
346 return v, nil
347 }
348
349 case pref.EnumKind:
350 if v, ok := unmarshalEnum(tok, fd); ok {
351 return v, nil
352 }
353
354 default:
355 panic(fmt.Sprintf("unmarshalScalar: invalid scalar kind %v", kind))
356 }
357
358 return pref.Value{}, d.newError(tok.Pos(), "invalid value for %v type: %v", kind, tok.RawString())
359}
360
361func unmarshalInt(tok json.Token, bitSize int) (pref.Value, bool) {
362 switch tok.Kind() {
363 case json.Number:
364 return getInt(tok, bitSize)
365
366 case json.String:
367 // Decode number from string.
368 s := strings.TrimSpace(tok.ParsedString())
369 if len(s) != len(tok.ParsedString()) {
370 return pref.Value{}, false
371 }
372 dec := json.NewDecoder([]byte(s))
373 tok, err := dec.Read()
374 if err != nil {
375 return pref.Value{}, false
376 }
377 return getInt(tok, bitSize)
378 }
379 return pref.Value{}, false
380}
381
382func getInt(tok json.Token, bitSize int) (pref.Value, bool) {
383 n, ok := tok.Int(bitSize)
384 if !ok {
385 return pref.Value{}, false
386 }
387 if bitSize == 32 {
388 return pref.ValueOfInt32(int32(n)), true
389 }
390 return pref.ValueOfInt64(n), true
391}
392
393func unmarshalUint(tok json.Token, bitSize int) (pref.Value, bool) {
394 switch tok.Kind() {
395 case json.Number:
396 return getUint(tok, bitSize)
397
398 case json.String:
399 // Decode number from string.
400 s := strings.TrimSpace(tok.ParsedString())
401 if len(s) != len(tok.ParsedString()) {
402 return pref.Value{}, false
403 }
404 dec := json.NewDecoder([]byte(s))
405 tok, err := dec.Read()
406 if err != nil {
407 return pref.Value{}, false
408 }
409 return getUint(tok, bitSize)
410 }
411 return pref.Value{}, false
412}
413
414func getUint(tok json.Token, bitSize int) (pref.Value, bool) {
415 n, ok := tok.Uint(bitSize)
416 if !ok {
417 return pref.Value{}, false
418 }
419 if bitSize == 32 {
420 return pref.ValueOfUint32(uint32(n)), true
421 }
422 return pref.ValueOfUint64(n), true
423}
424
425func unmarshalFloat(tok json.Token, bitSize int) (pref.Value, bool) {
426 switch tok.Kind() {
427 case json.Number:
428 return getFloat(tok, bitSize)
429
430 case json.String:
431 s := tok.ParsedString()
432 switch s {
433 case "NaN":
434 if bitSize == 32 {
435 return pref.ValueOfFloat32(float32(math.NaN())), true
436 }
437 return pref.ValueOfFloat64(math.NaN()), true
438 case "Infinity":
439 if bitSize == 32 {
440 return pref.ValueOfFloat32(float32(math.Inf(+1))), true
441 }
442 return pref.ValueOfFloat64(math.Inf(+1)), true
443 case "-Infinity":
444 if bitSize == 32 {
445 return pref.ValueOfFloat32(float32(math.Inf(-1))), true
446 }
447 return pref.ValueOfFloat64(math.Inf(-1)), true
448 }
449
450 // Decode number from string.
451 if len(s) != len(strings.TrimSpace(s)) {
452 return pref.Value{}, false
453 }
454 dec := json.NewDecoder([]byte(s))
455 tok, err := dec.Read()
456 if err != nil {
457 return pref.Value{}, false
458 }
459 return getFloat(tok, bitSize)
460 }
461 return pref.Value{}, false
462}
463
464func getFloat(tok json.Token, bitSize int) (pref.Value, bool) {
465 n, ok := tok.Float(bitSize)
466 if !ok {
467 return pref.Value{}, false
468 }
469 if bitSize == 32 {
470 return pref.ValueOfFloat32(float32(n)), true
471 }
472 return pref.ValueOfFloat64(n), true
473}
474
475func unmarshalBytes(tok json.Token) (pref.Value, bool) {
476 if tok.Kind() != json.String {
477 return pref.Value{}, false
478 }
479
480 s := tok.ParsedString()
481 enc := base64.StdEncoding
482 if strings.ContainsAny(s, "-_") {
483 enc = base64.URLEncoding
484 }
485 if len(s)%4 != 0 {
486 enc = enc.WithPadding(base64.NoPadding)
487 }
488 b, err := enc.DecodeString(s)
489 if err != nil {
490 return pref.Value{}, false
491 }
492 return pref.ValueOfBytes(b), true
493}
494
495func unmarshalEnum(tok json.Token, fd pref.FieldDescriptor) (pref.Value, bool) {
496 switch tok.Kind() {
497 case json.String:
498 // Lookup EnumNumber based on name.
499 s := tok.ParsedString()
500 if enumVal := fd.Enum().Values().ByName(pref.Name(s)); enumVal != nil {
501 return pref.ValueOfEnum(enumVal.Number()), true
502 }
503
504 case json.Number:
505 if n, ok := tok.Int(32); ok {
506 return pref.ValueOfEnum(pref.EnumNumber(n)), true
507 }
508
509 case json.Null:
510 // This is only valid for google.protobuf.NullValue.
511 if isNullValue(fd) {
512 return pref.ValueOfEnum(0), true
513 }
514 }
515
516 return pref.Value{}, false
517}
518
519func (d decoder) unmarshalList(list pref.List, fd pref.FieldDescriptor) error {
520 tok, err := d.Read()
521 if err != nil {
522 return err
523 }
524 if tok.Kind() != json.ArrayOpen {
525 return d.unexpectedTokenError(tok)
526 }
527
528 switch fd.Kind() {
529 case pref.MessageKind, pref.GroupKind:
530 for {
531 tok, err := d.Peek()
532 if err != nil {
533 return err
534 }
535
536 if tok.Kind() == json.ArrayClose {
537 d.Read()
538 return nil
539 }
540
541 val := list.NewElement()
542 if err := d.unmarshalMessage(val.Message(), false); err != nil {
543 return err
544 }
545 list.Append(val)
546 }
547 default:
548 for {
549 tok, err := d.Peek()
550 if err != nil {
551 return err
552 }
553
554 if tok.Kind() == json.ArrayClose {
555 d.Read()
556 return nil
557 }
558
559 val, err := d.unmarshalScalar(fd)
560 if err != nil {
561 return err
562 }
563 list.Append(val)
564 }
565 }
566
567 return nil
568}
569
570func (d decoder) unmarshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
571 tok, err := d.Read()
572 if err != nil {
573 return err
574 }
575 if tok.Kind() != json.ObjectOpen {
576 return d.unexpectedTokenError(tok)
577 }
578
579 // Determine ahead whether map entry is a scalar type or a message type in
580 // order to call the appropriate unmarshalMapValue func inside the for loop
581 // below.
582 var unmarshalMapValue func() (pref.Value, error)
583 switch fd.MapValue().Kind() {
584 case pref.MessageKind, pref.GroupKind:
585 unmarshalMapValue = func() (pref.Value, error) {
586 val := mmap.NewValue()
587 if err := d.unmarshalMessage(val.Message(), false); err != nil {
588 return pref.Value{}, err
589 }
590 return val, nil
591 }
592 default:
593 unmarshalMapValue = func() (pref.Value, error) {
594 return d.unmarshalScalar(fd.MapValue())
595 }
596 }
597
598Loop:
599 for {
600 // Read field name.
601 tok, err := d.Read()
602 if err != nil {
603 return err
604 }
605 switch tok.Kind() {
606 default:
607 return d.unexpectedTokenError(tok)
608 case json.ObjectClose:
609 break Loop
610 case json.Name:
611 // Continue.
612 }
613
614 // Unmarshal field name.
615 pkey, err := d.unmarshalMapKey(tok, fd.MapKey())
616 if err != nil {
617 return err
618 }
619
620 // Check for duplicate field name.
621 if mmap.Has(pkey) {
622 return d.newError(tok.Pos(), "duplicate map key %v", tok.RawString())
623 }
624
625 // Read and unmarshal field value.
626 pval, err := unmarshalMapValue()
627 if err != nil {
628 return err
629 }
630
631 mmap.Set(pkey, pval)
632 }
633
634 return nil
635}
636
637// unmarshalMapKey converts given token of Name kind into a protoreflect.MapKey.
638// A map key type is any integral or string type.
639func (d decoder) unmarshalMapKey(tok json.Token, fd pref.FieldDescriptor) (pref.MapKey, error) {
640 const b32 = 32
641 const b64 = 64
642 const base10 = 10
643
644 name := tok.Name()
645 kind := fd.Kind()
646 switch kind {
647 case pref.StringKind:
648 return pref.ValueOfString(name).MapKey(), nil
649
650 case pref.BoolKind:
651 switch name {
652 case "true":
653 return pref.ValueOfBool(true).MapKey(), nil
654 case "false":
655 return pref.ValueOfBool(false).MapKey(), nil
656 }
657
658 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
659 if n, err := strconv.ParseInt(name, base10, b32); err == nil {
660 return pref.ValueOfInt32(int32(n)).MapKey(), nil
661 }
662
663 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
664 if n, err := strconv.ParseInt(name, base10, b64); err == nil {
665 return pref.ValueOfInt64(int64(n)).MapKey(), nil
666 }
667
668 case pref.Uint32Kind, pref.Fixed32Kind:
669 if n, err := strconv.ParseUint(name, base10, b32); err == nil {
670 return pref.ValueOfUint32(uint32(n)).MapKey(), nil
671 }
672
673 case pref.Uint64Kind, pref.Fixed64Kind:
674 if n, err := strconv.ParseUint(name, base10, b64); err == nil {
675 return pref.ValueOfUint64(uint64(n)).MapKey(), nil
676 }
677
678 default:
679 panic(fmt.Sprintf("invalid kind for map key: %v", kind))
680 }
681
682 return pref.MapKey{}, d.newError(tok.Pos(), "invalid value for %v key: %s", kind, tok.RawString())
683}