blob: 4974b16d544828a51a4cb4812962a4c663bd4d6e [file] [log] [blame]
Matteo Scandoloa4285862020-12-01 18:10:10 -08001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "google.golang.org/protobuf/encoding/protowire"
9 "google.golang.org/protobuf/internal/encoding/messageset"
10 "google.golang.org/protobuf/internal/errors"
11 "google.golang.org/protobuf/internal/flags"
12 "google.golang.org/protobuf/internal/pragma"
13 "google.golang.org/protobuf/reflect/protoreflect"
14 "google.golang.org/protobuf/reflect/protoregistry"
15 "google.golang.org/protobuf/runtime/protoiface"
16)
17
18// UnmarshalOptions configures the unmarshaler.
19//
20// Example usage:
21// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
22type UnmarshalOptions struct {
23 pragma.NoUnkeyedLiterals
24
25 // Merge merges the input into the destination message.
26 // The default behavior is to always reset the message before unmarshaling,
27 // unless Merge is specified.
28 Merge bool
29
30 // AllowPartial accepts input for messages that will result in missing
31 // required fields. If AllowPartial is false (the default), Unmarshal will
32 // return an error if there are any missing required fields.
33 AllowPartial bool
34
35 // If DiscardUnknown is set, unknown fields are ignored.
36 DiscardUnknown bool
37
38 // Resolver is used for looking up types when unmarshaling extension fields.
39 // If nil, this defaults to using protoregistry.GlobalTypes.
40 Resolver interface {
41 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
42 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
43 }
44}
45
46// Unmarshal parses the wire-format message in b and places the result in m.
47func Unmarshal(b []byte, m Message) error {
48 _, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
49 return err
50}
51
52// Unmarshal parses the wire-format message in b and places the result in m.
53func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
54 _, err := o.unmarshal(b, m.ProtoReflect())
55 return err
56}
57
58// UnmarshalState parses a wire-format message and places the result in m.
59//
60// This method permits fine-grained control over the unmarshaler.
61// Most users should use Unmarshal instead.
62func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
63 return o.unmarshal(in.Buf, in.Message)
64}
65
66// unmarshal is a centralized function that all unmarshal operations go through.
67// For profiling purposes, avoid changing the name of this function or
68// introducing other code paths for unmarshal that do not go through this.
69func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
70 if o.Resolver == nil {
71 o.Resolver = protoregistry.GlobalTypes
72 }
73 if !o.Merge {
74 Reset(m.Interface())
75 }
76 allowPartial := o.AllowPartial
77 o.Merge = true
78 o.AllowPartial = true
79 methods := protoMethods(m)
80 if methods != nil && methods.Unmarshal != nil &&
81 !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
82 in := protoiface.UnmarshalInput{
83 Message: m,
84 Buf: b,
85 Resolver: o.Resolver,
86 }
87 if o.DiscardUnknown {
88 in.Flags |= protoiface.UnmarshalDiscardUnknown
89 }
90 out, err = methods.Unmarshal(in)
91 } else {
92 err = o.unmarshalMessageSlow(b, m)
93 }
94 if err != nil {
95 return out, err
96 }
97 if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
98 return out, nil
99 }
100 return out, checkInitialized(m)
101}
102
103func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
104 _, err := o.unmarshal(b, m)
105 return err
106}
107
108func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
109 md := m.Descriptor()
110 if messageset.IsMessageSet(md) {
111 return o.unmarshalMessageSet(b, m)
112 }
113 fields := md.Fields()
114 for len(b) > 0 {
115 // Parse the tag (field number and wire type).
116 num, wtyp, tagLen := protowire.ConsumeTag(b)
117 if tagLen < 0 {
118 return protowire.ParseError(tagLen)
119 }
120 if num > protowire.MaxValidNumber {
121 return errors.New("invalid field number")
122 }
123
124 // Find the field descriptor for this field number.
125 fd := fields.ByNumber(num)
126 if fd == nil && md.ExtensionRanges().Has(num) {
127 extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
128 if err != nil && err != protoregistry.NotFound {
129 return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
130 }
131 if extType != nil {
132 fd = extType.TypeDescriptor()
133 }
134 }
135 var err error
136 if fd == nil {
137 err = errUnknown
138 } else if flags.ProtoLegacy {
139 if fd.IsWeak() && fd.Message().IsPlaceholder() {
140 err = errUnknown // weak referent is not linked in
141 }
142 }
143
144 // Parse the field value.
145 var valLen int
146 switch {
147 case err != nil:
148 case fd.IsList():
149 valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
150 case fd.IsMap():
151 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
152 default:
153 valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
154 }
155 if err != nil {
156 if err != errUnknown {
157 return err
158 }
159 valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
160 if valLen < 0 {
161 return protowire.ParseError(valLen)
162 }
163 if !o.DiscardUnknown {
164 m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
165 }
166 }
167 b = b[tagLen+valLen:]
168 }
169 return nil
170}
171
172func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
173 v, n, err := o.unmarshalScalar(b, wtyp, fd)
174 if err != nil {
175 return 0, err
176 }
177 switch fd.Kind() {
178 case protoreflect.GroupKind, protoreflect.MessageKind:
179 m2 := m.Mutable(fd).Message()
180 if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
181 return n, err
182 }
183 default:
184 // Non-message scalars replace the previous value.
185 m.Set(fd, v)
186 }
187 return n, nil
188}
189
190func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
191 if wtyp != protowire.BytesType {
192 return 0, errUnknown
193 }
194 b, n = protowire.ConsumeBytes(b)
195 if n < 0 {
196 return 0, protowire.ParseError(n)
197 }
198 var (
199 keyField = fd.MapKey()
200 valField = fd.MapValue()
201 key protoreflect.Value
202 val protoreflect.Value
203 haveKey bool
204 haveVal bool
205 )
206 switch valField.Kind() {
207 case protoreflect.GroupKind, protoreflect.MessageKind:
208 val = mapv.NewValue()
209 }
210 // Map entries are represented as a two-element message with fields
211 // containing the key and value.
212 for len(b) > 0 {
213 num, wtyp, n := protowire.ConsumeTag(b)
214 if n < 0 {
215 return 0, protowire.ParseError(n)
216 }
217 if num > protowire.MaxValidNumber {
218 return 0, errors.New("invalid field number")
219 }
220 b = b[n:]
221 err = errUnknown
222 switch num {
223 case 1:
224 key, n, err = o.unmarshalScalar(b, wtyp, keyField)
225 if err != nil {
226 break
227 }
228 haveKey = true
229 case 2:
230 var v protoreflect.Value
231 v, n, err = o.unmarshalScalar(b, wtyp, valField)
232 if err != nil {
233 break
234 }
235 switch valField.Kind() {
236 case protoreflect.GroupKind, protoreflect.MessageKind:
237 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
238 return 0, err
239 }
240 default:
241 val = v
242 }
243 haveVal = true
244 }
245 if err == errUnknown {
246 n = protowire.ConsumeFieldValue(num, wtyp, b)
247 if n < 0 {
248 return 0, protowire.ParseError(n)
249 }
250 } else if err != nil {
251 return 0, err
252 }
253 b = b[n:]
254 }
255 // Every map entry should have entries for key and value, but this is not strictly required.
256 if !haveKey {
257 key = keyField.Default()
258 }
259 if !haveVal {
260 switch valField.Kind() {
261 case protoreflect.GroupKind, protoreflect.MessageKind:
262 default:
263 val = valField.Default()
264 }
265 }
266 mapv.Set(key.MapKey(), val)
267 return n, nil
268}
269
270// errUnknown is used internally to indicate fields which should be added
271// to the unknown field set of a message. It is never returned from an exported
272// function.
273var errUnknown = errors.New("BUG: internal error (unknown)")