blob: 49f9b8c88cfd05a5a79bb3b53bcbc304276231d5 [file] [log] [blame]
David K. Bainbridgebd6b2882021-08-26 13:31:02 +00001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "google.golang.org/protobuf/encoding/protowire"
9 "google.golang.org/protobuf/internal/encoding/messageset"
10 "google.golang.org/protobuf/internal/errors"
11 "google.golang.org/protobuf/internal/flags"
12 "google.golang.org/protobuf/internal/genid"
13 "google.golang.org/protobuf/internal/pragma"
14 "google.golang.org/protobuf/reflect/protoreflect"
15 "google.golang.org/protobuf/reflect/protoregistry"
16 "google.golang.org/protobuf/runtime/protoiface"
17)
18
19// UnmarshalOptions configures the unmarshaler.
20//
21// Example usage:
22// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
23type UnmarshalOptions struct {
24 pragma.NoUnkeyedLiterals
25
26 // Merge merges the input into the destination message.
27 // The default behavior is to always reset the message before unmarshaling,
28 // unless Merge is specified.
29 Merge bool
30
31 // AllowPartial accepts input for messages that will result in missing
32 // required fields. If AllowPartial is false (the default), Unmarshal will
33 // return an error if there are any missing required fields.
34 AllowPartial bool
35
36 // If DiscardUnknown is set, unknown fields are ignored.
37 DiscardUnknown bool
38
39 // Resolver is used for looking up types when unmarshaling extension fields.
40 // If nil, this defaults to using protoregistry.GlobalTypes.
41 Resolver interface {
42 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
43 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
44 }
45}
46
47// Unmarshal parses the wire-format message in b and places the result in m.
48// The provided message must be mutable (e.g., a non-nil pointer to a message).
49func Unmarshal(b []byte, m Message) error {
50 _, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
51 return err
52}
53
54// Unmarshal parses the wire-format message in b and places the result in m.
55// The provided message must be mutable (e.g., a non-nil pointer to a message).
56func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
57 _, err := o.unmarshal(b, m.ProtoReflect())
58 return err
59}
60
61// UnmarshalState parses a wire-format message and places the result in m.
62//
63// This method permits fine-grained control over the unmarshaler.
64// Most users should use Unmarshal instead.
65func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
66 return o.unmarshal(in.Buf, in.Message)
67}
68
69// unmarshal is a centralized function that all unmarshal operations go through.
70// For profiling purposes, avoid changing the name of this function or
71// introducing other code paths for unmarshal that do not go through this.
72func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
73 if o.Resolver == nil {
74 o.Resolver = protoregistry.GlobalTypes
75 }
76 if !o.Merge {
77 Reset(m.Interface())
78 }
79 allowPartial := o.AllowPartial
80 o.Merge = true
81 o.AllowPartial = true
82 methods := protoMethods(m)
83 if methods != nil && methods.Unmarshal != nil &&
84 !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
85 in := protoiface.UnmarshalInput{
86 Message: m,
87 Buf: b,
88 Resolver: o.Resolver,
89 }
90 if o.DiscardUnknown {
91 in.Flags |= protoiface.UnmarshalDiscardUnknown
92 }
93 out, err = methods.Unmarshal(in)
94 } else {
95 err = o.unmarshalMessageSlow(b, m)
96 }
97 if err != nil {
98 return out, err
99 }
100 if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
101 return out, nil
102 }
103 return out, checkInitialized(m)
104}
105
106func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
107 _, err := o.unmarshal(b, m)
108 return err
109}
110
111func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
112 md := m.Descriptor()
113 if messageset.IsMessageSet(md) {
114 return o.unmarshalMessageSet(b, m)
115 }
116 fields := md.Fields()
117 for len(b) > 0 {
118 // Parse the tag (field number and wire type).
119 num, wtyp, tagLen := protowire.ConsumeTag(b)
120 if tagLen < 0 {
121 return errDecode
122 }
123 if num > protowire.MaxValidNumber {
124 return errDecode
125 }
126
127 // Find the field descriptor for this field number.
128 fd := fields.ByNumber(num)
129 if fd == nil && md.ExtensionRanges().Has(num) {
130 extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
131 if err != nil && err != protoregistry.NotFound {
132 return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
133 }
134 if extType != nil {
135 fd = extType.TypeDescriptor()
136 }
137 }
138 var err error
139 if fd == nil {
140 err = errUnknown
141 } else if flags.ProtoLegacy {
142 if fd.IsWeak() && fd.Message().IsPlaceholder() {
143 err = errUnknown // weak referent is not linked in
144 }
145 }
146
147 // Parse the field value.
148 var valLen int
149 switch {
150 case err != nil:
151 case fd.IsList():
152 valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
153 case fd.IsMap():
154 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
155 default:
156 valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
157 }
158 if err != nil {
159 if err != errUnknown {
160 return err
161 }
162 valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
163 if valLen < 0 {
164 return errDecode
165 }
166 if !o.DiscardUnknown {
167 m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
168 }
169 }
170 b = b[tagLen+valLen:]
171 }
172 return nil
173}
174
175func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
176 v, n, err := o.unmarshalScalar(b, wtyp, fd)
177 if err != nil {
178 return 0, err
179 }
180 switch fd.Kind() {
181 case protoreflect.GroupKind, protoreflect.MessageKind:
182 m2 := m.Mutable(fd).Message()
183 if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
184 return n, err
185 }
186 default:
187 // Non-message scalars replace the previous value.
188 m.Set(fd, v)
189 }
190 return n, nil
191}
192
193func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
194 if wtyp != protowire.BytesType {
195 return 0, errUnknown
196 }
197 b, n = protowire.ConsumeBytes(b)
198 if n < 0 {
199 return 0, errDecode
200 }
201 var (
202 keyField = fd.MapKey()
203 valField = fd.MapValue()
204 key protoreflect.Value
205 val protoreflect.Value
206 haveKey bool
207 haveVal bool
208 )
209 switch valField.Kind() {
210 case protoreflect.GroupKind, protoreflect.MessageKind:
211 val = mapv.NewValue()
212 }
213 // Map entries are represented as a two-element message with fields
214 // containing the key and value.
215 for len(b) > 0 {
216 num, wtyp, n := protowire.ConsumeTag(b)
217 if n < 0 {
218 return 0, errDecode
219 }
220 if num > protowire.MaxValidNumber {
221 return 0, errDecode
222 }
223 b = b[n:]
224 err = errUnknown
225 switch num {
226 case genid.MapEntry_Key_field_number:
227 key, n, err = o.unmarshalScalar(b, wtyp, keyField)
228 if err != nil {
229 break
230 }
231 haveKey = true
232 case genid.MapEntry_Value_field_number:
233 var v protoreflect.Value
234 v, n, err = o.unmarshalScalar(b, wtyp, valField)
235 if err != nil {
236 break
237 }
238 switch valField.Kind() {
239 case protoreflect.GroupKind, protoreflect.MessageKind:
240 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
241 return 0, err
242 }
243 default:
244 val = v
245 }
246 haveVal = true
247 }
248 if err == errUnknown {
249 n = protowire.ConsumeFieldValue(num, wtyp, b)
250 if n < 0 {
251 return 0, errDecode
252 }
253 } else if err != nil {
254 return 0, err
255 }
256 b = b[n:]
257 }
258 // Every map entry should have entries for key and value, but this is not strictly required.
259 if !haveKey {
260 key = keyField.Default()
261 }
262 if !haveVal {
263 switch valField.Kind() {
264 case protoreflect.GroupKind, protoreflect.MessageKind:
265 default:
266 val = valField.Default()
267 }
268 }
269 mapv.Set(key.MapKey(), val)
270 return n, nil
271}
272
273// errUnknown is used internally to indicate fields which should be added
274// to the unknown field set of a message. It is never returned from an exported
275// function.
276var errUnknown = errors.New("BUG: internal error (unknown)")
277
278var errDecode = errors.New("cannot parse invalid wire-format data")