blob: 949dc49a65b3cf4f9892292d9a556a66fe33d42d [file] [log] [blame]
khenaidoo5fc5cea2021-08-11 17:39:16 -04001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "math/bits"
9
10 "google.golang.org/protobuf/encoding/protowire"
11 "google.golang.org/protobuf/internal/errors"
12 "google.golang.org/protobuf/internal/flags"
13 "google.golang.org/protobuf/proto"
14 "google.golang.org/protobuf/reflect/protoreflect"
15 preg "google.golang.org/protobuf/reflect/protoregistry"
16 "google.golang.org/protobuf/runtime/protoiface"
17 piface "google.golang.org/protobuf/runtime/protoiface"
18)
19
20var errDecode = errors.New("cannot parse invalid wire-format data")
21
22type unmarshalOptions struct {
23 flags protoiface.UnmarshalInputFlags
24 resolver interface {
25 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
26 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
27 }
28}
29
30func (o unmarshalOptions) Options() proto.UnmarshalOptions {
31 return proto.UnmarshalOptions{
32 Merge: true,
33 AllowPartial: true,
34 DiscardUnknown: o.DiscardUnknown(),
35 Resolver: o.resolver,
36 }
37}
38
39func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
40
41func (o unmarshalOptions) IsDefault() bool {
42 return o.flags == 0 && o.resolver == preg.GlobalTypes
43}
44
45var lazyUnmarshalOptions = unmarshalOptions{
46 resolver: preg.GlobalTypes,
47}
48
49type unmarshalOutput struct {
50 n int // number of bytes consumed
51 initialized bool
52}
53
54// unmarshal is protoreflect.Methods.Unmarshal.
55func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
56 var p pointer
57 if ms, ok := in.Message.(*messageState); ok {
58 p = ms.pointer()
59 } else {
60 p = in.Message.(*messageReflectWrapper).pointer()
61 }
62 out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
63 flags: in.Flags,
64 resolver: in.Resolver,
65 })
66 var flags piface.UnmarshalOutputFlags
67 if out.initialized {
68 flags |= piface.UnmarshalInitialized
69 }
70 return piface.UnmarshalOutput{
71 Flags: flags,
72 }, err
73}
74
75// errUnknown is returned during unmarshaling to indicate a parse error that
76// should result in a field being placed in the unknown fields section (for example,
77// when the wire type doesn't match) as opposed to the entire unmarshal operation
78// failing (for example, when a field extends past the available input).
79//
80// This is a sentinel error which should never be visible to the user.
81var errUnknown = errors.New("unknown")
82
83func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
84 mi.init()
85 if flags.ProtoLegacy && mi.isMessageSet {
86 return unmarshalMessageSet(mi, b, p, opts)
87 }
88 initialized := true
89 var requiredMask uint64
90 var exts *map[int32]ExtensionField
91 start := len(b)
92 for len(b) > 0 {
93 // Parse the tag (field number and wire type).
94 var tag uint64
95 if b[0] < 0x80 {
96 tag = uint64(b[0])
97 b = b[1:]
98 } else if len(b) >= 2 && b[1] < 128 {
99 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
100 b = b[2:]
101 } else {
102 var n int
103 tag, n = protowire.ConsumeVarint(b)
104 if n < 0 {
105 return out, errDecode
106 }
107 b = b[n:]
108 }
109 var num protowire.Number
110 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
111 return out, errDecode
112 } else {
113 num = protowire.Number(n)
114 }
115 wtyp := protowire.Type(tag & 7)
116
117 if wtyp == protowire.EndGroupType {
118 if num != groupTag {
119 return out, errDecode
120 }
121 groupTag = 0
122 break
123 }
124
125 var f *coderFieldInfo
126 if int(num) < len(mi.denseCoderFields) {
127 f = mi.denseCoderFields[num]
128 } else {
129 f = mi.coderFields[num]
130 }
131 var n int
132 err := errUnknown
133 switch {
134 case f != nil:
135 if f.funcs.unmarshal == nil {
136 break
137 }
138 var o unmarshalOutput
139 o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
140 n = o.n
141 if err != nil {
142 break
143 }
144 requiredMask |= f.validation.requiredBit
145 if f.funcs.isInit != nil && !o.initialized {
146 initialized = false
147 }
148 default:
149 // Possible extension.
150 if exts == nil && mi.extensionOffset.IsValid() {
151 exts = p.Apply(mi.extensionOffset).Extensions()
152 if *exts == nil {
153 *exts = make(map[int32]ExtensionField)
154 }
155 }
156 if exts == nil {
157 break
158 }
159 var o unmarshalOutput
160 o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
161 if err != nil {
162 break
163 }
164 n = o.n
165 if !o.initialized {
166 initialized = false
167 }
168 }
169 if err != nil {
170 if err != errUnknown {
171 return out, err
172 }
173 n = protowire.ConsumeFieldValue(num, wtyp, b)
174 if n < 0 {
175 return out, errDecode
176 }
177 if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
178 u := mi.mutableUnknownBytes(p)
179 *u = protowire.AppendTag(*u, num, wtyp)
180 *u = append(*u, b[:n]...)
181 }
182 }
183 b = b[n:]
184 }
185 if groupTag != 0 {
186 return out, errDecode
187 }
188 if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
189 initialized = false
190 }
191 if initialized {
192 out.initialized = true
193 }
194 out.n = start - len(b)
195 return out, nil
196}
197
198func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
199 x := exts[int32(num)]
200 xt := x.Type()
201 if xt == nil {
202 var err error
203 xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
204 if err != nil {
205 if err == preg.NotFound {
206 return out, errUnknown
207 }
208 return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
209 }
210 }
211 xi := getExtensionFieldInfo(xt)
212 if xi.funcs.unmarshal == nil {
213 return out, errUnknown
214 }
215 if flags.LazyUnmarshalExtensions {
216 if opts.IsDefault() && x.canLazy(xt) {
217 out, valid := skipExtension(b, xi, num, wtyp, opts)
218 switch valid {
219 case ValidationValid:
220 if out.initialized {
221 x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
222 exts[int32(num)] = x
223 return out, nil
224 }
225 case ValidationInvalid:
226 return out, errDecode
227 case ValidationUnknown:
228 }
229 }
230 }
231 ival := x.Value()
232 if !ival.IsValid() && xi.unmarshalNeedsValue {
233 // Create a new message, list, or map value to fill in.
234 // For enums, create a prototype value to let the unmarshal func know the
235 // concrete type.
236 ival = xt.New()
237 }
238 v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
239 if err != nil {
240 return out, err
241 }
242 if xi.funcs.isInit == nil {
243 out.initialized = true
244 }
245 x.Set(xt, v)
246 exts[int32(num)] = x
247 return out, nil
248}
249
250func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
251 if xi.validation.mi == nil {
252 return out, ValidationUnknown
253 }
254 xi.validation.mi.init()
255 switch xi.validation.typ {
256 case validationTypeMessage:
257 if wtyp != protowire.BytesType {
258 return out, ValidationUnknown
259 }
260 v, n := protowire.ConsumeBytes(b)
261 if n < 0 {
262 return out, ValidationUnknown
263 }
264 out, st := xi.validation.mi.validate(v, 0, opts)
265 out.n = n
266 return out, st
267 case validationTypeGroup:
268 if wtyp != protowire.StartGroupType {
269 return out, ValidationUnknown
270 }
271 out, st := xi.validation.mi.validate(b, num, opts)
272 return out, st
273 default:
274 return out, ValidationUnknown
275 }
276}