blob: 4cbf1aeaf79c93e5b9474f69b792970e1a0cefaa [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "google.golang.org/protobuf/encoding/protowire"
9 "google.golang.org/protobuf/internal/encoding/messageset"
10 "google.golang.org/protobuf/internal/errors"
khenaidoo7d3c5582021-08-11 18:09:44 -040011 "google.golang.org/protobuf/internal/genid"
12 "google.golang.org/protobuf/internal/pragma"
13 "google.golang.org/protobuf/reflect/protoreflect"
14 "google.golang.org/protobuf/reflect/protoregistry"
15 "google.golang.org/protobuf/runtime/protoiface"
16)
17
18// UnmarshalOptions configures the unmarshaler.
19//
20// Example usage:
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053021//
22// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
khenaidoo7d3c5582021-08-11 18:09:44 -040023type UnmarshalOptions struct {
24 pragma.NoUnkeyedLiterals
25
26 // Merge merges the input into the destination message.
27 // The default behavior is to always reset the message before unmarshaling,
28 // unless Merge is specified.
29 Merge bool
30
31 // AllowPartial accepts input for messages that will result in missing
32 // required fields. If AllowPartial is false (the default), Unmarshal will
33 // return an error if there are any missing required fields.
34 AllowPartial bool
35
36 // If DiscardUnknown is set, unknown fields are ignored.
37 DiscardUnknown bool
38
39 // Resolver is used for looking up types when unmarshaling extension fields.
40 // If nil, this defaults to using protoregistry.GlobalTypes.
41 Resolver interface {
42 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
43 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
44 }
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053045
46 // RecursionLimit limits how deeply messages may be nested.
47 // If zero, a default limit is applied.
48 RecursionLimit int
49
50 //
51 // NoLazyDecoding turns off lazy decoding, which otherwise is enabled by
52 // default. Lazy decoding only affects submessages (annotated with [lazy =
53 // true] in the .proto file) within messages that use the Opaque API.
54 NoLazyDecoding bool
khenaidoo7d3c5582021-08-11 18:09:44 -040055}
56
57// Unmarshal parses the wire-format message in b and places the result in m.
58// The provided message must be mutable (e.g., a non-nil pointer to a message).
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053059//
60// See the [UnmarshalOptions] type if you need more control.
khenaidoo7d3c5582021-08-11 18:09:44 -040061func Unmarshal(b []byte, m Message) error {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053062 _, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
khenaidoo7d3c5582021-08-11 18:09:44 -040063 return err
64}
65
66// Unmarshal parses the wire-format message in b and places the result in m.
67// The provided message must be mutable (e.g., a non-nil pointer to a message).
68func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053069 if o.RecursionLimit == 0 {
70 o.RecursionLimit = protowire.DefaultRecursionLimit
71 }
khenaidoo7d3c5582021-08-11 18:09:44 -040072 _, err := o.unmarshal(b, m.ProtoReflect())
73 return err
74}
75
76// UnmarshalState parses a wire-format message and places the result in m.
77//
78// This method permits fine-grained control over the unmarshaler.
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053079// Most users should use [Unmarshal] instead.
khenaidoo7d3c5582021-08-11 18:09:44 -040080func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053081 if o.RecursionLimit == 0 {
82 o.RecursionLimit = protowire.DefaultRecursionLimit
83 }
khenaidoo7d3c5582021-08-11 18:09:44 -040084 return o.unmarshal(in.Buf, in.Message)
85}
86
87// unmarshal is a centralized function that all unmarshal operations go through.
88// For profiling purposes, avoid changing the name of this function or
89// introducing other code paths for unmarshal that do not go through this.
90func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
91 if o.Resolver == nil {
92 o.Resolver = protoregistry.GlobalTypes
93 }
94 if !o.Merge {
95 Reset(m.Interface())
96 }
97 allowPartial := o.AllowPartial
98 o.Merge = true
99 o.AllowPartial = true
100 methods := protoMethods(m)
101 if methods != nil && methods.Unmarshal != nil &&
102 !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
103 in := protoiface.UnmarshalInput{
104 Message: m,
105 Buf: b,
106 Resolver: o.Resolver,
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530107 Depth: o.RecursionLimit,
khenaidoo7d3c5582021-08-11 18:09:44 -0400108 }
109 if o.DiscardUnknown {
110 in.Flags |= protoiface.UnmarshalDiscardUnknown
111 }
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530112
113 if !allowPartial {
114 // This does not affect how current unmarshal functions work, it just allows them
115 // to record this for lazy the decoding case.
116 in.Flags |= protoiface.UnmarshalCheckRequired
117 }
118 if o.NoLazyDecoding {
119 in.Flags |= protoiface.UnmarshalNoLazyDecoding
120 }
121
khenaidoo7d3c5582021-08-11 18:09:44 -0400122 out, err = methods.Unmarshal(in)
123 } else {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530124 o.RecursionLimit--
125 if o.RecursionLimit < 0 {
126 return out, errors.New("exceeded max recursion depth")
127 }
khenaidoo7d3c5582021-08-11 18:09:44 -0400128 err = o.unmarshalMessageSlow(b, m)
129 }
130 if err != nil {
131 return out, err
132 }
133 if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
134 return out, nil
135 }
136 return out, checkInitialized(m)
137}
138
139func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
140 _, err := o.unmarshal(b, m)
141 return err
142}
143
144func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
145 md := m.Descriptor()
146 if messageset.IsMessageSet(md) {
147 return o.unmarshalMessageSet(b, m)
148 }
149 fields := md.Fields()
150 for len(b) > 0 {
151 // Parse the tag (field number and wire type).
152 num, wtyp, tagLen := protowire.ConsumeTag(b)
153 if tagLen < 0 {
154 return errDecode
155 }
156 if num > protowire.MaxValidNumber {
157 return errDecode
158 }
159
160 // Find the field descriptor for this field number.
161 fd := fields.ByNumber(num)
162 if fd == nil && md.ExtensionRanges().Has(num) {
163 extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
164 if err != nil && err != protoregistry.NotFound {
165 return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
166 }
167 if extType != nil {
168 fd = extType.TypeDescriptor()
169 }
170 }
171 var err error
172 if fd == nil {
173 err = errUnknown
khenaidoo7d3c5582021-08-11 18:09:44 -0400174 }
175
176 // Parse the field value.
177 var valLen int
178 switch {
179 case err != nil:
180 case fd.IsList():
181 valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
182 case fd.IsMap():
183 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
184 default:
185 valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
186 }
187 if err != nil {
188 if err != errUnknown {
189 return err
190 }
191 valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
192 if valLen < 0 {
193 return errDecode
194 }
195 if !o.DiscardUnknown {
196 m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
197 }
198 }
199 b = b[tagLen+valLen:]
200 }
201 return nil
202}
203
204func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
205 v, n, err := o.unmarshalScalar(b, wtyp, fd)
206 if err != nil {
207 return 0, err
208 }
209 switch fd.Kind() {
210 case protoreflect.GroupKind, protoreflect.MessageKind:
211 m2 := m.Mutable(fd).Message()
212 if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
213 return n, err
214 }
215 default:
216 // Non-message scalars replace the previous value.
217 m.Set(fd, v)
218 }
219 return n, nil
220}
221
222func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
223 if wtyp != protowire.BytesType {
224 return 0, errUnknown
225 }
226 b, n = protowire.ConsumeBytes(b)
227 if n < 0 {
228 return 0, errDecode
229 }
230 var (
231 keyField = fd.MapKey()
232 valField = fd.MapValue()
233 key protoreflect.Value
234 val protoreflect.Value
235 haveKey bool
236 haveVal bool
237 )
238 switch valField.Kind() {
239 case protoreflect.GroupKind, protoreflect.MessageKind:
240 val = mapv.NewValue()
241 }
242 // Map entries are represented as a two-element message with fields
243 // containing the key and value.
244 for len(b) > 0 {
245 num, wtyp, n := protowire.ConsumeTag(b)
246 if n < 0 {
247 return 0, errDecode
248 }
249 if num > protowire.MaxValidNumber {
250 return 0, errDecode
251 }
252 b = b[n:]
253 err = errUnknown
254 switch num {
255 case genid.MapEntry_Key_field_number:
256 key, n, err = o.unmarshalScalar(b, wtyp, keyField)
257 if err != nil {
258 break
259 }
260 haveKey = true
261 case genid.MapEntry_Value_field_number:
262 var v protoreflect.Value
263 v, n, err = o.unmarshalScalar(b, wtyp, valField)
264 if err != nil {
265 break
266 }
267 switch valField.Kind() {
268 case protoreflect.GroupKind, protoreflect.MessageKind:
269 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
270 return 0, err
271 }
272 default:
273 val = v
274 }
275 haveVal = true
276 }
277 if err == errUnknown {
278 n = protowire.ConsumeFieldValue(num, wtyp, b)
279 if n < 0 {
280 return 0, errDecode
281 }
282 } else if err != nil {
283 return 0, err
284 }
285 b = b[n:]
286 }
287 // Every map entry should have entries for key and value, but this is not strictly required.
288 if !haveKey {
289 key = keyField.Default()
290 }
291 if !haveVal {
292 switch valField.Kind() {
293 case protoreflect.GroupKind, protoreflect.MessageKind:
294 default:
295 val = valField.Default()
296 }
297 }
298 mapv.Set(key.MapKey(), val)
299 return n, nil
300}
301
302// errUnknown is used internally to indicate fields which should be added
303// to the unknown field set of a message. It is never returned from an exported
304// function.
305var errUnknown = errors.New("BUG: internal error (unknown)")
306
307var errDecode = errors.New("cannot parse invalid wire-format data")