blob: c7de31e243eb1dfac7a13504c2a334468719a7e1 [file] [log] [blame]
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +05301// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math/bits"
10 "os"
11 "reflect"
12 "sort"
13 "sync/atomic"
14
15 "google.golang.org/protobuf/encoding/protowire"
16 "google.golang.org/protobuf/internal/errors"
17 "google.golang.org/protobuf/internal/protolazy"
18 "google.golang.org/protobuf/reflect/protoreflect"
19 preg "google.golang.org/protobuf/reflect/protoregistry"
20 piface "google.golang.org/protobuf/runtime/protoiface"
21)
22
23var enableLazy int32 = func() int32 {
24 if os.Getenv("GOPROTODEBUG") == "nolazy" {
25 return 0
26 }
27 return 1
28}()
29
30// EnableLazyUnmarshal enables lazy unmarshaling.
31func EnableLazyUnmarshal(enable bool) {
32 if enable {
33 atomic.StoreInt32(&enableLazy, 1)
34 return
35 }
36 atomic.StoreInt32(&enableLazy, 0)
37}
38
39// LazyEnabled reports whether lazy unmarshalling is currently enabled.
40func LazyEnabled() bool {
41 return atomic.LoadInt32(&enableLazy) != 0
42}
43
44// UnmarshalField unmarshals a field in a message.
45func UnmarshalField(m interface{}, num protowire.Number) {
46 switch m := m.(type) {
47 case *messageState:
48 m.messageInfo().lazyUnmarshal(m.pointer(), num)
49 case *messageReflectWrapper:
50 m.messageInfo().lazyUnmarshal(m.pointer(), num)
51 default:
52 panic(fmt.Sprintf("unsupported wrapper type %T", m))
53 }
54}
55
56func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) {
57 var f *coderFieldInfo
58 if int(num) < len(mi.denseCoderFields) {
59 f = mi.denseCoderFields[num]
60 } else {
61 f = mi.coderFields[num]
62 }
63 if f == nil {
64 panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num))
65 }
66 lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr()
67 start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num))
68 if !found && multipleEntries == nil {
69 panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num))
70 }
71 // The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races.
72 // Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil.
73 fp := pointerOfValue(reflect.New(f.ft))
74 if multipleEntries != nil {
75 for _, entry := range multipleEntries {
76 mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags())
77 }
78 } else {
79 mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags())
80 }
81 p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem())
82}
83
84func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error {
85 opts := lazyUnmarshalOptions
86 opts.flags |= flags
87 for len(b) > 0 {
88 // Parse the tag (field number and wire type).
89 var tag uint64
90 if b[0] < 0x80 {
91 tag = uint64(b[0])
92 b = b[1:]
93 } else if len(b) >= 2 && b[1] < 128 {
94 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
95 b = b[2:]
96 } else {
97 var n int
98 tag, n = protowire.ConsumeVarint(b)
99 if n < 0 {
100 return errors.New("invalid wire data")
101 }
102 b = b[n:]
103 }
104 var num protowire.Number
105 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
106 return errors.New("invalid wire data")
107 } else {
108 num = protowire.Number(n)
109 }
110 wtyp := protowire.Type(tag & 7)
111 if num == f.num {
112 o, err := f.funcs.unmarshal(b, p, wtyp, f, opts)
113 if err == nil {
114 b = b[o.n:]
115 continue
116 }
117 if err != errUnknown {
118 return err
119 }
120 }
121 n := protowire.ConsumeFieldValue(num, wtyp, b)
122 if n < 0 {
123 return errors.New("invalid wire data")
124 }
125 b = b[n:]
126 }
127 return nil
128}
129
130func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
131 fmi := f.validation.mi
132 if fmi == nil {
133 fd := mi.Desc.Fields().ByNumber(f.num)
134 if fd == nil {
135 return out, ValidationUnknown
136 }
137 messageName := fd.Message().FullName()
138 messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
139 if err != nil {
140 return out, ValidationUnknown
141 }
142 var ok bool
143 fmi, ok = messageType.(*MessageInfo)
144 if !ok {
145 return out, ValidationUnknown
146 }
147 }
148 fmi.init()
149 switch f.validation.typ {
150 case validationTypeMessage:
151 if wtyp != protowire.BytesType {
152 return out, ValidationWrongWireType
153 }
154 v, n := protowire.ConsumeBytes(b)
155 if n < 0 {
156 return out, ValidationInvalid
157 }
158 out, st := fmi.validate(v, 0, opts)
159 out.n = n
160 return out, st
161 case validationTypeGroup:
162 if wtyp != protowire.StartGroupType {
163 return out, ValidationWrongWireType
164 }
165 out, st := fmi.validate(b, f.num, opts)
166 return out, st
167 default:
168 return out, ValidationUnknown
169 }
170}
171
172// unmarshalPointerLazy is similar to unmarshalPointerEager, but it
173// specifically handles lazy unmarshalling. it expects lazyOffset and
174// presenceOffset to both be valid.
175func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
176 initialized := true
177 var requiredMask uint64
178 var lazy **protolazy.XXX_lazyUnmarshalInfo
179 var presence presence
180 var lazyIndex []protolazy.IndexEntry
181 var lastNum protowire.Number
182 outOfOrder := false
183 lazyDecode := false
184 presence = p.Apply(mi.presenceOffset).PresenceInfo()
185 lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
186 if !presence.AnyPresent(mi.presenceSize) {
187 if opts.CanBeLazy() {
188 // If the message contains existing data, we need to merge into it.
189 // Lazy unmarshaling doesn't merge, so only enable it when the
190 // message is empty (has no presence bitmap).
191 lazyDecode = true
192 if *lazy == nil {
193 *lazy = &protolazy.XXX_lazyUnmarshalInfo{}
194 }
195 (*lazy).SetUnmarshalFlags(opts.flags)
196 if !opts.AliasBuffer() {
197 // Make a copy of the buffer for lazy unmarshaling.
198 // Set the AliasBuffer flag so recursive unmarshal
199 // operations reuse the copy.
200 b = append([]byte{}, b...)
201 opts.flags |= piface.UnmarshalAliasBuffer
202 }
203 (*lazy).SetBuffer(b)
204 }
205 }
206 // Track special handling of lazy fields.
207 //
208 // In the common case, all fields are lazyValidateOnly (and lazyFields remains nil).
209 // In the event that validation for a field fails, this map tracks handling of the field.
210 type lazyAction uint8
211 const (
212 lazyValidateOnly lazyAction = iota // validate the field only
213 lazyUnmarshalNow // eagerly unmarshal the field
214 lazyUnmarshalLater // unmarshal the field after the message is fully processed
215 )
216 var lazyFields map[*coderFieldInfo]lazyAction
217 var exts *map[int32]ExtensionField
218 start := len(b)
219 pos := 0
220 for len(b) > 0 {
221 // Parse the tag (field number and wire type).
222 var tag uint64
223 if b[0] < 0x80 {
224 tag = uint64(b[0])
225 b = b[1:]
226 } else if len(b) >= 2 && b[1] < 128 {
227 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
228 b = b[2:]
229 } else {
230 var n int
231 tag, n = protowire.ConsumeVarint(b)
232 if n < 0 {
233 return out, errDecode
234 }
235 b = b[n:]
236 }
237 var num protowire.Number
238 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
239 return out, errors.New("invalid field number")
240 } else {
241 num = protowire.Number(n)
242 }
243 wtyp := protowire.Type(tag & 7)
244
245 if wtyp == protowire.EndGroupType {
246 if num != groupTag {
247 return out, errors.New("mismatching end group marker")
248 }
249 groupTag = 0
250 break
251 }
252
253 var f *coderFieldInfo
254 if int(num) < len(mi.denseCoderFields) {
255 f = mi.denseCoderFields[num]
256 } else {
257 f = mi.coderFields[num]
258 }
259 var n int
260 err := errUnknown
261 discardUnknown := false
262 Field:
263 switch {
264 case f != nil:
265 if f.funcs.unmarshal == nil {
266 break
267 }
268 if f.isLazy && lazyDecode {
269 switch {
270 case lazyFields == nil || lazyFields[f] == lazyValidateOnly:
271 // Attempt to validate this field and leave it for later lazy unmarshaling.
272 o, valid := mi.skipField(b, f, wtyp, opts)
273 switch valid {
274 case ValidationValid:
275 // Skip over the valid field and continue.
276 err = nil
277 presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
278 requiredMask |= f.validation.requiredBit
279 if !o.initialized {
280 initialized = false
281 }
282 n = o.n
283 break Field
284 case ValidationInvalid:
285 return out, errors.New("invalid proto wire format")
286 case ValidationWrongWireType:
287 break Field
288 case ValidationUnknown:
289 if lazyFields == nil {
290 lazyFields = make(map[*coderFieldInfo]lazyAction)
291 }
292 if presence.Present(f.presenceIndex) {
293 // We were unable to determine if the field is valid or not,
294 // and we've already skipped over at least one instance of this
295 // field. Clear the presence bit (so if we stop decoding early,
296 // we don't leave a partially-initialized field around) and flag
297 // the field for unmarshaling before we return.
298 presence.ClearPresent(f.presenceIndex)
299 lazyFields[f] = lazyUnmarshalLater
300 discardUnknown = true
301 break Field
302 } else {
303 // We were unable to determine if the field is valid or not,
304 // but this is the first time we've seen it. Flag it as needing
305 // eager unmarshaling and fall through to the eager unmarshal case below.
306 lazyFields[f] = lazyUnmarshalNow
307 }
308 }
309 case lazyFields[f] == lazyUnmarshalLater:
310 // This field will be unmarshaled in a separate pass below.
311 // Skip over it here.
312 discardUnknown = true
313 break Field
314 default:
315 // Eagerly unmarshal the field.
316 }
317 }
318 if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) {
319 if p.Apply(f.offset).AtomicGetPointer().IsNil() {
320 mi.lazyUnmarshal(p, f.num)
321 }
322 }
323 var o unmarshalOutput
324 o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
325 n = o.n
326 if err != nil {
327 break
328 }
329 requiredMask |= f.validation.requiredBit
330 if f.funcs.isInit != nil && !o.initialized {
331 initialized = false
332 }
333 if f.presenceIndex != noPresence {
334 presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
335 }
336 default:
337 // Possible extension.
338 if exts == nil && mi.extensionOffset.IsValid() {
339 exts = p.Apply(mi.extensionOffset).Extensions()
340 if *exts == nil {
341 *exts = make(map[int32]ExtensionField)
342 }
343 }
344 if exts == nil {
345 break
346 }
347 var o unmarshalOutput
348 o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
349 if err != nil {
350 break
351 }
352 n = o.n
353 if !o.initialized {
354 initialized = false
355 }
356 }
357 if err != nil {
358 if err != errUnknown {
359 return out, err
360 }
361 n = protowire.ConsumeFieldValue(num, wtyp, b)
362 if n < 0 {
363 return out, errDecode
364 }
365 if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
366 u := mi.mutableUnknownBytes(p)
367 *u = protowire.AppendTag(*u, num, wtyp)
368 *u = append(*u, b[:n]...)
369 }
370 }
371 b = b[n:]
372 end := start - len(b)
373 if lazyDecode && f != nil && f.isLazy {
374 if num != lastNum {
375 lazyIndex = append(lazyIndex, protolazy.IndexEntry{
376 FieldNum: uint32(num),
377 Start: uint32(pos),
378 End: uint32(end),
379 })
380 } else {
381 i := len(lazyIndex) - 1
382 lazyIndex[i].End = uint32(end)
383 lazyIndex[i].MultipleContiguous = true
384 }
385 }
386 if num < lastNum {
387 outOfOrder = true
388 }
389 pos = end
390 lastNum = num
391 }
392 if groupTag != 0 {
393 return out, errors.New("missing end group marker")
394 }
395 if lazyFields != nil {
396 // Some fields failed validation, and now need to be unmarshaled.
397 for f, action := range lazyFields {
398 if action != lazyUnmarshalLater {
399 continue
400 }
401 initialized = false
402 if *lazy == nil {
403 *lazy = &protolazy.XXX_lazyUnmarshalInfo{}
404 }
405 if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil {
406 return out, err
407 }
408 presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
409 }
410 }
411 if lazyDecode {
412 if outOfOrder {
413 sort.Slice(lazyIndex, func(i, j int) bool {
414 return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum ||
415 (lazyIndex[i].FieldNum == lazyIndex[j].FieldNum &&
416 lazyIndex[i].Start < lazyIndex[j].Start)
417 })
418 }
419 if *lazy == nil {
420 *lazy = &protolazy.XXX_lazyUnmarshalInfo{}
421 }
422
423 (*lazy).SetIndex(lazyIndex)
424 }
425 if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
426 initialized = false
427 }
428 if initialized {
429 out.initialized = true
430 }
431 out.n = start - len(b)
432 return out, nil
433}