blob: 60e82caa9a2d30ac3d1abbd2b9ab6049f33a1a36 [file] [log] [blame]
khenaidoof3333552021-12-15 16:52:31 -05001// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jsonpb
6
7import (
8 "encoding/json"
9 "errors"
10 "fmt"
11 "io"
12 "math"
13 "reflect"
14 "strconv"
15 "strings"
16 "time"
17
18 "github.com/golang/protobuf/proto"
19 "google.golang.org/protobuf/encoding/protojson"
20 protoV2 "google.golang.org/protobuf/proto"
21 "google.golang.org/protobuf/reflect/protoreflect"
22 "google.golang.org/protobuf/reflect/protoregistry"
23)
24
25const wrapJSONUnmarshalV2 = false
26
27// UnmarshalNext unmarshals the next JSON object from d into m.
28func UnmarshalNext(d *json.Decoder, m proto.Message) error {
29 return new(Unmarshaler).UnmarshalNext(d, m)
30}
31
32// Unmarshal unmarshals a JSON object from r into m.
33func Unmarshal(r io.Reader, m proto.Message) error {
34 return new(Unmarshaler).Unmarshal(r, m)
35}
36
37// UnmarshalString unmarshals a JSON object from s into m.
38func UnmarshalString(s string, m proto.Message) error {
39 return new(Unmarshaler).Unmarshal(strings.NewReader(s), m)
40}
41
42// Unmarshaler is a configurable object for converting from a JSON
43// representation to a protocol buffer object.
44type Unmarshaler struct {
45 // AllowUnknownFields specifies whether to allow messages to contain
46 // unknown JSON fields, as opposed to failing to unmarshal.
47 AllowUnknownFields bool
48
49 // AnyResolver is used to resolve the google.protobuf.Any well-known type.
50 // If unset, the global registry is used by default.
51 AnyResolver AnyResolver
52}
53
54// JSONPBUnmarshaler is implemented by protobuf messages that customize the way
55// they are unmarshaled from JSON. Messages that implement this should also
56// implement JSONPBMarshaler so that the custom format can be produced.
57//
58// The JSON unmarshaling must follow the JSON to proto specification:
59// https://developers.google.com/protocol-buffers/docs/proto3#json
60//
61// Deprecated: Custom types should implement protobuf reflection instead.
62type JSONPBUnmarshaler interface {
63 UnmarshalJSONPB(*Unmarshaler, []byte) error
64}
65
66// Unmarshal unmarshals a JSON object from r into m.
67func (u *Unmarshaler) Unmarshal(r io.Reader, m proto.Message) error {
68 return u.UnmarshalNext(json.NewDecoder(r), m)
69}
70
71// UnmarshalNext unmarshals the next JSON object from d into m.
72func (u *Unmarshaler) UnmarshalNext(d *json.Decoder, m proto.Message) error {
73 if m == nil {
74 return errors.New("invalid nil message")
75 }
76
77 // Parse the next JSON object from the stream.
78 raw := json.RawMessage{}
79 if err := d.Decode(&raw); err != nil {
80 return err
81 }
82
83 // Check for custom unmarshalers first since they may not properly
84 // implement protobuf reflection that the logic below relies on.
85 if jsu, ok := m.(JSONPBUnmarshaler); ok {
86 return jsu.UnmarshalJSONPB(u, raw)
87 }
88
89 mr := proto.MessageReflect(m)
90
91 // NOTE: For historical reasons, a top-level null is treated as a noop.
92 // This is incorrect, but kept for compatibility.
93 if string(raw) == "null" && mr.Descriptor().FullName() != "google.protobuf.Value" {
94 return nil
95 }
96
97 if wrapJSONUnmarshalV2 {
98 // NOTE: If input message is non-empty, we need to preserve merge semantics
99 // of the old jsonpb implementation. These semantics are not supported by
100 // the protobuf JSON specification.
101 isEmpty := true
102 mr.Range(func(protoreflect.FieldDescriptor, protoreflect.Value) bool {
103 isEmpty = false // at least one iteration implies non-empty
104 return false
105 })
106 if !isEmpty {
107 // Perform unmarshaling into a newly allocated, empty message.
108 mr = mr.New()
109
110 // Use a defer to copy all unmarshaled fields into the original message.
111 dst := proto.MessageReflect(m)
112 defer mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
113 dst.Set(fd, v)
114 return true
115 })
116 }
117
118 // Unmarshal using the v2 JSON unmarshaler.
119 opts := protojson.UnmarshalOptions{
120 DiscardUnknown: u.AllowUnknownFields,
121 }
122 if u.AnyResolver != nil {
123 opts.Resolver = anyResolver{u.AnyResolver}
124 }
125 return opts.Unmarshal(raw, mr.Interface())
126 } else {
127 if err := u.unmarshalMessage(mr, raw); err != nil {
128 return err
129 }
130 return protoV2.CheckInitialized(mr.Interface())
131 }
132}
133
134func (u *Unmarshaler) unmarshalMessage(m protoreflect.Message, in []byte) error {
135 md := m.Descriptor()
136 fds := md.Fields()
137
138 if jsu, ok := proto.MessageV1(m.Interface()).(JSONPBUnmarshaler); ok {
139 return jsu.UnmarshalJSONPB(u, in)
140 }
141
142 if string(in) == "null" && md.FullName() != "google.protobuf.Value" {
143 return nil
144 }
145
146 switch wellKnownType(md.FullName()) {
147 case "Any":
148 var jsonObject map[string]json.RawMessage
149 if err := json.Unmarshal(in, &jsonObject); err != nil {
150 return err
151 }
152
153 rawTypeURL, ok := jsonObject["@type"]
154 if !ok {
155 return errors.New("Any JSON doesn't have '@type'")
156 }
157 typeURL, err := unquoteString(string(rawTypeURL))
158 if err != nil {
159 return fmt.Errorf("can't unmarshal Any's '@type': %q", rawTypeURL)
160 }
161 m.Set(fds.ByNumber(1), protoreflect.ValueOfString(typeURL))
162
163 var m2 protoreflect.Message
164 if u.AnyResolver != nil {
165 mi, err := u.AnyResolver.Resolve(typeURL)
166 if err != nil {
167 return err
168 }
169 m2 = proto.MessageReflect(mi)
170 } else {
171 mt, err := protoregistry.GlobalTypes.FindMessageByURL(typeURL)
172 if err != nil {
173 if err == protoregistry.NotFound {
174 return fmt.Errorf("could not resolve Any message type: %v", typeURL)
175 }
176 return err
177 }
178 m2 = mt.New()
179 }
180
181 if wellKnownType(m2.Descriptor().FullName()) != "" {
182 rawValue, ok := jsonObject["value"]
183 if !ok {
184 return errors.New("Any JSON doesn't have 'value'")
185 }
186 if err := u.unmarshalMessage(m2, rawValue); err != nil {
187 return fmt.Errorf("can't unmarshal Any nested proto %v: %v", typeURL, err)
188 }
189 } else {
190 delete(jsonObject, "@type")
191 rawJSON, err := json.Marshal(jsonObject)
192 if err != nil {
193 return fmt.Errorf("can't generate JSON for Any's nested proto to be unmarshaled: %v", err)
194 }
195 if err = u.unmarshalMessage(m2, rawJSON); err != nil {
196 return fmt.Errorf("can't unmarshal Any nested proto %v: %v", typeURL, err)
197 }
198 }
199
200 rawWire, err := protoV2.Marshal(m2.Interface())
201 if err != nil {
202 return fmt.Errorf("can't marshal proto %v into Any.Value: %v", typeURL, err)
203 }
204 m.Set(fds.ByNumber(2), protoreflect.ValueOfBytes(rawWire))
205 return nil
206 case "BoolValue", "BytesValue", "StringValue",
207 "Int32Value", "UInt32Value", "FloatValue",
208 "Int64Value", "UInt64Value", "DoubleValue":
209 fd := fds.ByNumber(1)
210 v, err := u.unmarshalValue(m.NewField(fd), in, fd)
211 if err != nil {
212 return err
213 }
214 m.Set(fd, v)
215 return nil
216 case "Duration":
217 v, err := unquoteString(string(in))
218 if err != nil {
219 return err
220 }
221 d, err := time.ParseDuration(v)
222 if err != nil {
223 return fmt.Errorf("bad Duration: %v", err)
224 }
225
226 sec := d.Nanoseconds() / 1e9
227 nsec := d.Nanoseconds() % 1e9
228 m.Set(fds.ByNumber(1), protoreflect.ValueOfInt64(int64(sec)))
229 m.Set(fds.ByNumber(2), protoreflect.ValueOfInt32(int32(nsec)))
230 return nil
231 case "Timestamp":
232 v, err := unquoteString(string(in))
233 if err != nil {
234 return err
235 }
236 t, err := time.Parse(time.RFC3339Nano, v)
237 if err != nil {
238 return fmt.Errorf("bad Timestamp: %v", err)
239 }
240
241 sec := t.Unix()
242 nsec := t.Nanosecond()
243 m.Set(fds.ByNumber(1), protoreflect.ValueOfInt64(int64(sec)))
244 m.Set(fds.ByNumber(2), protoreflect.ValueOfInt32(int32(nsec)))
245 return nil
246 case "Value":
247 switch {
248 case string(in) == "null":
249 m.Set(fds.ByNumber(1), protoreflect.ValueOfEnum(0))
250 case string(in) == "true":
251 m.Set(fds.ByNumber(4), protoreflect.ValueOfBool(true))
252 case string(in) == "false":
253 m.Set(fds.ByNumber(4), protoreflect.ValueOfBool(false))
254 case hasPrefixAndSuffix('"', in, '"'):
255 s, err := unquoteString(string(in))
256 if err != nil {
257 return fmt.Errorf("unrecognized type for Value %q", in)
258 }
259 m.Set(fds.ByNumber(3), protoreflect.ValueOfString(s))
260 case hasPrefixAndSuffix('[', in, ']'):
261 v := m.Mutable(fds.ByNumber(6))
262 return u.unmarshalMessage(v.Message(), in)
263 case hasPrefixAndSuffix('{', in, '}'):
264 v := m.Mutable(fds.ByNumber(5))
265 return u.unmarshalMessage(v.Message(), in)
266 default:
267 f, err := strconv.ParseFloat(string(in), 0)
268 if err != nil {
269 return fmt.Errorf("unrecognized type for Value %q", in)
270 }
271 m.Set(fds.ByNumber(2), protoreflect.ValueOfFloat64(f))
272 }
273 return nil
274 case "ListValue":
275 var jsonArray []json.RawMessage
276 if err := json.Unmarshal(in, &jsonArray); err != nil {
277 return fmt.Errorf("bad ListValue: %v", err)
278 }
279
280 lv := m.Mutable(fds.ByNumber(1)).List()
281 for _, raw := range jsonArray {
282 ve := lv.NewElement()
283 if err := u.unmarshalMessage(ve.Message(), raw); err != nil {
284 return err
285 }
286 lv.Append(ve)
287 }
288 return nil
289 case "Struct":
290 var jsonObject map[string]json.RawMessage
291 if err := json.Unmarshal(in, &jsonObject); err != nil {
292 return fmt.Errorf("bad StructValue: %v", err)
293 }
294
295 mv := m.Mutable(fds.ByNumber(1)).Map()
296 for key, raw := range jsonObject {
297 kv := protoreflect.ValueOf(key).MapKey()
298 vv := mv.NewValue()
299 if err := u.unmarshalMessage(vv.Message(), raw); err != nil {
300 return fmt.Errorf("bad value in StructValue for key %q: %v", key, err)
301 }
302 mv.Set(kv, vv)
303 }
304 return nil
305 }
306
307 var jsonObject map[string]json.RawMessage
308 if err := json.Unmarshal(in, &jsonObject); err != nil {
309 return err
310 }
311
312 // Handle known fields.
313 for i := 0; i < fds.Len(); i++ {
314 fd := fds.Get(i)
315 if fd.IsWeak() && fd.Message().IsPlaceholder() {
316 continue // weak reference is not linked in
317 }
318
319 // Search for any raw JSON value associated with this field.
320 var raw json.RawMessage
321 name := string(fd.Name())
322 if fd.Kind() == protoreflect.GroupKind {
323 name = string(fd.Message().Name())
324 }
325 if v, ok := jsonObject[name]; ok {
326 delete(jsonObject, name)
327 raw = v
328 }
329 name = string(fd.JSONName())
330 if v, ok := jsonObject[name]; ok {
331 delete(jsonObject, name)
332 raw = v
333 }
334
335 field := m.NewField(fd)
336 // Unmarshal the field value.
337 if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd) && !isSingularJSONPBUnmarshaler(field, fd)) {
338 continue
339 }
340 v, err := u.unmarshalValue(field, raw, fd)
341 if err != nil {
342 return err
343 }
344 m.Set(fd, v)
345 }
346
347 // Handle extension fields.
348 for name, raw := range jsonObject {
349 if !strings.HasPrefix(name, "[") || !strings.HasSuffix(name, "]") {
350 continue
351 }
352
353 // Resolve the extension field by name.
354 xname := protoreflect.FullName(name[len("[") : len(name)-len("]")])
355 xt, _ := protoregistry.GlobalTypes.FindExtensionByName(xname)
356 if xt == nil && isMessageSet(md) {
357 xt, _ = protoregistry.GlobalTypes.FindExtensionByName(xname.Append("message_set_extension"))
358 }
359 if xt == nil {
360 continue
361 }
362 delete(jsonObject, name)
363 fd := xt.TypeDescriptor()
364 if fd.ContainingMessage().FullName() != m.Descriptor().FullName() {
365 return fmt.Errorf("extension field %q does not extend message %q", xname, m.Descriptor().FullName())
366 }
367
368 field := m.NewField(fd)
369 // Unmarshal the field value.
370 if raw == nil || (string(raw) == "null" && !isSingularWellKnownValue(fd) && !isSingularJSONPBUnmarshaler(field, fd)) {
371 continue
372 }
373 v, err := u.unmarshalValue(field, raw, fd)
374 if err != nil {
375 return err
376 }
377 m.Set(fd, v)
378 }
379
380 if !u.AllowUnknownFields && len(jsonObject) > 0 {
381 for name := range jsonObject {
382 return fmt.Errorf("unknown field %q in %v", name, md.FullName())
383 }
384 }
385 return nil
386}
387
388func isSingularWellKnownValue(fd protoreflect.FieldDescriptor) bool {
389 if md := fd.Message(); md != nil {
390 return md.FullName() == "google.protobuf.Value" && fd.Cardinality() != protoreflect.Repeated
391 }
392 return false
393}
394
395func isSingularJSONPBUnmarshaler(v protoreflect.Value, fd protoreflect.FieldDescriptor) bool {
396 if fd.Message() != nil && fd.Cardinality() != protoreflect.Repeated {
397 _, ok := proto.MessageV1(v.Interface()).(JSONPBUnmarshaler)
398 return ok
399 }
400 return false
401}
402
403func (u *Unmarshaler) unmarshalValue(v protoreflect.Value, in []byte, fd protoreflect.FieldDescriptor) (protoreflect.Value, error) {
404 switch {
405 case fd.IsList():
406 var jsonArray []json.RawMessage
407 if err := json.Unmarshal(in, &jsonArray); err != nil {
408 return v, err
409 }
410 lv := v.List()
411 for _, raw := range jsonArray {
412 ve, err := u.unmarshalSingularValue(lv.NewElement(), raw, fd)
413 if err != nil {
414 return v, err
415 }
416 lv.Append(ve)
417 }
418 return v, nil
419 case fd.IsMap():
420 var jsonObject map[string]json.RawMessage
421 if err := json.Unmarshal(in, &jsonObject); err != nil {
422 return v, err
423 }
424 kfd := fd.MapKey()
425 vfd := fd.MapValue()
426 mv := v.Map()
427 for key, raw := range jsonObject {
428 var kv protoreflect.MapKey
429 if kfd.Kind() == protoreflect.StringKind {
430 kv = protoreflect.ValueOf(key).MapKey()
431 } else {
432 v, err := u.unmarshalSingularValue(kfd.Default(), []byte(key), kfd)
433 if err != nil {
434 return v, err
435 }
436 kv = v.MapKey()
437 }
438
439 vv, err := u.unmarshalSingularValue(mv.NewValue(), raw, vfd)
440 if err != nil {
441 return v, err
442 }
443 mv.Set(kv, vv)
444 }
445 return v, nil
446 default:
447 return u.unmarshalSingularValue(v, in, fd)
448 }
449}
450
451var nonFinite = map[string]float64{
452 `"NaN"`: math.NaN(),
453 `"Infinity"`: math.Inf(+1),
454 `"-Infinity"`: math.Inf(-1),
455}
456
457func (u *Unmarshaler) unmarshalSingularValue(v protoreflect.Value, in []byte, fd protoreflect.FieldDescriptor) (protoreflect.Value, error) {
458 switch fd.Kind() {
459 case protoreflect.BoolKind:
460 return unmarshalValue(in, new(bool))
461 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
462 return unmarshalValue(trimQuote(in), new(int32))
463 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
464 return unmarshalValue(trimQuote(in), new(int64))
465 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
466 return unmarshalValue(trimQuote(in), new(uint32))
467 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
468 return unmarshalValue(trimQuote(in), new(uint64))
469 case protoreflect.FloatKind:
470 if f, ok := nonFinite[string(in)]; ok {
471 return protoreflect.ValueOfFloat32(float32(f)), nil
472 }
473 return unmarshalValue(trimQuote(in), new(float32))
474 case protoreflect.DoubleKind:
475 if f, ok := nonFinite[string(in)]; ok {
476 return protoreflect.ValueOfFloat64(float64(f)), nil
477 }
478 return unmarshalValue(trimQuote(in), new(float64))
479 case protoreflect.StringKind:
480 return unmarshalValue(in, new(string))
481 case protoreflect.BytesKind:
482 return unmarshalValue(in, new([]byte))
483 case protoreflect.EnumKind:
484 if hasPrefixAndSuffix('"', in, '"') {
485 vd := fd.Enum().Values().ByName(protoreflect.Name(trimQuote(in)))
486 if vd == nil {
487 return v, fmt.Errorf("unknown value %q for enum %s", in, fd.Enum().FullName())
488 }
489 return protoreflect.ValueOfEnum(vd.Number()), nil
490 }
491 return unmarshalValue(in, new(protoreflect.EnumNumber))
492 case protoreflect.MessageKind, protoreflect.GroupKind:
493 err := u.unmarshalMessage(v.Message(), in)
494 return v, err
495 default:
496 panic(fmt.Sprintf("invalid kind %v", fd.Kind()))
497 }
498}
499
500func unmarshalValue(in []byte, v interface{}) (protoreflect.Value, error) {
501 err := json.Unmarshal(in, v)
502 return protoreflect.ValueOf(reflect.ValueOf(v).Elem().Interface()), err
503}
504
505func unquoteString(in string) (out string, err error) {
506 err = json.Unmarshal([]byte(in), &out)
507 return out, err
508}
509
510func hasPrefixAndSuffix(prefix byte, in []byte, suffix byte) bool {
511 if len(in) >= 2 && in[0] == prefix && in[len(in)-1] == suffix {
512 return true
513 }
514 return false
515}
516
517// trimQuote is like unquoteString but simply strips surrounding quotes.
518// This is incorrect, but is behavior done by the legacy implementation.
519func trimQuote(in []byte) []byte {
520 if len(in) >= 2 && in[0] == '"' && in[len(in)-1] == '"' {
521 in = in[1 : len(in)-1]
522 }
523 return in
524}