khenaidoo | a46458b | 2021-12-15 16:50:44 -0500 | [diff] [blame] | 1 | package codec |
| 2 | |
| 3 | import ( |
| 4 | "fmt" |
| 5 | "math" |
| 6 | "reflect" |
| 7 | "sort" |
| 8 | |
| 9 | "github.com/golang/protobuf/proto" |
| 10 | "github.com/golang/protobuf/protoc-gen-go/descriptor" |
| 11 | |
| 12 | "github.com/jhump/protoreflect/desc" |
| 13 | ) |
| 14 | |
| 15 | // EncodeZigZag64 does zig-zag encoding to convert the given |
| 16 | // signed 64-bit integer into a form that can be expressed |
| 17 | // efficiently as a varint, even for negative values. |
| 18 | func EncodeZigZag64(v int64) uint64 { |
| 19 | return (uint64(v) << 1) ^ uint64(v>>63) |
| 20 | } |
| 21 | |
| 22 | // EncodeZigZag32 does zig-zag encoding to convert the given |
| 23 | // signed 32-bit integer into a form that can be expressed |
| 24 | // efficiently as a varint, even for negative values. |
| 25 | func EncodeZigZag32(v int32) uint64 { |
| 26 | return uint64((uint32(v) << 1) ^ uint32((v >> 31))) |
| 27 | } |
| 28 | |
| 29 | func (cb *Buffer) EncodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error { |
| 30 | if fd.IsMap() { |
| 31 | mp := val.(map[interface{}]interface{}) |
| 32 | entryType := fd.GetMessageType() |
| 33 | keyType := entryType.FindFieldByNumber(1) |
| 34 | valType := entryType.FindFieldByNumber(2) |
| 35 | var entryBuffer Buffer |
| 36 | if cb.IsDeterministic() { |
| 37 | entryBuffer.SetDeterministic(true) |
| 38 | keys := make([]interface{}, 0, len(mp)) |
| 39 | for k := range mp { |
| 40 | keys = append(keys, k) |
| 41 | } |
| 42 | sort.Sort(sortable(keys)) |
| 43 | for _, k := range keys { |
| 44 | v := mp[k] |
| 45 | entryBuffer.Reset() |
| 46 | if err := entryBuffer.encodeFieldElement(keyType, k); err != nil { |
| 47 | return err |
| 48 | } |
| 49 | rv := reflect.ValueOf(v) |
| 50 | if rv.Kind() != reflect.Ptr || !rv.IsNil() { |
| 51 | if err := entryBuffer.encodeFieldElement(valType, v); err != nil { |
| 52 | return err |
| 53 | } |
| 54 | } |
| 55 | if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { |
| 56 | return err |
| 57 | } |
| 58 | if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil { |
| 59 | return err |
| 60 | } |
| 61 | } |
| 62 | } else { |
| 63 | for k, v := range mp { |
| 64 | entryBuffer.Reset() |
| 65 | if err := entryBuffer.encodeFieldElement(keyType, k); err != nil { |
| 66 | return err |
| 67 | } |
| 68 | rv := reflect.ValueOf(v) |
| 69 | if rv.Kind() != reflect.Ptr || !rv.IsNil() { |
| 70 | if err := entryBuffer.encodeFieldElement(valType, v); err != nil { |
| 71 | return err |
| 72 | } |
| 73 | } |
| 74 | if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { |
| 75 | return err |
| 76 | } |
| 77 | if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil { |
| 78 | return err |
| 79 | } |
| 80 | } |
| 81 | } |
| 82 | return nil |
| 83 | } else if fd.IsRepeated() { |
| 84 | sl := val.([]interface{}) |
| 85 | wt, err := getWireType(fd.GetType()) |
| 86 | if err != nil { |
| 87 | return err |
| 88 | } |
| 89 | if isPacked(fd) && len(sl) > 0 && |
| 90 | (wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) { |
| 91 | // packed repeated field |
| 92 | var packedBuffer Buffer |
| 93 | for _, v := range sl { |
| 94 | if err := packedBuffer.encodeFieldValue(fd, v); err != nil { |
| 95 | return err |
| 96 | } |
| 97 | } |
| 98 | if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { |
| 99 | return err |
| 100 | } |
| 101 | return cb.EncodeRawBytes(packedBuffer.Bytes()) |
| 102 | } else { |
| 103 | // non-packed repeated field |
| 104 | for _, v := range sl { |
| 105 | if err := cb.encodeFieldElement(fd, v); err != nil { |
| 106 | return err |
| 107 | } |
| 108 | } |
| 109 | return nil |
| 110 | } |
| 111 | } else { |
| 112 | return cb.encodeFieldElement(fd, val) |
| 113 | } |
| 114 | } |
| 115 | |
| 116 | func isPacked(fd *desc.FieldDescriptor) bool { |
| 117 | opts := fd.AsFieldDescriptorProto().GetOptions() |
| 118 | // if set, use that value |
| 119 | if opts != nil && opts.Packed != nil { |
| 120 | return opts.GetPacked() |
| 121 | } |
| 122 | // if unset: proto2 defaults to false, proto3 to true |
| 123 | return fd.GetFile().IsProto3() |
| 124 | } |
| 125 | |
| 126 | // sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64), |
| 127 | // bools, or strings. |
| 128 | type sortable []interface{} |
| 129 | |
| 130 | func (s sortable) Len() int { |
| 131 | return len(s) |
| 132 | } |
| 133 | |
| 134 | func (s sortable) Less(i, j int) bool { |
| 135 | vi := s[i] |
| 136 | vj := s[j] |
| 137 | switch reflect.TypeOf(vi).Kind() { |
| 138 | case reflect.Int32: |
| 139 | return vi.(int32) < vj.(int32) |
| 140 | case reflect.Int64: |
| 141 | return vi.(int64) < vj.(int64) |
| 142 | case reflect.Uint32: |
| 143 | return vi.(uint32) < vj.(uint32) |
| 144 | case reflect.Uint64: |
| 145 | return vi.(uint64) < vj.(uint64) |
| 146 | case reflect.String: |
| 147 | return vi.(string) < vj.(string) |
| 148 | case reflect.Bool: |
| 149 | return !vi.(bool) && vj.(bool) |
| 150 | default: |
| 151 | panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi))) |
| 152 | } |
| 153 | } |
| 154 | |
| 155 | func (s sortable) Swap(i, j int) { |
| 156 | s[i], s[j] = s[j], s[i] |
| 157 | } |
| 158 | |
| 159 | func (b *Buffer) encodeFieldElement(fd *desc.FieldDescriptor, val interface{}) error { |
| 160 | wt, err := getWireType(fd.GetType()) |
| 161 | if err != nil { |
| 162 | return err |
| 163 | } |
| 164 | if err := b.EncodeTagAndWireType(fd.GetNumber(), wt); err != nil { |
| 165 | return err |
| 166 | } |
| 167 | if err := b.encodeFieldValue(fd, val); err != nil { |
| 168 | return err |
| 169 | } |
| 170 | if wt == proto.WireStartGroup { |
| 171 | return b.EncodeTagAndWireType(fd.GetNumber(), proto.WireEndGroup) |
| 172 | } |
| 173 | return nil |
| 174 | } |
| 175 | |
| 176 | func (b *Buffer) encodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error { |
| 177 | switch fd.GetType() { |
| 178 | case descriptor.FieldDescriptorProto_TYPE_BOOL: |
| 179 | v := val.(bool) |
| 180 | if v { |
| 181 | return b.EncodeVarint(1) |
| 182 | } |
| 183 | return b.EncodeVarint(0) |
| 184 | |
| 185 | case descriptor.FieldDescriptorProto_TYPE_ENUM, |
| 186 | descriptor.FieldDescriptorProto_TYPE_INT32: |
| 187 | v := val.(int32) |
| 188 | return b.EncodeVarint(uint64(v)) |
| 189 | |
| 190 | case descriptor.FieldDescriptorProto_TYPE_SFIXED32: |
| 191 | v := val.(int32) |
| 192 | return b.EncodeFixed32(uint64(v)) |
| 193 | |
| 194 | case descriptor.FieldDescriptorProto_TYPE_SINT32: |
| 195 | v := val.(int32) |
| 196 | return b.EncodeVarint(EncodeZigZag32(v)) |
| 197 | |
| 198 | case descriptor.FieldDescriptorProto_TYPE_UINT32: |
| 199 | v := val.(uint32) |
| 200 | return b.EncodeVarint(uint64(v)) |
| 201 | |
| 202 | case descriptor.FieldDescriptorProto_TYPE_FIXED32: |
| 203 | v := val.(uint32) |
| 204 | return b.EncodeFixed32(uint64(v)) |
| 205 | |
| 206 | case descriptor.FieldDescriptorProto_TYPE_INT64: |
| 207 | v := val.(int64) |
| 208 | return b.EncodeVarint(uint64(v)) |
| 209 | |
| 210 | case descriptor.FieldDescriptorProto_TYPE_SFIXED64: |
| 211 | v := val.(int64) |
| 212 | return b.EncodeFixed64(uint64(v)) |
| 213 | |
| 214 | case descriptor.FieldDescriptorProto_TYPE_SINT64: |
| 215 | v := val.(int64) |
| 216 | return b.EncodeVarint(EncodeZigZag64(v)) |
| 217 | |
| 218 | case descriptor.FieldDescriptorProto_TYPE_UINT64: |
| 219 | v := val.(uint64) |
| 220 | return b.EncodeVarint(v) |
| 221 | |
| 222 | case descriptor.FieldDescriptorProto_TYPE_FIXED64: |
| 223 | v := val.(uint64) |
| 224 | return b.EncodeFixed64(v) |
| 225 | |
| 226 | case descriptor.FieldDescriptorProto_TYPE_DOUBLE: |
| 227 | v := val.(float64) |
| 228 | return b.EncodeFixed64(math.Float64bits(v)) |
| 229 | |
| 230 | case descriptor.FieldDescriptorProto_TYPE_FLOAT: |
| 231 | v := val.(float32) |
| 232 | return b.EncodeFixed32(uint64(math.Float32bits(v))) |
| 233 | |
| 234 | case descriptor.FieldDescriptorProto_TYPE_BYTES: |
| 235 | v := val.([]byte) |
| 236 | return b.EncodeRawBytes(v) |
| 237 | |
| 238 | case descriptor.FieldDescriptorProto_TYPE_STRING: |
| 239 | v := val.(string) |
| 240 | return b.EncodeRawBytes(([]byte)(v)) |
| 241 | |
| 242 | case descriptor.FieldDescriptorProto_TYPE_MESSAGE: |
| 243 | return b.EncodeDelimitedMessage(val.(proto.Message)) |
| 244 | |
| 245 | case descriptor.FieldDescriptorProto_TYPE_GROUP: |
| 246 | // just append the nested message to this buffer |
| 247 | return b.EncodeMessage(val.(proto.Message)) |
| 248 | // whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag |
| 249 | |
| 250 | default: |
| 251 | return fmt.Errorf("unrecognized field type: %v", fd.GetType()) |
| 252 | } |
| 253 | } |
| 254 | |
| 255 | func getWireType(t descriptor.FieldDescriptorProto_Type) (int8, error) { |
| 256 | switch t { |
| 257 | case descriptor.FieldDescriptorProto_TYPE_ENUM, |
| 258 | descriptor.FieldDescriptorProto_TYPE_BOOL, |
| 259 | descriptor.FieldDescriptorProto_TYPE_INT32, |
| 260 | descriptor.FieldDescriptorProto_TYPE_SINT32, |
| 261 | descriptor.FieldDescriptorProto_TYPE_UINT32, |
| 262 | descriptor.FieldDescriptorProto_TYPE_INT64, |
| 263 | descriptor.FieldDescriptorProto_TYPE_SINT64, |
| 264 | descriptor.FieldDescriptorProto_TYPE_UINT64: |
| 265 | return proto.WireVarint, nil |
| 266 | |
| 267 | case descriptor.FieldDescriptorProto_TYPE_FIXED32, |
| 268 | descriptor.FieldDescriptorProto_TYPE_SFIXED32, |
| 269 | descriptor.FieldDescriptorProto_TYPE_FLOAT: |
| 270 | return proto.WireFixed32, nil |
| 271 | |
| 272 | case descriptor.FieldDescriptorProto_TYPE_FIXED64, |
| 273 | descriptor.FieldDescriptorProto_TYPE_SFIXED64, |
| 274 | descriptor.FieldDescriptorProto_TYPE_DOUBLE: |
| 275 | return proto.WireFixed64, nil |
| 276 | |
| 277 | case descriptor.FieldDescriptorProto_TYPE_BYTES, |
| 278 | descriptor.FieldDescriptorProto_TYPE_STRING, |
| 279 | descriptor.FieldDescriptorProto_TYPE_MESSAGE: |
| 280 | return proto.WireBytes, nil |
| 281 | |
| 282 | case descriptor.FieldDescriptorProto_TYPE_GROUP: |
| 283 | return proto.WireStartGroup, nil |
| 284 | |
| 285 | default: |
| 286 | return 0, ErrBadWireType |
| 287 | } |
| 288 | } |