blob: c1245fef487651a967fbf27bfe03344321c7a705 [file] [log] [blame]
David K. Bainbridgebd6b2882021-08-26 13:31:02 +00001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "reflect"
9 "sort"
10
11 "google.golang.org/protobuf/encoding/protowire"
12 "google.golang.org/protobuf/internal/genid"
13 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16type mapInfo struct {
17 goType reflect.Type
18 keyWiretag uint64
19 valWiretag uint64
20 keyFuncs valueCoderFuncs
21 valFuncs valueCoderFuncs
22 keyZero pref.Value
23 keyKind pref.Kind
24 conv *mapConverter
25}
26
27func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
28 // TODO: Consider generating specialized map coders.
29 keyField := fd.MapKey()
30 valField := fd.MapValue()
31 keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
32 valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
33 keyFuncs := encoderFuncsForValue(keyField)
34 valFuncs := encoderFuncsForValue(valField)
35 conv := newMapConverter(ft, fd)
36
37 mapi := &mapInfo{
38 goType: ft,
39 keyWiretag: keyWiretag,
40 valWiretag: valWiretag,
41 keyFuncs: keyFuncs,
42 valFuncs: valFuncs,
43 keyZero: keyField.Default(),
44 keyKind: keyField.Kind(),
45 conv: conv,
46 }
47 if valField.Kind() == pref.MessageKind {
48 valueMessage = getMessageInfo(ft.Elem())
49 }
50
51 funcs = pointerCoderFuncs{
52 size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
53 return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
54 },
55 marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
56 return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
57 },
58 unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
59 mp := p.AsValueOf(ft)
60 if mp.Elem().IsNil() {
61 mp.Elem().Set(reflect.MakeMap(mapi.goType))
62 }
63 if f.mi == nil {
64 return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
65 } else {
66 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
67 }
68 },
69 }
70 switch valField.Kind() {
71 case pref.MessageKind:
72 funcs.merge = mergeMapOfMessage
73 case pref.BytesKind:
74 funcs.merge = mergeMapOfBytes
75 default:
76 funcs.merge = mergeMap
77 }
78 if valFuncs.isInit != nil {
79 funcs.isInit = func(p pointer, f *coderFieldInfo) error {
80 return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
81 }
82 }
83 return valueMessage, funcs
84}
85
86const (
87 mapKeyTagSize = 1 // field 1, tag size 1.
88 mapValTagSize = 1 // field 2, tag size 2.
89)
90
91func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
92 if mapv.Len() == 0 {
93 return 0
94 }
95 n := 0
96 iter := mapRange(mapv)
97 for iter.Next() {
98 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
99 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
100 var valSize int
101 value := mapi.conv.valConv.PBValueOf(iter.Value())
102 if f.mi == nil {
103 valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
104 } else {
105 p := pointerOfValue(iter.Value())
106 valSize += mapValTagSize
107 valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
108 }
109 n += f.tagsize + protowire.SizeBytes(keySize+valSize)
110 }
111 return n
112}
113
114func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
115 if wtyp != protowire.BytesType {
116 return out, errUnknown
117 }
118 b, n := protowire.ConsumeBytes(b)
119 if n < 0 {
120 return out, errDecode
121 }
122 var (
123 key = mapi.keyZero
124 val = mapi.conv.valConv.New()
125 )
126 for len(b) > 0 {
127 num, wtyp, n := protowire.ConsumeTag(b)
128 if n < 0 {
129 return out, errDecode
130 }
131 if num > protowire.MaxValidNumber {
132 return out, errDecode
133 }
134 b = b[n:]
135 err := errUnknown
136 switch num {
137 case genid.MapEntry_Key_field_number:
138 var v pref.Value
139 var o unmarshalOutput
140 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
141 if err != nil {
142 break
143 }
144 key = v
145 n = o.n
146 case genid.MapEntry_Value_field_number:
147 var v pref.Value
148 var o unmarshalOutput
149 v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
150 if err != nil {
151 break
152 }
153 val = v
154 n = o.n
155 }
156 if err == errUnknown {
157 n = protowire.ConsumeFieldValue(num, wtyp, b)
158 if n < 0 {
159 return out, errDecode
160 }
161 } else if err != nil {
162 return out, err
163 }
164 b = b[n:]
165 }
166 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
167 out.n = n
168 return out, nil
169}
170
171func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
172 if wtyp != protowire.BytesType {
173 return out, errUnknown
174 }
175 b, n := protowire.ConsumeBytes(b)
176 if n < 0 {
177 return out, errDecode
178 }
179 var (
180 key = mapi.keyZero
181 val = reflect.New(f.mi.GoReflectType.Elem())
182 )
183 for len(b) > 0 {
184 num, wtyp, n := protowire.ConsumeTag(b)
185 if n < 0 {
186 return out, errDecode
187 }
188 if num > protowire.MaxValidNumber {
189 return out, errDecode
190 }
191 b = b[n:]
192 err := errUnknown
193 switch num {
194 case 1:
195 var v pref.Value
196 var o unmarshalOutput
197 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
198 if err != nil {
199 break
200 }
201 key = v
202 n = o.n
203 case 2:
204 if wtyp != protowire.BytesType {
205 break
206 }
207 var v []byte
208 v, n = protowire.ConsumeBytes(b)
209 if n < 0 {
210 return out, errDecode
211 }
212 var o unmarshalOutput
213 o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
214 if o.initialized {
215 // Consider this map item initialized so long as we see
216 // an initialized value.
217 out.initialized = true
218 }
219 }
220 if err == errUnknown {
221 n = protowire.ConsumeFieldValue(num, wtyp, b)
222 if n < 0 {
223 return out, errDecode
224 }
225 } else if err != nil {
226 return out, err
227 }
228 b = b[n:]
229 }
230 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
231 out.n = n
232 return out, nil
233}
234
235func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
236 if f.mi == nil {
237 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
238 val := mapi.conv.valConv.PBValueOf(valrv)
239 size := 0
240 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
241 size += mapi.valFuncs.size(val, mapValTagSize, opts)
242 b = protowire.AppendVarint(b, uint64(size))
243 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
244 if err != nil {
245 return nil, err
246 }
247 return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
248 } else {
249 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
250 val := pointerOfValue(valrv)
251 valSize := f.mi.sizePointer(val, opts)
252 size := 0
253 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
254 size += mapValTagSize + protowire.SizeBytes(valSize)
255 b = protowire.AppendVarint(b, uint64(size))
256 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
257 if err != nil {
258 return nil, err
259 }
260 b = protowire.AppendVarint(b, mapi.valWiretag)
261 b = protowire.AppendVarint(b, uint64(valSize))
262 return f.mi.marshalAppendPointer(b, val, opts)
263 }
264}
265
266func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
267 if mapv.Len() == 0 {
268 return b, nil
269 }
270 if opts.Deterministic() {
271 return appendMapDeterministic(b, mapv, mapi, f, opts)
272 }
273 iter := mapRange(mapv)
274 for iter.Next() {
275 var err error
276 b = protowire.AppendVarint(b, f.wiretag)
277 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
278 if err != nil {
279 return b, err
280 }
281 }
282 return b, nil
283}
284
285func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
286 keys := mapv.MapKeys()
287 sort.Slice(keys, func(i, j int) bool {
288 switch keys[i].Kind() {
289 case reflect.Bool:
290 return !keys[i].Bool() && keys[j].Bool()
291 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
292 return keys[i].Int() < keys[j].Int()
293 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
294 return keys[i].Uint() < keys[j].Uint()
295 case reflect.Float32, reflect.Float64:
296 return keys[i].Float() < keys[j].Float()
297 case reflect.String:
298 return keys[i].String() < keys[j].String()
299 default:
300 panic("invalid kind: " + keys[i].Kind().String())
301 }
302 })
303 for _, key := range keys {
304 var err error
305 b = protowire.AppendVarint(b, f.wiretag)
306 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
307 if err != nil {
308 return b, err
309 }
310 }
311 return b, nil
312}
313
314func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
315 if mi := f.mi; mi != nil {
316 mi.init()
317 if !mi.needsInitCheck {
318 return nil
319 }
320 iter := mapRange(mapv)
321 for iter.Next() {
322 val := pointerOfValue(iter.Value())
323 if err := mi.checkInitializedPointer(val); err != nil {
324 return err
325 }
326 }
327 } else {
328 iter := mapRange(mapv)
329 for iter.Next() {
330 val := mapi.conv.valConv.PBValueOf(iter.Value())
331 if err := mapi.valFuncs.isInit(val); err != nil {
332 return err
333 }
334 }
335 }
336 return nil
337}
338
339func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
340 dstm := dst.AsValueOf(f.ft).Elem()
341 srcm := src.AsValueOf(f.ft).Elem()
342 if srcm.Len() == 0 {
343 return
344 }
345 if dstm.IsNil() {
346 dstm.Set(reflect.MakeMap(f.ft))
347 }
348 iter := mapRange(srcm)
349 for iter.Next() {
350 dstm.SetMapIndex(iter.Key(), iter.Value())
351 }
352}
353
354func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
355 dstm := dst.AsValueOf(f.ft).Elem()
356 srcm := src.AsValueOf(f.ft).Elem()
357 if srcm.Len() == 0 {
358 return
359 }
360 if dstm.IsNil() {
361 dstm.Set(reflect.MakeMap(f.ft))
362 }
363 iter := mapRange(srcm)
364 for iter.Next() {
365 dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
366 }
367}
368
369func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
370 dstm := dst.AsValueOf(f.ft).Elem()
371 srcm := src.AsValueOf(f.ft).Elem()
372 if srcm.Len() == 0 {
373 return
374 }
375 if dstm.IsNil() {
376 dstm.Set(reflect.MakeMap(f.ft))
377 }
378 iter := mapRange(srcm)
379 for iter.Next() {
380 val := reflect.New(f.ft.Elem().Elem())
381 if f.mi != nil {
382 f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
383 } else {
384 opts.Merge(asMessage(val), asMessage(iter.Value()))
385 }
386 dstm.SetMapIndex(iter.Key(), val)
387 }
388}