Akash Reddy Kankanala | 92dfdf8 | 2025-03-23 22:07:09 +0530 | [diff] [blame^] | 1 | // Copyright 2024 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | // Package protolazy contains internal data structures for lazy message decoding. |
| 6 | package protolazy |
| 7 | |
| 8 | import ( |
| 9 | "fmt" |
| 10 | "sort" |
| 11 | |
| 12 | "google.golang.org/protobuf/encoding/protowire" |
| 13 | piface "google.golang.org/protobuf/runtime/protoiface" |
| 14 | ) |
| 15 | |
| 16 | // IndexEntry is the structure for an index of the fields in a message of a |
| 17 | // proto (not descending to sub-messages) |
| 18 | type IndexEntry struct { |
| 19 | FieldNum uint32 |
| 20 | // first byte of this tag/field |
| 21 | Start uint32 |
| 22 | // first byte after a contiguous sequence of bytes for this tag/field, which could |
| 23 | // include a single encoding of the field, or multiple encodings for the field |
| 24 | End uint32 |
| 25 | // True if this protobuf segment includes multiple encodings of the field |
| 26 | MultipleContiguous bool |
| 27 | } |
| 28 | |
| 29 | // XXX_lazyUnmarshalInfo has information about a particular lazily decoded message |
| 30 | // |
| 31 | // Deprecated: Do not use. This will be deleted in the near future. |
| 32 | type XXX_lazyUnmarshalInfo struct { |
| 33 | // Index of fields and their positions in the protobuf for this |
| 34 | // message. Make index be a pointer to a slice so it can be updated |
| 35 | // atomically. The index pointer is only set once (lazily when/if |
| 36 | // the index is first needed), and must always be SET and LOADED |
| 37 | // ATOMICALLY. |
| 38 | index *[]IndexEntry |
| 39 | // The protobuf associated with this lazily decoded message. It is |
| 40 | // only set during proto.Unmarshal(). It doesn't need to be set and |
| 41 | // loaded atomically, since any simultaneous set (Unmarshal) and read |
| 42 | // (during a get) would already be a race in the app code. |
| 43 | Protobuf []byte |
| 44 | // The flags present when Unmarshal was originally called for this particular message |
| 45 | unmarshalFlags piface.UnmarshalInputFlags |
| 46 | } |
| 47 | |
| 48 | // The Buffer and SetBuffer methods let v2/internal/impl interact with |
| 49 | // XXX_lazyUnmarshalInfo via an interface, to avoid an import cycle. |
| 50 | |
| 51 | // Buffer returns the lazy unmarshal buffer. |
| 52 | // |
| 53 | // Deprecated: Do not use. This will be deleted in the near future. |
| 54 | func (lazy *XXX_lazyUnmarshalInfo) Buffer() []byte { |
| 55 | return lazy.Protobuf |
| 56 | } |
| 57 | |
| 58 | // SetBuffer sets the lazy unmarshal buffer. |
| 59 | // |
| 60 | // Deprecated: Do not use. This will be deleted in the near future. |
| 61 | func (lazy *XXX_lazyUnmarshalInfo) SetBuffer(b []byte) { |
| 62 | lazy.Protobuf = b |
| 63 | } |
| 64 | |
| 65 | // SetUnmarshalFlags is called to set a copy of the original unmarshalInputFlags. |
| 66 | // The flags should reflect how Unmarshal was called. |
| 67 | func (lazy *XXX_lazyUnmarshalInfo) SetUnmarshalFlags(f piface.UnmarshalInputFlags) { |
| 68 | lazy.unmarshalFlags = f |
| 69 | } |
| 70 | |
| 71 | // UnmarshalFlags returns the original unmarshalInputFlags. |
| 72 | func (lazy *XXX_lazyUnmarshalInfo) UnmarshalFlags() piface.UnmarshalInputFlags { |
| 73 | return lazy.unmarshalFlags |
| 74 | } |
| 75 | |
| 76 | // AllowedPartial returns true if the user originally unmarshalled this message with |
| 77 | // AllowPartial set to true |
| 78 | func (lazy *XXX_lazyUnmarshalInfo) AllowedPartial() bool { |
| 79 | return (lazy.unmarshalFlags & piface.UnmarshalCheckRequired) == 0 |
| 80 | } |
| 81 | |
| 82 | func protoFieldNumber(tag uint32) uint32 { |
| 83 | return tag >> 3 |
| 84 | } |
| 85 | |
| 86 | // buildIndex builds an index of the specified protobuf, return the index |
| 87 | // array and an error. |
| 88 | func buildIndex(buf []byte) ([]IndexEntry, error) { |
| 89 | index := make([]IndexEntry, 0, 16) |
| 90 | var lastProtoFieldNum uint32 |
| 91 | var outOfOrder bool |
| 92 | |
| 93 | var r BufferReader = NewBufferReader(buf) |
| 94 | |
| 95 | for !r.Done() { |
| 96 | var tag uint32 |
| 97 | var err error |
| 98 | var curPos = r.Pos |
| 99 | // INLINED: tag, err = r.DecodeVarint32() |
| 100 | { |
| 101 | i := r.Pos |
| 102 | buf := r.Buf |
| 103 | |
| 104 | if i >= len(buf) { |
| 105 | return nil, errOutOfBounds |
| 106 | } else if buf[i] < 0x80 { |
| 107 | r.Pos++ |
| 108 | tag = uint32(buf[i]) |
| 109 | } else if r.Remaining() < 5 { |
| 110 | var v uint64 |
| 111 | v, err = r.DecodeVarintSlow() |
| 112 | tag = uint32(v) |
| 113 | } else { |
| 114 | var v uint32 |
| 115 | // we already checked the first byte |
| 116 | tag = uint32(buf[i]) & 127 |
| 117 | i++ |
| 118 | |
| 119 | v = uint32(buf[i]) |
| 120 | i++ |
| 121 | tag |= (v & 127) << 7 |
| 122 | if v < 128 { |
| 123 | goto done |
| 124 | } |
| 125 | |
| 126 | v = uint32(buf[i]) |
| 127 | i++ |
| 128 | tag |= (v & 127) << 14 |
| 129 | if v < 128 { |
| 130 | goto done |
| 131 | } |
| 132 | |
| 133 | v = uint32(buf[i]) |
| 134 | i++ |
| 135 | tag |= (v & 127) << 21 |
| 136 | if v < 128 { |
| 137 | goto done |
| 138 | } |
| 139 | |
| 140 | v = uint32(buf[i]) |
| 141 | i++ |
| 142 | tag |= (v & 127) << 28 |
| 143 | if v < 128 { |
| 144 | goto done |
| 145 | } |
| 146 | |
| 147 | return nil, errOutOfBounds |
| 148 | |
| 149 | done: |
| 150 | r.Pos = i |
| 151 | } |
| 152 | } |
| 153 | // DONE: tag, err = r.DecodeVarint32() |
| 154 | |
| 155 | fieldNum := protoFieldNumber(tag) |
| 156 | if fieldNum < lastProtoFieldNum { |
| 157 | outOfOrder = true |
| 158 | } |
| 159 | |
| 160 | // Skip the current value -- will skip over an entire group as well. |
| 161 | // INLINED: err = r.SkipValue(tag) |
| 162 | wireType := tag & 0x7 |
| 163 | switch protowire.Type(wireType) { |
| 164 | case protowire.VarintType: |
| 165 | // INLINED: err = r.SkipVarint() |
| 166 | i := r.Pos |
| 167 | |
| 168 | if len(r.Buf)-i < 10 { |
| 169 | // Use DecodeVarintSlow() to skip while |
| 170 | // checking for buffer overflow, but ignore result |
| 171 | _, err = r.DecodeVarintSlow() |
| 172 | goto out2 |
| 173 | } |
| 174 | if r.Buf[i] < 0x80 { |
| 175 | goto out |
| 176 | } |
| 177 | i++ |
| 178 | |
| 179 | if r.Buf[i] < 0x80 { |
| 180 | goto out |
| 181 | } |
| 182 | i++ |
| 183 | |
| 184 | if r.Buf[i] < 0x80 { |
| 185 | goto out |
| 186 | } |
| 187 | i++ |
| 188 | |
| 189 | if r.Buf[i] < 0x80 { |
| 190 | goto out |
| 191 | } |
| 192 | i++ |
| 193 | |
| 194 | if r.Buf[i] < 0x80 { |
| 195 | goto out |
| 196 | } |
| 197 | i++ |
| 198 | |
| 199 | if r.Buf[i] < 0x80 { |
| 200 | goto out |
| 201 | } |
| 202 | i++ |
| 203 | |
| 204 | if r.Buf[i] < 0x80 { |
| 205 | goto out |
| 206 | } |
| 207 | i++ |
| 208 | |
| 209 | if r.Buf[i] < 0x80 { |
| 210 | goto out |
| 211 | } |
| 212 | i++ |
| 213 | |
| 214 | if r.Buf[i] < 0x80 { |
| 215 | goto out |
| 216 | } |
| 217 | i++ |
| 218 | |
| 219 | if r.Buf[i] < 0x80 { |
| 220 | goto out |
| 221 | } |
| 222 | return nil, errOverflow |
| 223 | out: |
| 224 | r.Pos = i + 1 |
| 225 | // DONE: err = r.SkipVarint() |
| 226 | case protowire.Fixed64Type: |
| 227 | err = r.SkipFixed64() |
| 228 | case protowire.BytesType: |
| 229 | var n uint32 |
| 230 | n, err = r.DecodeVarint32() |
| 231 | if err == nil { |
| 232 | err = r.Skip(int(n)) |
| 233 | } |
| 234 | case protowire.StartGroupType: |
| 235 | err = r.SkipGroup(tag) |
| 236 | case protowire.Fixed32Type: |
| 237 | err = r.SkipFixed32() |
| 238 | default: |
| 239 | err = fmt.Errorf("Unexpected wire type (%d)", wireType) |
| 240 | } |
| 241 | // DONE: err = r.SkipValue(tag) |
| 242 | |
| 243 | out2: |
| 244 | if err != nil { |
| 245 | return nil, err |
| 246 | } |
| 247 | if fieldNum != lastProtoFieldNum { |
| 248 | index = append(index, IndexEntry{FieldNum: fieldNum, |
| 249 | Start: uint32(curPos), |
| 250 | End: uint32(r.Pos)}, |
| 251 | ) |
| 252 | } else { |
| 253 | index[len(index)-1].End = uint32(r.Pos) |
| 254 | index[len(index)-1].MultipleContiguous = true |
| 255 | } |
| 256 | lastProtoFieldNum = fieldNum |
| 257 | } |
| 258 | if outOfOrder { |
| 259 | sort.Slice(index, func(i, j int) bool { |
| 260 | return index[i].FieldNum < index[j].FieldNum || |
| 261 | (index[i].FieldNum == index[j].FieldNum && |
| 262 | index[i].Start < index[j].Start) |
| 263 | }) |
| 264 | } |
| 265 | return index, nil |
| 266 | } |
| 267 | |
| 268 | func (lazy *XXX_lazyUnmarshalInfo) SizeField(num uint32) (size int) { |
| 269 | start, end, found, _, multipleEntries := lazy.FindFieldInProto(num) |
| 270 | if multipleEntries != nil { |
| 271 | for _, entry := range multipleEntries { |
| 272 | size += int(entry.End - entry.Start) |
| 273 | } |
| 274 | return size |
| 275 | } |
| 276 | if !found { |
| 277 | return 0 |
| 278 | } |
| 279 | return int(end - start) |
| 280 | } |
| 281 | |
| 282 | func (lazy *XXX_lazyUnmarshalInfo) AppendField(b []byte, num uint32) ([]byte, bool) { |
| 283 | start, end, found, _, multipleEntries := lazy.FindFieldInProto(num) |
| 284 | if multipleEntries != nil { |
| 285 | for _, entry := range multipleEntries { |
| 286 | b = append(b, lazy.Protobuf[entry.Start:entry.End]...) |
| 287 | } |
| 288 | return b, true |
| 289 | } |
| 290 | if !found { |
| 291 | return nil, false |
| 292 | } |
| 293 | b = append(b, lazy.Protobuf[start:end]...) |
| 294 | return b, true |
| 295 | } |
| 296 | |
| 297 | func (lazy *XXX_lazyUnmarshalInfo) SetIndex(index []IndexEntry) { |
| 298 | atomicStoreIndex(&lazy.index, &index) |
| 299 | } |
| 300 | |
| 301 | // FindFieldInProto looks for field fieldNum in lazyUnmarshalInfo information |
| 302 | // (including protobuf), returns startOffset/endOffset/found. |
| 303 | func (lazy *XXX_lazyUnmarshalInfo) FindFieldInProto(fieldNum uint32) (start, end uint32, found, multipleContiguous bool, multipleEntries []IndexEntry) { |
| 304 | if lazy.Protobuf == nil { |
| 305 | // There is no backing protobuf for this message -- it was made from a builder |
| 306 | return 0, 0, false, false, nil |
| 307 | } |
| 308 | index := atomicLoadIndex(&lazy.index) |
| 309 | if index == nil { |
| 310 | r, err := buildIndex(lazy.Protobuf) |
| 311 | if err != nil { |
| 312 | panic(fmt.Sprintf("findFieldInfo: error building index when looking for field %d: %v", fieldNum, err)) |
| 313 | } |
| 314 | // lazy.index is a pointer to the slice returned by BuildIndex |
| 315 | index = &r |
| 316 | atomicStoreIndex(&lazy.index, index) |
| 317 | } |
| 318 | return lookupField(index, fieldNum) |
| 319 | } |
| 320 | |
| 321 | // lookupField returns the offset at which the indicated field starts using |
| 322 | // the index, offset immediately after field ends (including all instances of |
| 323 | // a repeated field), and bools indicating if field was found and if there |
| 324 | // are multiple encodings of the field in the byte range. |
| 325 | // |
| 326 | // To hande the uncommon case where there are repeated encodings for the same |
| 327 | // field which are not consecutive in the protobuf (so we need to returns |
| 328 | // multiple start/end offsets), we also return a slice multipleEntries. If |
| 329 | // multipleEntries is non-nil, then multiple entries were found, and the |
| 330 | // values in the slice should be used, rather than start/end/found. |
| 331 | func lookupField(indexp *[]IndexEntry, fieldNum uint32) (start, end uint32, found bool, multipleContiguous bool, multipleEntries []IndexEntry) { |
| 332 | // The pointer indexp to the index was already loaded atomically. |
| 333 | // The slice is uniquely associated with the pointer, so it doesn't |
| 334 | // need to be loaded atomically. |
| 335 | index := *indexp |
| 336 | for i, entry := range index { |
| 337 | if fieldNum == entry.FieldNum { |
| 338 | if i < len(index)-1 && entry.FieldNum == index[i+1].FieldNum { |
| 339 | // Handle the uncommon case where there are |
| 340 | // repeated entries for the same field which |
| 341 | // are not contiguous in the protobuf. |
| 342 | multiple := make([]IndexEntry, 1, 2) |
| 343 | multiple[0] = IndexEntry{fieldNum, entry.Start, entry.End, entry.MultipleContiguous} |
| 344 | i++ |
| 345 | for i < len(index) && index[i].FieldNum == fieldNum { |
| 346 | multiple = append(multiple, IndexEntry{fieldNum, index[i].Start, index[i].End, index[i].MultipleContiguous}) |
| 347 | i++ |
| 348 | } |
| 349 | return 0, 0, false, false, multiple |
| 350 | |
| 351 | } |
| 352 | return entry.Start, entry.End, true, entry.MultipleContiguous, nil |
| 353 | } |
| 354 | if fieldNum < entry.FieldNum { |
| 355 | return 0, 0, false, false, nil |
| 356 | } |
| 357 | } |
| 358 | return 0, 0, false, false, nil |
| 359 | } |