blob: 499aa9564a10d7ad29e4696dfa2a995ce3463f6b [file] [log] [blame]
khenaidooefff76e2021-12-15 16:51:30 -05001package codec
2
3import (
4 "fmt"
5 "math"
6 "reflect"
7 "sort"
8
9 "github.com/golang/protobuf/proto"
10 "github.com/golang/protobuf/protoc-gen-go/descriptor"
11
12 "github.com/jhump/protoreflect/desc"
13)
14
15// EncodeZigZag64 does zig-zag encoding to convert the given
16// signed 64-bit integer into a form that can be expressed
17// efficiently as a varint, even for negative values.
18func EncodeZigZag64(v int64) uint64 {
19 return (uint64(v) << 1) ^ uint64(v>>63)
20}
21
22// EncodeZigZag32 does zig-zag encoding to convert the given
23// signed 32-bit integer into a form that can be expressed
24// efficiently as a varint, even for negative values.
25func EncodeZigZag32(v int32) uint64 {
26 return uint64((uint32(v) << 1) ^ uint32((v >> 31)))
27}
28
29func (cb *Buffer) EncodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
30 if fd.IsMap() {
31 mp := val.(map[interface{}]interface{})
32 entryType := fd.GetMessageType()
33 keyType := entryType.FindFieldByNumber(1)
34 valType := entryType.FindFieldByNumber(2)
35 var entryBuffer Buffer
36 if cb.IsDeterministic() {
37 entryBuffer.SetDeterministic(true)
38 keys := make([]interface{}, 0, len(mp))
39 for k := range mp {
40 keys = append(keys, k)
41 }
42 sort.Sort(sortable(keys))
43 for _, k := range keys {
44 v := mp[k]
45 entryBuffer.Reset()
46 if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
47 return err
48 }
49 rv := reflect.ValueOf(v)
50 if rv.Kind() != reflect.Ptr || !rv.IsNil() {
51 if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
52 return err
53 }
54 }
55 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
56 return err
57 }
58 if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
59 return err
60 }
61 }
62 } else {
63 for k, v := range mp {
64 entryBuffer.Reset()
65 if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
66 return err
67 }
68 rv := reflect.ValueOf(v)
69 if rv.Kind() != reflect.Ptr || !rv.IsNil() {
70 if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
71 return err
72 }
73 }
74 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
75 return err
76 }
77 if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
78 return err
79 }
80 }
81 }
82 return nil
83 } else if fd.IsRepeated() {
84 sl := val.([]interface{})
85 wt, err := getWireType(fd.GetType())
86 if err != nil {
87 return err
88 }
89 if isPacked(fd) && len(sl) > 0 &&
90 (wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) {
91 // packed repeated field
92 var packedBuffer Buffer
93 for _, v := range sl {
94 if err := packedBuffer.encodeFieldValue(fd, v); err != nil {
95 return err
96 }
97 }
98 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
99 return err
100 }
101 return cb.EncodeRawBytes(packedBuffer.Bytes())
102 } else {
103 // non-packed repeated field
104 for _, v := range sl {
105 if err := cb.encodeFieldElement(fd, v); err != nil {
106 return err
107 }
108 }
109 return nil
110 }
111 } else {
112 return cb.encodeFieldElement(fd, val)
113 }
114}
115
116func isPacked(fd *desc.FieldDescriptor) bool {
117 opts := fd.AsFieldDescriptorProto().GetOptions()
118 // if set, use that value
119 if opts != nil && opts.Packed != nil {
120 return opts.GetPacked()
121 }
122 // if unset: proto2 defaults to false, proto3 to true
123 return fd.GetFile().IsProto3()
124}
125
126// sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64),
127// bools, or strings.
128type sortable []interface{}
129
130func (s sortable) Len() int {
131 return len(s)
132}
133
134func (s sortable) Less(i, j int) bool {
135 vi := s[i]
136 vj := s[j]
137 switch reflect.TypeOf(vi).Kind() {
138 case reflect.Int32:
139 return vi.(int32) < vj.(int32)
140 case reflect.Int64:
141 return vi.(int64) < vj.(int64)
142 case reflect.Uint32:
143 return vi.(uint32) < vj.(uint32)
144 case reflect.Uint64:
145 return vi.(uint64) < vj.(uint64)
146 case reflect.String:
147 return vi.(string) < vj.(string)
148 case reflect.Bool:
149 return !vi.(bool) && vj.(bool)
150 default:
151 panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi)))
152 }
153}
154
155func (s sortable) Swap(i, j int) {
156 s[i], s[j] = s[j], s[i]
157}
158
159func (b *Buffer) encodeFieldElement(fd *desc.FieldDescriptor, val interface{}) error {
160 wt, err := getWireType(fd.GetType())
161 if err != nil {
162 return err
163 }
164 if err := b.EncodeTagAndWireType(fd.GetNumber(), wt); err != nil {
165 return err
166 }
167 if err := b.encodeFieldValue(fd, val); err != nil {
168 return err
169 }
170 if wt == proto.WireStartGroup {
171 return b.EncodeTagAndWireType(fd.GetNumber(), proto.WireEndGroup)
172 }
173 return nil
174}
175
176func (b *Buffer) encodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
177 switch fd.GetType() {
178 case descriptor.FieldDescriptorProto_TYPE_BOOL:
179 v := val.(bool)
180 if v {
181 return b.EncodeVarint(1)
182 }
183 return b.EncodeVarint(0)
184
185 case descriptor.FieldDescriptorProto_TYPE_ENUM,
186 descriptor.FieldDescriptorProto_TYPE_INT32:
187 v := val.(int32)
188 return b.EncodeVarint(uint64(v))
189
190 case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
191 v := val.(int32)
192 return b.EncodeFixed32(uint64(v))
193
194 case descriptor.FieldDescriptorProto_TYPE_SINT32:
195 v := val.(int32)
196 return b.EncodeVarint(EncodeZigZag32(v))
197
198 case descriptor.FieldDescriptorProto_TYPE_UINT32:
199 v := val.(uint32)
200 return b.EncodeVarint(uint64(v))
201
202 case descriptor.FieldDescriptorProto_TYPE_FIXED32:
203 v := val.(uint32)
204 return b.EncodeFixed32(uint64(v))
205
206 case descriptor.FieldDescriptorProto_TYPE_INT64:
207 v := val.(int64)
208 return b.EncodeVarint(uint64(v))
209
210 case descriptor.FieldDescriptorProto_TYPE_SFIXED64:
211 v := val.(int64)
212 return b.EncodeFixed64(uint64(v))
213
214 case descriptor.FieldDescriptorProto_TYPE_SINT64:
215 v := val.(int64)
216 return b.EncodeVarint(EncodeZigZag64(v))
217
218 case descriptor.FieldDescriptorProto_TYPE_UINT64:
219 v := val.(uint64)
220 return b.EncodeVarint(v)
221
222 case descriptor.FieldDescriptorProto_TYPE_FIXED64:
223 v := val.(uint64)
224 return b.EncodeFixed64(v)
225
226 case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
227 v := val.(float64)
228 return b.EncodeFixed64(math.Float64bits(v))
229
230 case descriptor.FieldDescriptorProto_TYPE_FLOAT:
231 v := val.(float32)
232 return b.EncodeFixed32(uint64(math.Float32bits(v)))
233
234 case descriptor.FieldDescriptorProto_TYPE_BYTES:
235 v := val.([]byte)
236 return b.EncodeRawBytes(v)
237
238 case descriptor.FieldDescriptorProto_TYPE_STRING:
239 v := val.(string)
240 return b.EncodeRawBytes(([]byte)(v))
241
242 case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
243 return b.EncodeDelimitedMessage(val.(proto.Message))
244
245 case descriptor.FieldDescriptorProto_TYPE_GROUP:
246 // just append the nested message to this buffer
247 return b.EncodeMessage(val.(proto.Message))
248 // whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag
249
250 default:
251 return fmt.Errorf("unrecognized field type: %v", fd.GetType())
252 }
253}
254
255func getWireType(t descriptor.FieldDescriptorProto_Type) (int8, error) {
256 switch t {
257 case descriptor.FieldDescriptorProto_TYPE_ENUM,
258 descriptor.FieldDescriptorProto_TYPE_BOOL,
259 descriptor.FieldDescriptorProto_TYPE_INT32,
260 descriptor.FieldDescriptorProto_TYPE_SINT32,
261 descriptor.FieldDescriptorProto_TYPE_UINT32,
262 descriptor.FieldDescriptorProto_TYPE_INT64,
263 descriptor.FieldDescriptorProto_TYPE_SINT64,
264 descriptor.FieldDescriptorProto_TYPE_UINT64:
265 return proto.WireVarint, nil
266
267 case descriptor.FieldDescriptorProto_TYPE_FIXED32,
268 descriptor.FieldDescriptorProto_TYPE_SFIXED32,
269 descriptor.FieldDescriptorProto_TYPE_FLOAT:
270 return proto.WireFixed32, nil
271
272 case descriptor.FieldDescriptorProto_TYPE_FIXED64,
273 descriptor.FieldDescriptorProto_TYPE_SFIXED64,
274 descriptor.FieldDescriptorProto_TYPE_DOUBLE:
275 return proto.WireFixed64, nil
276
277 case descriptor.FieldDescriptorProto_TYPE_BYTES,
278 descriptor.FieldDescriptorProto_TYPE_STRING,
279 descriptor.FieldDescriptorProto_TYPE_MESSAGE:
280 return proto.WireBytes, nil
281
282 case descriptor.FieldDescriptorProto_TYPE_GROUP:
283 return proto.WireStartGroup, nil
284
285 default:
286 return 0, ErrBadWireType
287 }
288}