blob: 128214760ccc16976a1d0b4628f3ca487dadc1c2 [file] [log] [blame]
amit.ghosh258d14c2020-10-02 15:13:38 +02001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "google.golang.org/protobuf/encoding/protowire"
9 "google.golang.org/protobuf/internal/encoding/messageset"
10 "google.golang.org/protobuf/internal/errors"
11 "google.golang.org/protobuf/internal/flags"
12 "google.golang.org/protobuf/internal/pragma"
13 "google.golang.org/protobuf/reflect/protoreflect"
14 "google.golang.org/protobuf/reflect/protoregistry"
15 "google.golang.org/protobuf/runtime/protoiface"
16)
17
18// UnmarshalOptions configures the unmarshaler.
19//
20// Example usage:
21// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
22type UnmarshalOptions struct {
23 pragma.NoUnkeyedLiterals
24
25 // Merge merges the input into the destination message.
26 // The default behavior is to always reset the message before unmarshaling,
27 // unless Merge is specified.
28 Merge bool
29
30 // AllowPartial accepts input for messages that will result in missing
31 // required fields. If AllowPartial is false (the default), Unmarshal will
32 // return an error if there are any missing required fields.
33 AllowPartial bool
34
35 // If DiscardUnknown is set, unknown fields are ignored.
36 DiscardUnknown bool
37
38 // Resolver is used for looking up types when unmarshaling extension fields.
39 // If nil, this defaults to using protoregistry.GlobalTypes.
40 Resolver interface {
41 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
42 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
43 }
44}
45
46// Unmarshal parses the wire-format message in b and places the result in m.
47func Unmarshal(b []byte, m Message) error {
48 _, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
49 return err
50}
51
52// Unmarshal parses the wire-format message in b and places the result in m.
53func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
54 _, err := o.unmarshal(b, m.ProtoReflect())
55 return err
56}
57
58// UnmarshalState parses a wire-format message and places the result in m.
59//
60// This method permits fine-grained control over the unmarshaler.
61// Most users should use Unmarshal instead.
62func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
63 return o.unmarshal(in.Buf, in.Message)
64}
65
66func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
67 if o.Resolver == nil {
68 o.Resolver = protoregistry.GlobalTypes
69 }
70 if !o.Merge {
71 Reset(m.Interface()) // TODO
72 }
73 allowPartial := o.AllowPartial
74 o.Merge = true
75 o.AllowPartial = true
76 methods := protoMethods(m)
77 if methods != nil && methods.Unmarshal != nil &&
78 !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
79 in := protoiface.UnmarshalInput{
80 Message: m,
81 Buf: b,
82 Resolver: o.Resolver,
83 }
84 if o.DiscardUnknown {
85 in.Flags |= protoiface.UnmarshalDiscardUnknown
86 }
87 out, err = methods.Unmarshal(in)
88 } else {
89 err = o.unmarshalMessageSlow(b, m)
90 }
91 if err != nil {
92 return out, err
93 }
94 if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
95 return out, nil
96 }
97 return out, checkInitialized(m)
98}
99
100func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
101 _, err := o.unmarshal(b, m)
102 return err
103}
104
105func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
106 md := m.Descriptor()
107 if messageset.IsMessageSet(md) {
108 return unmarshalMessageSet(b, m, o)
109 }
110 fields := md.Fields()
111 for len(b) > 0 {
112 // Parse the tag (field number and wire type).
113 num, wtyp, tagLen := protowire.ConsumeTag(b)
114 if tagLen < 0 {
115 return protowire.ParseError(tagLen)
116 }
117 if num > protowire.MaxValidNumber {
118 return errors.New("invalid field number")
119 }
120
121 // Find the field descriptor for this field number.
122 fd := fields.ByNumber(num)
123 if fd == nil && md.ExtensionRanges().Has(num) {
124 extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
125 if err != nil && err != protoregistry.NotFound {
126 return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
127 }
128 if extType != nil {
129 fd = extType.TypeDescriptor()
130 }
131 }
132 var err error
133 if fd == nil {
134 err = errUnknown
135 } else if flags.ProtoLegacy {
136 if fd.IsWeak() && fd.Message().IsPlaceholder() {
137 err = errUnknown // weak referent is not linked in
138 }
139 }
140
141 // Parse the field value.
142 var valLen int
143 switch {
144 case err != nil:
145 case fd.IsList():
146 valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
147 case fd.IsMap():
148 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
149 default:
150 valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
151 }
152 if err != nil {
153 if err != errUnknown {
154 return err
155 }
156 valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
157 if valLen < 0 {
158 return protowire.ParseError(valLen)
159 }
160 if !o.DiscardUnknown {
161 m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
162 }
163 }
164 b = b[tagLen+valLen:]
165 }
166 return nil
167}
168
169func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
170 v, n, err := o.unmarshalScalar(b, wtyp, fd)
171 if err != nil {
172 return 0, err
173 }
174 switch fd.Kind() {
175 case protoreflect.GroupKind, protoreflect.MessageKind:
176 m2 := m.Mutable(fd).Message()
177 if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
178 return n, err
179 }
180 default:
181 // Non-message scalars replace the previous value.
182 m.Set(fd, v)
183 }
184 return n, nil
185}
186
187func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
188 if wtyp != protowire.BytesType {
189 return 0, errUnknown
190 }
191 b, n = protowire.ConsumeBytes(b)
192 if n < 0 {
193 return 0, protowire.ParseError(n)
194 }
195 var (
196 keyField = fd.MapKey()
197 valField = fd.MapValue()
198 key protoreflect.Value
199 val protoreflect.Value
200 haveKey bool
201 haveVal bool
202 )
203 switch valField.Kind() {
204 case protoreflect.GroupKind, protoreflect.MessageKind:
205 val = mapv.NewValue()
206 }
207 // Map entries are represented as a two-element message with fields
208 // containing the key and value.
209 for len(b) > 0 {
210 num, wtyp, n := protowire.ConsumeTag(b)
211 if n < 0 {
212 return 0, protowire.ParseError(n)
213 }
214 if num > protowire.MaxValidNumber {
215 return 0, errors.New("invalid field number")
216 }
217 b = b[n:]
218 err = errUnknown
219 switch num {
220 case 1:
221 key, n, err = o.unmarshalScalar(b, wtyp, keyField)
222 if err != nil {
223 break
224 }
225 haveKey = true
226 case 2:
227 var v protoreflect.Value
228 v, n, err = o.unmarshalScalar(b, wtyp, valField)
229 if err != nil {
230 break
231 }
232 switch valField.Kind() {
233 case protoreflect.GroupKind, protoreflect.MessageKind:
234 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
235 return 0, err
236 }
237 default:
238 val = v
239 }
240 haveVal = true
241 }
242 if err == errUnknown {
243 n = protowire.ConsumeFieldValue(num, wtyp, b)
244 if n < 0 {
245 return 0, protowire.ParseError(n)
246 }
247 } else if err != nil {
248 return 0, err
249 }
250 b = b[n:]
251 }
252 // Every map entry should have entries for key and value, but this is not strictly required.
253 if !haveKey {
254 key = keyField.Default()
255 }
256 if !haveVal {
257 switch valField.Kind() {
258 case protoreflect.GroupKind, protoreflect.MessageKind:
259 default:
260 val = valField.Default()
261 }
262 }
263 mapv.Set(key.MapKey(), val)
264 return n, nil
265}
266
267// errUnknown is used internally to indicate fields which should be added
268// to the unknown field set of a message. It is never returned from an exported
269// function.
270var errUnknown = errors.New("BUG: internal error (unknown)")