blob: cda7299bbb70fb2a846a20fe9f4f4eb49b65d7c5 [file] [log] [blame]
Scott Baker4a35a702019-11-26 08:17:33 -08001package codec
2
3import (
4 "fmt"
5 "math"
6 "reflect"
7 "sort"
8
9 "github.com/golang/protobuf/proto"
10 "github.com/golang/protobuf/protoc-gen-go/descriptor"
11
12 "github.com/jhump/protoreflect/desc"
13)
14
15func (cb *Buffer) EncodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
16 if fd.IsMap() {
17 mp := val.(map[interface{}]interface{})
18 entryType := fd.GetMessageType()
19 keyType := entryType.FindFieldByNumber(1)
20 valType := entryType.FindFieldByNumber(2)
21 var entryBuffer Buffer
22 if cb.deterministic {
23 keys := make([]interface{}, 0, len(mp))
24 for k := range mp {
25 keys = append(keys, k)
26 }
27 sort.Sort(sortable(keys))
28 for _, k := range keys {
29 v := mp[k]
30 entryBuffer.Reset()
31 if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
32 return err
33 }
34 if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
35 return err
36 }
37 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
38 return err
39 }
40 if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
41 return err
42 }
43 }
44 } else {
45 for k, v := range mp {
46 entryBuffer.Reset()
47 if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
48 return err
49 }
50 if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
51 return err
52 }
53 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
54 return err
55 }
56 if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
57 return err
58 }
59 }
60 }
61 return nil
62 } else if fd.IsRepeated() {
63 sl := val.([]interface{})
64 wt, err := getWireType(fd.GetType())
65 if err != nil {
66 return err
67 }
68 if isPacked(fd) && len(sl) > 1 &&
69 (wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) {
70 // packed repeated field
71 var packedBuffer Buffer
72 for _, v := range sl {
73 if err := packedBuffer.encodeFieldValue(fd, v); err != nil {
74 return err
75 }
76 }
77 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
78 return err
79 }
80 return cb.EncodeRawBytes(packedBuffer.Bytes())
81 } else {
82 // non-packed repeated field
83 for _, v := range sl {
84 if err := cb.encodeFieldElement(fd, v); err != nil {
85 return err
86 }
87 }
88 return nil
89 }
90 } else {
91 return cb.encodeFieldElement(fd, val)
92 }
93}
94
95func isPacked(fd *desc.FieldDescriptor) bool {
96 opts := fd.AsFieldDescriptorProto().GetOptions()
97 // if set, use that value
98 if opts != nil && opts.Packed != nil {
99 return opts.GetPacked()
100 }
101 // if unset: proto2 defaults to false, proto3 to true
102 return fd.GetFile().IsProto3()
103}
104
105// sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64),
106// bools, or strings.
107type sortable []interface{}
108
109func (s sortable) Len() int {
110 return len(s)
111}
112
113func (s sortable) Less(i, j int) bool {
114 vi := s[i]
115 vj := s[j]
116 switch reflect.TypeOf(vi).Kind() {
117 case reflect.Int32:
118 return vi.(int32) < vj.(int32)
119 case reflect.Int64:
120 return vi.(int64) < vj.(int64)
121 case reflect.Uint32:
122 return vi.(uint32) < vj.(uint32)
123 case reflect.Uint64:
124 return vi.(uint64) < vj.(uint64)
125 case reflect.String:
126 return vi.(string) < vj.(string)
127 case reflect.Bool:
128 return !vi.(bool) && vj.(bool)
129 default:
130 panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi)))
131 }
132}
133
134func (s sortable) Swap(i, j int) {
135 s[i], s[j] = s[j], s[i]
136}
137
138func (b *Buffer) encodeFieldElement(fd *desc.FieldDescriptor, val interface{}) error {
139 wt, err := getWireType(fd.GetType())
140 if err != nil {
141 return err
142 }
143 if err := b.EncodeTagAndWireType(fd.GetNumber(), wt); err != nil {
144 return err
145 }
146 if err := b.encodeFieldValue(fd, val); err != nil {
147 return err
148 }
149 if wt == proto.WireStartGroup {
150 return b.EncodeTagAndWireType(fd.GetNumber(), proto.WireEndGroup)
151 }
152 return nil
153}
154
155func (b *Buffer) encodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
156 switch fd.GetType() {
157 case descriptor.FieldDescriptorProto_TYPE_BOOL:
158 v := val.(bool)
159 if v {
160 return b.EncodeVarint(1)
161 }
162 return b.EncodeVarint(0)
163
164 case descriptor.FieldDescriptorProto_TYPE_ENUM,
165 descriptor.FieldDescriptorProto_TYPE_INT32:
166 v := val.(int32)
167 return b.EncodeVarint(uint64(v))
168
169 case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
170 v := val.(int32)
171 return b.EncodeFixed32(uint64(v))
172
173 case descriptor.FieldDescriptorProto_TYPE_SINT32:
174 v := val.(int32)
175 return b.EncodeVarint(EncodeZigZag32(v))
176
177 case descriptor.FieldDescriptorProto_TYPE_UINT32:
178 v := val.(uint32)
179 return b.EncodeVarint(uint64(v))
180
181 case descriptor.FieldDescriptorProto_TYPE_FIXED32:
182 v := val.(uint32)
183 return b.EncodeFixed32(uint64(v))
184
185 case descriptor.FieldDescriptorProto_TYPE_INT64:
186 v := val.(int64)
187 return b.EncodeVarint(uint64(v))
188
189 case descriptor.FieldDescriptorProto_TYPE_SFIXED64:
190 v := val.(int64)
191 return b.EncodeFixed64(uint64(v))
192
193 case descriptor.FieldDescriptorProto_TYPE_SINT64:
194 v := val.(int64)
195 return b.EncodeVarint(EncodeZigZag64(v))
196
197 case descriptor.FieldDescriptorProto_TYPE_UINT64:
198 v := val.(uint64)
199 return b.EncodeVarint(v)
200
201 case descriptor.FieldDescriptorProto_TYPE_FIXED64:
202 v := val.(uint64)
203 return b.EncodeFixed64(v)
204
205 case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
206 v := val.(float64)
207 return b.EncodeFixed64(math.Float64bits(v))
208
209 case descriptor.FieldDescriptorProto_TYPE_FLOAT:
210 v := val.(float32)
211 return b.EncodeFixed32(uint64(math.Float32bits(v)))
212
213 case descriptor.FieldDescriptorProto_TYPE_BYTES:
214 v := val.([]byte)
215 return b.EncodeRawBytes(v)
216
217 case descriptor.FieldDescriptorProto_TYPE_STRING:
218 v := val.(string)
219 return b.EncodeRawBytes(([]byte)(v))
220
221 case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
222 return b.EncodeDelimitedMessage(val.(proto.Message))
223
224 case descriptor.FieldDescriptorProto_TYPE_GROUP:
225 // just append the nested message to this buffer
226 return b.EncodeMessage(val.(proto.Message))
227 // whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag
228
229 default:
230 return fmt.Errorf("unrecognized field type: %v", fd.GetType())
231 }
232}
233
234func getWireType(t descriptor.FieldDescriptorProto_Type) (int8, error) {
235 switch t {
236 case descriptor.FieldDescriptorProto_TYPE_ENUM,
237 descriptor.FieldDescriptorProto_TYPE_BOOL,
238 descriptor.FieldDescriptorProto_TYPE_INT32,
239 descriptor.FieldDescriptorProto_TYPE_SINT32,
240 descriptor.FieldDescriptorProto_TYPE_UINT32,
241 descriptor.FieldDescriptorProto_TYPE_INT64,
242 descriptor.FieldDescriptorProto_TYPE_SINT64,
243 descriptor.FieldDescriptorProto_TYPE_UINT64:
244 return proto.WireVarint, nil
245
246 case descriptor.FieldDescriptorProto_TYPE_FIXED32,
247 descriptor.FieldDescriptorProto_TYPE_SFIXED32,
248 descriptor.FieldDescriptorProto_TYPE_FLOAT:
249 return proto.WireFixed32, nil
250
251 case descriptor.FieldDescriptorProto_TYPE_FIXED64,
252 descriptor.FieldDescriptorProto_TYPE_SFIXED64,
253 descriptor.FieldDescriptorProto_TYPE_DOUBLE:
254 return proto.WireFixed64, nil
255
256 case descriptor.FieldDescriptorProto_TYPE_BYTES,
257 descriptor.FieldDescriptorProto_TYPE_STRING,
258 descriptor.FieldDescriptorProto_TYPE_MESSAGE:
259 return proto.WireBytes, nil
260
261 case descriptor.FieldDescriptorProto_TYPE_GROUP:
262 return proto.WireStartGroup, nil
263
264 default:
265 return 0, ErrBadWireType
266 }
267}