blob: 07da5db3450e82404c8995fbe37faa270d161384 [file] [log] [blame]
khenaidoo106c61a2021-08-11 18:05:46 -04001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package protojson
6
7import (
8 "encoding/base64"
9 "fmt"
10 "math"
11 "strconv"
12 "strings"
13
14 "google.golang.org/protobuf/internal/encoding/json"
15 "google.golang.org/protobuf/internal/encoding/messageset"
16 "google.golang.org/protobuf/internal/errors"
17 "google.golang.org/protobuf/internal/flags"
18 "google.golang.org/protobuf/internal/genid"
19 "google.golang.org/protobuf/internal/pragma"
20 "google.golang.org/protobuf/internal/set"
21 "google.golang.org/protobuf/proto"
22 pref "google.golang.org/protobuf/reflect/protoreflect"
23 "google.golang.org/protobuf/reflect/protoregistry"
24)
25
26// Unmarshal reads the given []byte into the given proto.Message.
27// The provided message must be mutable (e.g., a non-nil pointer to a message).
28func Unmarshal(b []byte, m proto.Message) error {
29 return UnmarshalOptions{}.Unmarshal(b, m)
30}
31
32// UnmarshalOptions is a configurable JSON format parser.
33type UnmarshalOptions struct {
34 pragma.NoUnkeyedLiterals
35
36 // If AllowPartial is set, input for messages that will result in missing
37 // required fields will not return an error.
38 AllowPartial bool
39
40 // If DiscardUnknown is set, unknown fields are ignored.
41 DiscardUnknown bool
42
43 // Resolver is used for looking up types when unmarshaling
44 // google.protobuf.Any messages or extension fields.
45 // If nil, this defaults to using protoregistry.GlobalTypes.
46 Resolver interface {
47 protoregistry.MessageTypeResolver
48 protoregistry.ExtensionTypeResolver
49 }
50}
51
52// Unmarshal reads the given []byte and populates the given proto.Message
53// using options in the UnmarshalOptions object.
54// It will clear the message first before setting the fields.
55// If it returns an error, the given message may be partially set.
56// The provided message must be mutable (e.g., a non-nil pointer to a message).
57func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
58 return o.unmarshal(b, m)
59}
60
61// unmarshal is a centralized function that all unmarshal operations go through.
62// For profiling purposes, avoid changing the name of this function or
63// introducing other code paths for unmarshal that do not go through this.
64func (o UnmarshalOptions) unmarshal(b []byte, m proto.Message) error {
65 proto.Reset(m)
66
67 if o.Resolver == nil {
68 o.Resolver = protoregistry.GlobalTypes
69 }
70
71 dec := decoder{json.NewDecoder(b), o}
72 if err := dec.unmarshalMessage(m.ProtoReflect(), false); err != nil {
73 return err
74 }
75
76 // Check for EOF.
77 tok, err := dec.Read()
78 if err != nil {
79 return err
80 }
81 if tok.Kind() != json.EOF {
82 return dec.unexpectedTokenError(tok)
83 }
84
85 if o.AllowPartial {
86 return nil
87 }
88 return proto.CheckInitialized(m)
89}
90
91type decoder struct {
92 *json.Decoder
93 opts UnmarshalOptions
94}
95
96// newError returns an error object with position info.
97func (d decoder) newError(pos int, f string, x ...interface{}) error {
98 line, column := d.Position(pos)
99 head := fmt.Sprintf("(line %d:%d): ", line, column)
100 return errors.New(head+f, x...)
101}
102
103// unexpectedTokenError returns a syntax error for the given unexpected token.
104func (d decoder) unexpectedTokenError(tok json.Token) error {
105 return d.syntaxError(tok.Pos(), "unexpected token %s", tok.RawString())
106}
107
108// syntaxError returns a syntax error for given position.
109func (d decoder) syntaxError(pos int, f string, x ...interface{}) error {
110 line, column := d.Position(pos)
111 head := fmt.Sprintf("syntax error (line %d:%d): ", line, column)
112 return errors.New(head+f, x...)
113}
114
115// unmarshalMessage unmarshals a message into the given protoreflect.Message.
116func (d decoder) unmarshalMessage(m pref.Message, skipTypeURL bool) error {
117 if unmarshal := wellKnownTypeUnmarshaler(m.Descriptor().FullName()); unmarshal != nil {
118 return unmarshal(d, m)
119 }
120
121 tok, err := d.Read()
122 if err != nil {
123 return err
124 }
125 if tok.Kind() != json.ObjectOpen {
126 return d.unexpectedTokenError(tok)
127 }
128
129 messageDesc := m.Descriptor()
130 if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) {
131 return errors.New("no support for proto1 MessageSets")
132 }
133
134 var seenNums set.Ints
135 var seenOneofs set.Ints
136 fieldDescs := messageDesc.Fields()
137 for {
138 // Read field name.
139 tok, err := d.Read()
140 if err != nil {
141 return err
142 }
143 switch tok.Kind() {
144 default:
145 return d.unexpectedTokenError(tok)
146 case json.ObjectClose:
147 return nil
148 case json.Name:
149 // Continue below.
150 }
151
152 name := tok.Name()
153 // Unmarshaling a non-custom embedded message in Any will contain the
154 // JSON field "@type" which should be skipped because it is not a field
155 // of the embedded message, but simply an artifact of the Any format.
156 if skipTypeURL && name == "@type" {
157 d.Read()
158 continue
159 }
160
161 // Get the FieldDescriptor.
162 var fd pref.FieldDescriptor
163 if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
164 // Only extension names are in [name] format.
165 extName := pref.FullName(name[1 : len(name)-1])
166 extType, err := d.opts.Resolver.FindExtensionByName(extName)
167 if err != nil && err != protoregistry.NotFound {
168 return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err)
169 }
170 if extType != nil {
171 fd = extType.TypeDescriptor()
172 if !messageDesc.ExtensionRanges().Has(fd.Number()) || fd.ContainingMessage().FullName() != messageDesc.FullName() {
173 return d.newError(tok.Pos(), "message %v cannot be extended by %v", messageDesc.FullName(), fd.FullName())
174 }
175 }
176 } else {
177 // The name can either be the JSON name or the proto field name.
178 fd = fieldDescs.ByJSONName(name)
179 if fd == nil {
180 fd = fieldDescs.ByTextName(name)
181 }
182 }
183 if flags.ProtoLegacy {
184 if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
185 fd = nil // reset since the weak reference is not linked in
186 }
187 }
188
189 if fd == nil {
190 // Field is unknown.
191 if d.opts.DiscardUnknown {
192 if err := d.skipJSONValue(); err != nil {
193 return err
194 }
195 continue
196 }
197 return d.newError(tok.Pos(), "unknown field %v", tok.RawString())
198 }
199
200 // Do not allow duplicate fields.
201 num := uint64(fd.Number())
202 if seenNums.Has(num) {
203 return d.newError(tok.Pos(), "duplicate field %v", tok.RawString())
204 }
205 seenNums.Set(num)
206
207 // No need to set values for JSON null unless the field type is
208 // google.protobuf.Value or google.protobuf.NullValue.
209 if tok, _ := d.Peek(); tok.Kind() == json.Null && !isKnownValue(fd) && !isNullValue(fd) {
210 d.Read()
211 continue
212 }
213
214 switch {
215 case fd.IsList():
216 list := m.Mutable(fd).List()
217 if err := d.unmarshalList(list, fd); err != nil {
218 return err
219 }
220 case fd.IsMap():
221 mmap := m.Mutable(fd).Map()
222 if err := d.unmarshalMap(mmap, fd); err != nil {
223 return err
224 }
225 default:
226 // If field is a oneof, check if it has already been set.
227 if od := fd.ContainingOneof(); od != nil {
228 idx := uint64(od.Index())
229 if seenOneofs.Has(idx) {
230 return d.newError(tok.Pos(), "error parsing %s, oneof %v is already set", tok.RawString(), od.FullName())
231 }
232 seenOneofs.Set(idx)
233 }
234
235 // Required or optional fields.
236 if err := d.unmarshalSingular(m, fd); err != nil {
237 return err
238 }
239 }
240 }
241}
242
243func isKnownValue(fd pref.FieldDescriptor) bool {
244 md := fd.Message()
245 return md != nil && md.FullName() == genid.Value_message_fullname
246}
247
248func isNullValue(fd pref.FieldDescriptor) bool {
249 ed := fd.Enum()
250 return ed != nil && ed.FullName() == genid.NullValue_enum_fullname
251}
252
253// unmarshalSingular unmarshals to the non-repeated field specified
254// by the given FieldDescriptor.
255func (d decoder) unmarshalSingular(m pref.Message, fd pref.FieldDescriptor) error {
256 var val pref.Value
257 var err error
258 switch fd.Kind() {
259 case pref.MessageKind, pref.GroupKind:
260 val = m.NewField(fd)
261 err = d.unmarshalMessage(val.Message(), false)
262 default:
263 val, err = d.unmarshalScalar(fd)
264 }
265
266 if err != nil {
267 return err
268 }
269 m.Set(fd, val)
270 return nil
271}
272
273// unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by
274// the given FieldDescriptor.
275func (d decoder) unmarshalScalar(fd pref.FieldDescriptor) (pref.Value, error) {
276 const b32 int = 32
277 const b64 int = 64
278
279 tok, err := d.Read()
280 if err != nil {
281 return pref.Value{}, err
282 }
283
284 kind := fd.Kind()
285 switch kind {
286 case pref.BoolKind:
287 if tok.Kind() == json.Bool {
288 return pref.ValueOfBool(tok.Bool()), nil
289 }
290
291 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
292 if v, ok := unmarshalInt(tok, b32); ok {
293 return v, nil
294 }
295
296 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
297 if v, ok := unmarshalInt(tok, b64); ok {
298 return v, nil
299 }
300
301 case pref.Uint32Kind, pref.Fixed32Kind:
302 if v, ok := unmarshalUint(tok, b32); ok {
303 return v, nil
304 }
305
306 case pref.Uint64Kind, pref.Fixed64Kind:
307 if v, ok := unmarshalUint(tok, b64); ok {
308 return v, nil
309 }
310
311 case pref.FloatKind:
312 if v, ok := unmarshalFloat(tok, b32); ok {
313 return v, nil
314 }
315
316 case pref.DoubleKind:
317 if v, ok := unmarshalFloat(tok, b64); ok {
318 return v, nil
319 }
320
321 case pref.StringKind:
322 if tok.Kind() == json.String {
323 return pref.ValueOfString(tok.ParsedString()), nil
324 }
325
326 case pref.BytesKind:
327 if v, ok := unmarshalBytes(tok); ok {
328 return v, nil
329 }
330
331 case pref.EnumKind:
332 if v, ok := unmarshalEnum(tok, fd); ok {
333 return v, nil
334 }
335
336 default:
337 panic(fmt.Sprintf("unmarshalScalar: invalid scalar kind %v", kind))
338 }
339
340 return pref.Value{}, d.newError(tok.Pos(), "invalid value for %v type: %v", kind, tok.RawString())
341}
342
343func unmarshalInt(tok json.Token, bitSize int) (pref.Value, bool) {
344 switch tok.Kind() {
345 case json.Number:
346 return getInt(tok, bitSize)
347
348 case json.String:
349 // Decode number from string.
350 s := strings.TrimSpace(tok.ParsedString())
351 if len(s) != len(tok.ParsedString()) {
352 return pref.Value{}, false
353 }
354 dec := json.NewDecoder([]byte(s))
355 tok, err := dec.Read()
356 if err != nil {
357 return pref.Value{}, false
358 }
359 return getInt(tok, bitSize)
360 }
361 return pref.Value{}, false
362}
363
364func getInt(tok json.Token, bitSize int) (pref.Value, bool) {
365 n, ok := tok.Int(bitSize)
366 if !ok {
367 return pref.Value{}, false
368 }
369 if bitSize == 32 {
370 return pref.ValueOfInt32(int32(n)), true
371 }
372 return pref.ValueOfInt64(n), true
373}
374
375func unmarshalUint(tok json.Token, bitSize int) (pref.Value, bool) {
376 switch tok.Kind() {
377 case json.Number:
378 return getUint(tok, bitSize)
379
380 case json.String:
381 // Decode number from string.
382 s := strings.TrimSpace(tok.ParsedString())
383 if len(s) != len(tok.ParsedString()) {
384 return pref.Value{}, false
385 }
386 dec := json.NewDecoder([]byte(s))
387 tok, err := dec.Read()
388 if err != nil {
389 return pref.Value{}, false
390 }
391 return getUint(tok, bitSize)
392 }
393 return pref.Value{}, false
394}
395
396func getUint(tok json.Token, bitSize int) (pref.Value, bool) {
397 n, ok := tok.Uint(bitSize)
398 if !ok {
399 return pref.Value{}, false
400 }
401 if bitSize == 32 {
402 return pref.ValueOfUint32(uint32(n)), true
403 }
404 return pref.ValueOfUint64(n), true
405}
406
407func unmarshalFloat(tok json.Token, bitSize int) (pref.Value, bool) {
408 switch tok.Kind() {
409 case json.Number:
410 return getFloat(tok, bitSize)
411
412 case json.String:
413 s := tok.ParsedString()
414 switch s {
415 case "NaN":
416 if bitSize == 32 {
417 return pref.ValueOfFloat32(float32(math.NaN())), true
418 }
419 return pref.ValueOfFloat64(math.NaN()), true
420 case "Infinity":
421 if bitSize == 32 {
422 return pref.ValueOfFloat32(float32(math.Inf(+1))), true
423 }
424 return pref.ValueOfFloat64(math.Inf(+1)), true
425 case "-Infinity":
426 if bitSize == 32 {
427 return pref.ValueOfFloat32(float32(math.Inf(-1))), true
428 }
429 return pref.ValueOfFloat64(math.Inf(-1)), true
430 }
431
432 // Decode number from string.
433 if len(s) != len(strings.TrimSpace(s)) {
434 return pref.Value{}, false
435 }
436 dec := json.NewDecoder([]byte(s))
437 tok, err := dec.Read()
438 if err != nil {
439 return pref.Value{}, false
440 }
441 return getFloat(tok, bitSize)
442 }
443 return pref.Value{}, false
444}
445
446func getFloat(tok json.Token, bitSize int) (pref.Value, bool) {
447 n, ok := tok.Float(bitSize)
448 if !ok {
449 return pref.Value{}, false
450 }
451 if bitSize == 32 {
452 return pref.ValueOfFloat32(float32(n)), true
453 }
454 return pref.ValueOfFloat64(n), true
455}
456
457func unmarshalBytes(tok json.Token) (pref.Value, bool) {
458 if tok.Kind() != json.String {
459 return pref.Value{}, false
460 }
461
462 s := tok.ParsedString()
463 enc := base64.StdEncoding
464 if strings.ContainsAny(s, "-_") {
465 enc = base64.URLEncoding
466 }
467 if len(s)%4 != 0 {
468 enc = enc.WithPadding(base64.NoPadding)
469 }
470 b, err := enc.DecodeString(s)
471 if err != nil {
472 return pref.Value{}, false
473 }
474 return pref.ValueOfBytes(b), true
475}
476
477func unmarshalEnum(tok json.Token, fd pref.FieldDescriptor) (pref.Value, bool) {
478 switch tok.Kind() {
479 case json.String:
480 // Lookup EnumNumber based on name.
481 s := tok.ParsedString()
482 if enumVal := fd.Enum().Values().ByName(pref.Name(s)); enumVal != nil {
483 return pref.ValueOfEnum(enumVal.Number()), true
484 }
485
486 case json.Number:
487 if n, ok := tok.Int(32); ok {
488 return pref.ValueOfEnum(pref.EnumNumber(n)), true
489 }
490
491 case json.Null:
492 // This is only valid for google.protobuf.NullValue.
493 if isNullValue(fd) {
494 return pref.ValueOfEnum(0), true
495 }
496 }
497
498 return pref.Value{}, false
499}
500
501func (d decoder) unmarshalList(list pref.List, fd pref.FieldDescriptor) error {
502 tok, err := d.Read()
503 if err != nil {
504 return err
505 }
506 if tok.Kind() != json.ArrayOpen {
507 return d.unexpectedTokenError(tok)
508 }
509
510 switch fd.Kind() {
511 case pref.MessageKind, pref.GroupKind:
512 for {
513 tok, err := d.Peek()
514 if err != nil {
515 return err
516 }
517
518 if tok.Kind() == json.ArrayClose {
519 d.Read()
520 return nil
521 }
522
523 val := list.NewElement()
524 if err := d.unmarshalMessage(val.Message(), false); err != nil {
525 return err
526 }
527 list.Append(val)
528 }
529 default:
530 for {
531 tok, err := d.Peek()
532 if err != nil {
533 return err
534 }
535
536 if tok.Kind() == json.ArrayClose {
537 d.Read()
538 return nil
539 }
540
541 val, err := d.unmarshalScalar(fd)
542 if err != nil {
543 return err
544 }
545 list.Append(val)
546 }
547 }
548
549 return nil
550}
551
552func (d decoder) unmarshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
553 tok, err := d.Read()
554 if err != nil {
555 return err
556 }
557 if tok.Kind() != json.ObjectOpen {
558 return d.unexpectedTokenError(tok)
559 }
560
561 // Determine ahead whether map entry is a scalar type or a message type in
562 // order to call the appropriate unmarshalMapValue func inside the for loop
563 // below.
564 var unmarshalMapValue func() (pref.Value, error)
565 switch fd.MapValue().Kind() {
566 case pref.MessageKind, pref.GroupKind:
567 unmarshalMapValue = func() (pref.Value, error) {
568 val := mmap.NewValue()
569 if err := d.unmarshalMessage(val.Message(), false); err != nil {
570 return pref.Value{}, err
571 }
572 return val, nil
573 }
574 default:
575 unmarshalMapValue = func() (pref.Value, error) {
576 return d.unmarshalScalar(fd.MapValue())
577 }
578 }
579
580Loop:
581 for {
582 // Read field name.
583 tok, err := d.Read()
584 if err != nil {
585 return err
586 }
587 switch tok.Kind() {
588 default:
589 return d.unexpectedTokenError(tok)
590 case json.ObjectClose:
591 break Loop
592 case json.Name:
593 // Continue.
594 }
595
596 // Unmarshal field name.
597 pkey, err := d.unmarshalMapKey(tok, fd.MapKey())
598 if err != nil {
599 return err
600 }
601
602 // Check for duplicate field name.
603 if mmap.Has(pkey) {
604 return d.newError(tok.Pos(), "duplicate map key %v", tok.RawString())
605 }
606
607 // Read and unmarshal field value.
608 pval, err := unmarshalMapValue()
609 if err != nil {
610 return err
611 }
612
613 mmap.Set(pkey, pval)
614 }
615
616 return nil
617}
618
619// unmarshalMapKey converts given token of Name kind into a protoreflect.MapKey.
620// A map key type is any integral or string type.
621func (d decoder) unmarshalMapKey(tok json.Token, fd pref.FieldDescriptor) (pref.MapKey, error) {
622 const b32 = 32
623 const b64 = 64
624 const base10 = 10
625
626 name := tok.Name()
627 kind := fd.Kind()
628 switch kind {
629 case pref.StringKind:
630 return pref.ValueOfString(name).MapKey(), nil
631
632 case pref.BoolKind:
633 switch name {
634 case "true":
635 return pref.ValueOfBool(true).MapKey(), nil
636 case "false":
637 return pref.ValueOfBool(false).MapKey(), nil
638 }
639
640 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
641 if n, err := strconv.ParseInt(name, base10, b32); err == nil {
642 return pref.ValueOfInt32(int32(n)).MapKey(), nil
643 }
644
645 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
646 if n, err := strconv.ParseInt(name, base10, b64); err == nil {
647 return pref.ValueOfInt64(int64(n)).MapKey(), nil
648 }
649
650 case pref.Uint32Kind, pref.Fixed32Kind:
651 if n, err := strconv.ParseUint(name, base10, b32); err == nil {
652 return pref.ValueOfUint32(uint32(n)).MapKey(), nil
653 }
654
655 case pref.Uint64Kind, pref.Fixed64Kind:
656 if n, err := strconv.ParseUint(name, base10, b64); err == nil {
657 return pref.ValueOfUint64(uint64(n)).MapKey(), nil
658 }
659
660 default:
661 panic(fmt.Sprintf("invalid kind for map key: %v", kind))
662 }
663
664 return pref.MapKey{}, d.newError(tok.Pos(), "invalid value for %v key: %s", kind, tok.RawString())
665}