blob: ff4d4834bbce178ff2abd79654d7520ad46ce11e [file] [log] [blame]
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +05301// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protolazy contains internal data structures for lazy message decoding.
6package protolazy
7
8import (
9 "fmt"
10 "sort"
11
12 "google.golang.org/protobuf/encoding/protowire"
13 piface "google.golang.org/protobuf/runtime/protoiface"
14)
15
16// IndexEntry is the structure for an index of the fields in a message of a
17// proto (not descending to sub-messages)
18type IndexEntry struct {
19 FieldNum uint32
20 // first byte of this tag/field
21 Start uint32
22 // first byte after a contiguous sequence of bytes for this tag/field, which could
23 // include a single encoding of the field, or multiple encodings for the field
24 End uint32
25 // True if this protobuf segment includes multiple encodings of the field
26 MultipleContiguous bool
27}
28
29// XXX_lazyUnmarshalInfo has information about a particular lazily decoded message
30//
31// Deprecated: Do not use. This will be deleted in the near future.
32type XXX_lazyUnmarshalInfo struct {
33 // Index of fields and their positions in the protobuf for this
34 // message. Make index be a pointer to a slice so it can be updated
35 // atomically. The index pointer is only set once (lazily when/if
36 // the index is first needed), and must always be SET and LOADED
37 // ATOMICALLY.
38 index *[]IndexEntry
39 // The protobuf associated with this lazily decoded message. It is
40 // only set during proto.Unmarshal(). It doesn't need to be set and
41 // loaded atomically, since any simultaneous set (Unmarshal) and read
42 // (during a get) would already be a race in the app code.
43 Protobuf []byte
44 // The flags present when Unmarshal was originally called for this particular message
45 unmarshalFlags piface.UnmarshalInputFlags
46}
47
48// The Buffer and SetBuffer methods let v2/internal/impl interact with
49// XXX_lazyUnmarshalInfo via an interface, to avoid an import cycle.
50
51// Buffer returns the lazy unmarshal buffer.
52//
53// Deprecated: Do not use. This will be deleted in the near future.
54func (lazy *XXX_lazyUnmarshalInfo) Buffer() []byte {
55 return lazy.Protobuf
56}
57
58// SetBuffer sets the lazy unmarshal buffer.
59//
60// Deprecated: Do not use. This will be deleted in the near future.
61func (lazy *XXX_lazyUnmarshalInfo) SetBuffer(b []byte) {
62 lazy.Protobuf = b
63}
64
65// SetUnmarshalFlags is called to set a copy of the original unmarshalInputFlags.
66// The flags should reflect how Unmarshal was called.
67func (lazy *XXX_lazyUnmarshalInfo) SetUnmarshalFlags(f piface.UnmarshalInputFlags) {
68 lazy.unmarshalFlags = f
69}
70
71// UnmarshalFlags returns the original unmarshalInputFlags.
72func (lazy *XXX_lazyUnmarshalInfo) UnmarshalFlags() piface.UnmarshalInputFlags {
73 return lazy.unmarshalFlags
74}
75
76// AllowedPartial returns true if the user originally unmarshalled this message with
77// AllowPartial set to true
78func (lazy *XXX_lazyUnmarshalInfo) AllowedPartial() bool {
79 return (lazy.unmarshalFlags & piface.UnmarshalCheckRequired) == 0
80}
81
82func protoFieldNumber(tag uint32) uint32 {
83 return tag >> 3
84}
85
86// buildIndex builds an index of the specified protobuf, return the index
87// array and an error.
88func buildIndex(buf []byte) ([]IndexEntry, error) {
89 index := make([]IndexEntry, 0, 16)
90 var lastProtoFieldNum uint32
91 var outOfOrder bool
92
93 var r BufferReader = NewBufferReader(buf)
94
95 for !r.Done() {
96 var tag uint32
97 var err error
98 var curPos = r.Pos
99 // INLINED: tag, err = r.DecodeVarint32()
100 {
101 i := r.Pos
102 buf := r.Buf
103
104 if i >= len(buf) {
105 return nil, errOutOfBounds
106 } else if buf[i] < 0x80 {
107 r.Pos++
108 tag = uint32(buf[i])
109 } else if r.Remaining() < 5 {
110 var v uint64
111 v, err = r.DecodeVarintSlow()
112 tag = uint32(v)
113 } else {
114 var v uint32
115 // we already checked the first byte
116 tag = uint32(buf[i]) & 127
117 i++
118
119 v = uint32(buf[i])
120 i++
121 tag |= (v & 127) << 7
122 if v < 128 {
123 goto done
124 }
125
126 v = uint32(buf[i])
127 i++
128 tag |= (v & 127) << 14
129 if v < 128 {
130 goto done
131 }
132
133 v = uint32(buf[i])
134 i++
135 tag |= (v & 127) << 21
136 if v < 128 {
137 goto done
138 }
139
140 v = uint32(buf[i])
141 i++
142 tag |= (v & 127) << 28
143 if v < 128 {
144 goto done
145 }
146
147 return nil, errOutOfBounds
148
149 done:
150 r.Pos = i
151 }
152 }
153 // DONE: tag, err = r.DecodeVarint32()
154
155 fieldNum := protoFieldNumber(tag)
156 if fieldNum < lastProtoFieldNum {
157 outOfOrder = true
158 }
159
160 // Skip the current value -- will skip over an entire group as well.
161 // INLINED: err = r.SkipValue(tag)
162 wireType := tag & 0x7
163 switch protowire.Type(wireType) {
164 case protowire.VarintType:
165 // INLINED: err = r.SkipVarint()
166 i := r.Pos
167
168 if len(r.Buf)-i < 10 {
169 // Use DecodeVarintSlow() to skip while
170 // checking for buffer overflow, but ignore result
171 _, err = r.DecodeVarintSlow()
172 goto out2
173 }
174 if r.Buf[i] < 0x80 {
175 goto out
176 }
177 i++
178
179 if r.Buf[i] < 0x80 {
180 goto out
181 }
182 i++
183
184 if r.Buf[i] < 0x80 {
185 goto out
186 }
187 i++
188
189 if r.Buf[i] < 0x80 {
190 goto out
191 }
192 i++
193
194 if r.Buf[i] < 0x80 {
195 goto out
196 }
197 i++
198
199 if r.Buf[i] < 0x80 {
200 goto out
201 }
202 i++
203
204 if r.Buf[i] < 0x80 {
205 goto out
206 }
207 i++
208
209 if r.Buf[i] < 0x80 {
210 goto out
211 }
212 i++
213
214 if r.Buf[i] < 0x80 {
215 goto out
216 }
217 i++
218
219 if r.Buf[i] < 0x80 {
220 goto out
221 }
222 return nil, errOverflow
223 out:
224 r.Pos = i + 1
225 // DONE: err = r.SkipVarint()
226 case protowire.Fixed64Type:
227 err = r.SkipFixed64()
228 case protowire.BytesType:
229 var n uint32
230 n, err = r.DecodeVarint32()
231 if err == nil {
232 err = r.Skip(int(n))
233 }
234 case protowire.StartGroupType:
235 err = r.SkipGroup(tag)
236 case protowire.Fixed32Type:
237 err = r.SkipFixed32()
238 default:
239 err = fmt.Errorf("Unexpected wire type (%d)", wireType)
240 }
241 // DONE: err = r.SkipValue(tag)
242
243 out2:
244 if err != nil {
245 return nil, err
246 }
247 if fieldNum != lastProtoFieldNum {
248 index = append(index, IndexEntry{FieldNum: fieldNum,
249 Start: uint32(curPos),
250 End: uint32(r.Pos)},
251 )
252 } else {
253 index[len(index)-1].End = uint32(r.Pos)
254 index[len(index)-1].MultipleContiguous = true
255 }
256 lastProtoFieldNum = fieldNum
257 }
258 if outOfOrder {
259 sort.Slice(index, func(i, j int) bool {
260 return index[i].FieldNum < index[j].FieldNum ||
261 (index[i].FieldNum == index[j].FieldNum &&
262 index[i].Start < index[j].Start)
263 })
264 }
265 return index, nil
266}
267
268func (lazy *XXX_lazyUnmarshalInfo) SizeField(num uint32) (size int) {
269 start, end, found, _, multipleEntries := lazy.FindFieldInProto(num)
270 if multipleEntries != nil {
271 for _, entry := range multipleEntries {
272 size += int(entry.End - entry.Start)
273 }
274 return size
275 }
276 if !found {
277 return 0
278 }
279 return int(end - start)
280}
281
282func (lazy *XXX_lazyUnmarshalInfo) AppendField(b []byte, num uint32) ([]byte, bool) {
283 start, end, found, _, multipleEntries := lazy.FindFieldInProto(num)
284 if multipleEntries != nil {
285 for _, entry := range multipleEntries {
286 b = append(b, lazy.Protobuf[entry.Start:entry.End]...)
287 }
288 return b, true
289 }
290 if !found {
291 return nil, false
292 }
293 b = append(b, lazy.Protobuf[start:end]...)
294 return b, true
295}
296
297func (lazy *XXX_lazyUnmarshalInfo) SetIndex(index []IndexEntry) {
298 atomicStoreIndex(&lazy.index, &index)
299}
300
301// FindFieldInProto looks for field fieldNum in lazyUnmarshalInfo information
302// (including protobuf), returns startOffset/endOffset/found.
303func (lazy *XXX_lazyUnmarshalInfo) FindFieldInProto(fieldNum uint32) (start, end uint32, found, multipleContiguous bool, multipleEntries []IndexEntry) {
304 if lazy.Protobuf == nil {
305 // There is no backing protobuf for this message -- it was made from a builder
306 return 0, 0, false, false, nil
307 }
308 index := atomicLoadIndex(&lazy.index)
309 if index == nil {
310 r, err := buildIndex(lazy.Protobuf)
311 if err != nil {
312 panic(fmt.Sprintf("findFieldInfo: error building index when looking for field %d: %v", fieldNum, err))
313 }
314 // lazy.index is a pointer to the slice returned by BuildIndex
315 index = &r
316 atomicStoreIndex(&lazy.index, index)
317 }
318 return lookupField(index, fieldNum)
319}
320
321// lookupField returns the offset at which the indicated field starts using
322// the index, offset immediately after field ends (including all instances of
323// a repeated field), and bools indicating if field was found and if there
324// are multiple encodings of the field in the byte range.
325//
326// To hande the uncommon case where there are repeated encodings for the same
327// field which are not consecutive in the protobuf (so we need to returns
328// multiple start/end offsets), we also return a slice multipleEntries. If
329// multipleEntries is non-nil, then multiple entries were found, and the
330// values in the slice should be used, rather than start/end/found.
331func lookupField(indexp *[]IndexEntry, fieldNum uint32) (start, end uint32, found bool, multipleContiguous bool, multipleEntries []IndexEntry) {
332 // The pointer indexp to the index was already loaded atomically.
333 // The slice is uniquely associated with the pointer, so it doesn't
334 // need to be loaded atomically.
335 index := *indexp
336 for i, entry := range index {
337 if fieldNum == entry.FieldNum {
338 if i < len(index)-1 && entry.FieldNum == index[i+1].FieldNum {
339 // Handle the uncommon case where there are
340 // repeated entries for the same field which
341 // are not contiguous in the protobuf.
342 multiple := make([]IndexEntry, 1, 2)
343 multiple[0] = IndexEntry{fieldNum, entry.Start, entry.End, entry.MultipleContiguous}
344 i++
345 for i < len(index) && index[i].FieldNum == fieldNum {
346 multiple = append(multiple, IndexEntry{fieldNum, index[i].Start, index[i].End, index[i].MultipleContiguous})
347 i++
348 }
349 return 0, 0, false, false, multiple
350
351 }
352 return entry.Start, entry.End, true, entry.MultipleContiguous, nil
353 }
354 if fieldNum < entry.FieldNum {
355 return 0, 0, false, false, nil
356 }
357 }
358 return 0, 0, false, false, nil
359}