Akash Reddy Kankanala | 92dfdf8 | 2025-03-23 22:07:09 +0530 | [diff] [blame^] | 1 | // Copyright 2024 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package impl |
| 6 | |
| 7 | import ( |
| 8 | "fmt" |
| 9 | "math/bits" |
| 10 | "os" |
| 11 | "reflect" |
| 12 | "sort" |
| 13 | "sync/atomic" |
| 14 | |
| 15 | "google.golang.org/protobuf/encoding/protowire" |
| 16 | "google.golang.org/protobuf/internal/errors" |
| 17 | "google.golang.org/protobuf/internal/protolazy" |
| 18 | "google.golang.org/protobuf/reflect/protoreflect" |
| 19 | preg "google.golang.org/protobuf/reflect/protoregistry" |
| 20 | piface "google.golang.org/protobuf/runtime/protoiface" |
| 21 | ) |
| 22 | |
| 23 | var enableLazy int32 = func() int32 { |
| 24 | if os.Getenv("GOPROTODEBUG") == "nolazy" { |
| 25 | return 0 |
| 26 | } |
| 27 | return 1 |
| 28 | }() |
| 29 | |
| 30 | // EnableLazyUnmarshal enables lazy unmarshaling. |
| 31 | func EnableLazyUnmarshal(enable bool) { |
| 32 | if enable { |
| 33 | atomic.StoreInt32(&enableLazy, 1) |
| 34 | return |
| 35 | } |
| 36 | atomic.StoreInt32(&enableLazy, 0) |
| 37 | } |
| 38 | |
| 39 | // LazyEnabled reports whether lazy unmarshalling is currently enabled. |
| 40 | func LazyEnabled() bool { |
| 41 | return atomic.LoadInt32(&enableLazy) != 0 |
| 42 | } |
| 43 | |
| 44 | // UnmarshalField unmarshals a field in a message. |
| 45 | func UnmarshalField(m interface{}, num protowire.Number) { |
| 46 | switch m := m.(type) { |
| 47 | case *messageState: |
| 48 | m.messageInfo().lazyUnmarshal(m.pointer(), num) |
| 49 | case *messageReflectWrapper: |
| 50 | m.messageInfo().lazyUnmarshal(m.pointer(), num) |
| 51 | default: |
| 52 | panic(fmt.Sprintf("unsupported wrapper type %T", m)) |
| 53 | } |
| 54 | } |
| 55 | |
| 56 | func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) { |
| 57 | var f *coderFieldInfo |
| 58 | if int(num) < len(mi.denseCoderFields) { |
| 59 | f = mi.denseCoderFields[num] |
| 60 | } else { |
| 61 | f = mi.coderFields[num] |
| 62 | } |
| 63 | if f == nil { |
| 64 | panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num)) |
| 65 | } |
| 66 | lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr() |
| 67 | start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num)) |
| 68 | if !found && multipleEntries == nil { |
| 69 | panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num)) |
| 70 | } |
| 71 | // The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races. |
| 72 | // Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil. |
| 73 | fp := pointerOfValue(reflect.New(f.ft)) |
| 74 | if multipleEntries != nil { |
| 75 | for _, entry := range multipleEntries { |
| 76 | mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags()) |
| 77 | } |
| 78 | } else { |
| 79 | mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags()) |
| 80 | } |
| 81 | p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem()) |
| 82 | } |
| 83 | |
| 84 | func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error { |
| 85 | opts := lazyUnmarshalOptions |
| 86 | opts.flags |= flags |
| 87 | for len(b) > 0 { |
| 88 | // Parse the tag (field number and wire type). |
| 89 | var tag uint64 |
| 90 | if b[0] < 0x80 { |
| 91 | tag = uint64(b[0]) |
| 92 | b = b[1:] |
| 93 | } else if len(b) >= 2 && b[1] < 128 { |
| 94 | tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 |
| 95 | b = b[2:] |
| 96 | } else { |
| 97 | var n int |
| 98 | tag, n = protowire.ConsumeVarint(b) |
| 99 | if n < 0 { |
| 100 | return errors.New("invalid wire data") |
| 101 | } |
| 102 | b = b[n:] |
| 103 | } |
| 104 | var num protowire.Number |
| 105 | if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { |
| 106 | return errors.New("invalid wire data") |
| 107 | } else { |
| 108 | num = protowire.Number(n) |
| 109 | } |
| 110 | wtyp := protowire.Type(tag & 7) |
| 111 | if num == f.num { |
| 112 | o, err := f.funcs.unmarshal(b, p, wtyp, f, opts) |
| 113 | if err == nil { |
| 114 | b = b[o.n:] |
| 115 | continue |
| 116 | } |
| 117 | if err != errUnknown { |
| 118 | return err |
| 119 | } |
| 120 | } |
| 121 | n := protowire.ConsumeFieldValue(num, wtyp, b) |
| 122 | if n < 0 { |
| 123 | return errors.New("invalid wire data") |
| 124 | } |
| 125 | b = b[n:] |
| 126 | } |
| 127 | return nil |
| 128 | } |
| 129 | |
| 130 | func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { |
| 131 | fmi := f.validation.mi |
| 132 | if fmi == nil { |
| 133 | fd := mi.Desc.Fields().ByNumber(f.num) |
| 134 | if fd == nil { |
| 135 | return out, ValidationUnknown |
| 136 | } |
| 137 | messageName := fd.Message().FullName() |
| 138 | messageType, err := preg.GlobalTypes.FindMessageByName(messageName) |
| 139 | if err != nil { |
| 140 | return out, ValidationUnknown |
| 141 | } |
| 142 | var ok bool |
| 143 | fmi, ok = messageType.(*MessageInfo) |
| 144 | if !ok { |
| 145 | return out, ValidationUnknown |
| 146 | } |
| 147 | } |
| 148 | fmi.init() |
| 149 | switch f.validation.typ { |
| 150 | case validationTypeMessage: |
| 151 | if wtyp != protowire.BytesType { |
| 152 | return out, ValidationWrongWireType |
| 153 | } |
| 154 | v, n := protowire.ConsumeBytes(b) |
| 155 | if n < 0 { |
| 156 | return out, ValidationInvalid |
| 157 | } |
| 158 | out, st := fmi.validate(v, 0, opts) |
| 159 | out.n = n |
| 160 | return out, st |
| 161 | case validationTypeGroup: |
| 162 | if wtyp != protowire.StartGroupType { |
| 163 | return out, ValidationWrongWireType |
| 164 | } |
| 165 | out, st := fmi.validate(b, f.num, opts) |
| 166 | return out, st |
| 167 | default: |
| 168 | return out, ValidationUnknown |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | // unmarshalPointerLazy is similar to unmarshalPointerEager, but it |
| 173 | // specifically handles lazy unmarshalling. it expects lazyOffset and |
| 174 | // presenceOffset to both be valid. |
| 175 | func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { |
| 176 | initialized := true |
| 177 | var requiredMask uint64 |
| 178 | var lazy **protolazy.XXX_lazyUnmarshalInfo |
| 179 | var presence presence |
| 180 | var lazyIndex []protolazy.IndexEntry |
| 181 | var lastNum protowire.Number |
| 182 | outOfOrder := false |
| 183 | lazyDecode := false |
| 184 | presence = p.Apply(mi.presenceOffset).PresenceInfo() |
| 185 | lazy = p.Apply(mi.lazyOffset).LazyInfoPtr() |
| 186 | if !presence.AnyPresent(mi.presenceSize) { |
| 187 | if opts.CanBeLazy() { |
| 188 | // If the message contains existing data, we need to merge into it. |
| 189 | // Lazy unmarshaling doesn't merge, so only enable it when the |
| 190 | // message is empty (has no presence bitmap). |
| 191 | lazyDecode = true |
| 192 | if *lazy == nil { |
| 193 | *lazy = &protolazy.XXX_lazyUnmarshalInfo{} |
| 194 | } |
| 195 | (*lazy).SetUnmarshalFlags(opts.flags) |
| 196 | if !opts.AliasBuffer() { |
| 197 | // Make a copy of the buffer for lazy unmarshaling. |
| 198 | // Set the AliasBuffer flag so recursive unmarshal |
| 199 | // operations reuse the copy. |
| 200 | b = append([]byte{}, b...) |
| 201 | opts.flags |= piface.UnmarshalAliasBuffer |
| 202 | } |
| 203 | (*lazy).SetBuffer(b) |
| 204 | } |
| 205 | } |
| 206 | // Track special handling of lazy fields. |
| 207 | // |
| 208 | // In the common case, all fields are lazyValidateOnly (and lazyFields remains nil). |
| 209 | // In the event that validation for a field fails, this map tracks handling of the field. |
| 210 | type lazyAction uint8 |
| 211 | const ( |
| 212 | lazyValidateOnly lazyAction = iota // validate the field only |
| 213 | lazyUnmarshalNow // eagerly unmarshal the field |
| 214 | lazyUnmarshalLater // unmarshal the field after the message is fully processed |
| 215 | ) |
| 216 | var lazyFields map[*coderFieldInfo]lazyAction |
| 217 | var exts *map[int32]ExtensionField |
| 218 | start := len(b) |
| 219 | pos := 0 |
| 220 | for len(b) > 0 { |
| 221 | // Parse the tag (field number and wire type). |
| 222 | var tag uint64 |
| 223 | if b[0] < 0x80 { |
| 224 | tag = uint64(b[0]) |
| 225 | b = b[1:] |
| 226 | } else if len(b) >= 2 && b[1] < 128 { |
| 227 | tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 |
| 228 | b = b[2:] |
| 229 | } else { |
| 230 | var n int |
| 231 | tag, n = protowire.ConsumeVarint(b) |
| 232 | if n < 0 { |
| 233 | return out, errDecode |
| 234 | } |
| 235 | b = b[n:] |
| 236 | } |
| 237 | var num protowire.Number |
| 238 | if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { |
| 239 | return out, errors.New("invalid field number") |
| 240 | } else { |
| 241 | num = protowire.Number(n) |
| 242 | } |
| 243 | wtyp := protowire.Type(tag & 7) |
| 244 | |
| 245 | if wtyp == protowire.EndGroupType { |
| 246 | if num != groupTag { |
| 247 | return out, errors.New("mismatching end group marker") |
| 248 | } |
| 249 | groupTag = 0 |
| 250 | break |
| 251 | } |
| 252 | |
| 253 | var f *coderFieldInfo |
| 254 | if int(num) < len(mi.denseCoderFields) { |
| 255 | f = mi.denseCoderFields[num] |
| 256 | } else { |
| 257 | f = mi.coderFields[num] |
| 258 | } |
| 259 | var n int |
| 260 | err := errUnknown |
| 261 | discardUnknown := false |
| 262 | Field: |
| 263 | switch { |
| 264 | case f != nil: |
| 265 | if f.funcs.unmarshal == nil { |
| 266 | break |
| 267 | } |
| 268 | if f.isLazy && lazyDecode { |
| 269 | switch { |
| 270 | case lazyFields == nil || lazyFields[f] == lazyValidateOnly: |
| 271 | // Attempt to validate this field and leave it for later lazy unmarshaling. |
| 272 | o, valid := mi.skipField(b, f, wtyp, opts) |
| 273 | switch valid { |
| 274 | case ValidationValid: |
| 275 | // Skip over the valid field and continue. |
| 276 | err = nil |
| 277 | presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) |
| 278 | requiredMask |= f.validation.requiredBit |
| 279 | if !o.initialized { |
| 280 | initialized = false |
| 281 | } |
| 282 | n = o.n |
| 283 | break Field |
| 284 | case ValidationInvalid: |
| 285 | return out, errors.New("invalid proto wire format") |
| 286 | case ValidationWrongWireType: |
| 287 | break Field |
| 288 | case ValidationUnknown: |
| 289 | if lazyFields == nil { |
| 290 | lazyFields = make(map[*coderFieldInfo]lazyAction) |
| 291 | } |
| 292 | if presence.Present(f.presenceIndex) { |
| 293 | // We were unable to determine if the field is valid or not, |
| 294 | // and we've already skipped over at least one instance of this |
| 295 | // field. Clear the presence bit (so if we stop decoding early, |
| 296 | // we don't leave a partially-initialized field around) and flag |
| 297 | // the field for unmarshaling before we return. |
| 298 | presence.ClearPresent(f.presenceIndex) |
| 299 | lazyFields[f] = lazyUnmarshalLater |
| 300 | discardUnknown = true |
| 301 | break Field |
| 302 | } else { |
| 303 | // We were unable to determine if the field is valid or not, |
| 304 | // but this is the first time we've seen it. Flag it as needing |
| 305 | // eager unmarshaling and fall through to the eager unmarshal case below. |
| 306 | lazyFields[f] = lazyUnmarshalNow |
| 307 | } |
| 308 | } |
| 309 | case lazyFields[f] == lazyUnmarshalLater: |
| 310 | // This field will be unmarshaled in a separate pass below. |
| 311 | // Skip over it here. |
| 312 | discardUnknown = true |
| 313 | break Field |
| 314 | default: |
| 315 | // Eagerly unmarshal the field. |
| 316 | } |
| 317 | } |
| 318 | if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) { |
| 319 | if p.Apply(f.offset).AtomicGetPointer().IsNil() { |
| 320 | mi.lazyUnmarshal(p, f.num) |
| 321 | } |
| 322 | } |
| 323 | var o unmarshalOutput |
| 324 | o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) |
| 325 | n = o.n |
| 326 | if err != nil { |
| 327 | break |
| 328 | } |
| 329 | requiredMask |= f.validation.requiredBit |
| 330 | if f.funcs.isInit != nil && !o.initialized { |
| 331 | initialized = false |
| 332 | } |
| 333 | if f.presenceIndex != noPresence { |
| 334 | presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) |
| 335 | } |
| 336 | default: |
| 337 | // Possible extension. |
| 338 | if exts == nil && mi.extensionOffset.IsValid() { |
| 339 | exts = p.Apply(mi.extensionOffset).Extensions() |
| 340 | if *exts == nil { |
| 341 | *exts = make(map[int32]ExtensionField) |
| 342 | } |
| 343 | } |
| 344 | if exts == nil { |
| 345 | break |
| 346 | } |
| 347 | var o unmarshalOutput |
| 348 | o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) |
| 349 | if err != nil { |
| 350 | break |
| 351 | } |
| 352 | n = o.n |
| 353 | if !o.initialized { |
| 354 | initialized = false |
| 355 | } |
| 356 | } |
| 357 | if err != nil { |
| 358 | if err != errUnknown { |
| 359 | return out, err |
| 360 | } |
| 361 | n = protowire.ConsumeFieldValue(num, wtyp, b) |
| 362 | if n < 0 { |
| 363 | return out, errDecode |
| 364 | } |
| 365 | if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { |
| 366 | u := mi.mutableUnknownBytes(p) |
| 367 | *u = protowire.AppendTag(*u, num, wtyp) |
| 368 | *u = append(*u, b[:n]...) |
| 369 | } |
| 370 | } |
| 371 | b = b[n:] |
| 372 | end := start - len(b) |
| 373 | if lazyDecode && f != nil && f.isLazy { |
| 374 | if num != lastNum { |
| 375 | lazyIndex = append(lazyIndex, protolazy.IndexEntry{ |
| 376 | FieldNum: uint32(num), |
| 377 | Start: uint32(pos), |
| 378 | End: uint32(end), |
| 379 | }) |
| 380 | } else { |
| 381 | i := len(lazyIndex) - 1 |
| 382 | lazyIndex[i].End = uint32(end) |
| 383 | lazyIndex[i].MultipleContiguous = true |
| 384 | } |
| 385 | } |
| 386 | if num < lastNum { |
| 387 | outOfOrder = true |
| 388 | } |
| 389 | pos = end |
| 390 | lastNum = num |
| 391 | } |
| 392 | if groupTag != 0 { |
| 393 | return out, errors.New("missing end group marker") |
| 394 | } |
| 395 | if lazyFields != nil { |
| 396 | // Some fields failed validation, and now need to be unmarshaled. |
| 397 | for f, action := range lazyFields { |
| 398 | if action != lazyUnmarshalLater { |
| 399 | continue |
| 400 | } |
| 401 | initialized = false |
| 402 | if *lazy == nil { |
| 403 | *lazy = &protolazy.XXX_lazyUnmarshalInfo{} |
| 404 | } |
| 405 | if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil { |
| 406 | return out, err |
| 407 | } |
| 408 | presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize) |
| 409 | } |
| 410 | } |
| 411 | if lazyDecode { |
| 412 | if outOfOrder { |
| 413 | sort.Slice(lazyIndex, func(i, j int) bool { |
| 414 | return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum || |
| 415 | (lazyIndex[i].FieldNum == lazyIndex[j].FieldNum && |
| 416 | lazyIndex[i].Start < lazyIndex[j].Start) |
| 417 | }) |
| 418 | } |
| 419 | if *lazy == nil { |
| 420 | *lazy = &protolazy.XXX_lazyUnmarshalInfo{} |
| 421 | } |
| 422 | |
| 423 | (*lazy).SetIndex(lazyIndex) |
| 424 | } |
| 425 | if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { |
| 426 | initialized = false |
| 427 | } |
| 428 | if initialized { |
| 429 | out.initialized = true |
| 430 | } |
| 431 | out.n = start - len(b) |
| 432 | return out, nil |
| 433 | } |