Scott Baker | 2d89798 | 2019-09-24 11:50:08 -0700 | [diff] [blame] | 1 | // Protocol Buffers for Go with Gadgets |
| 2 | // |
| 3 | // Copyright (c) 2013, The GoGo Authors. All rights reserved. |
| 4 | // http://github.com/gogo/protobuf |
| 5 | // |
| 6 | // Redistribution and use in source and binary forms, with or without |
| 7 | // modification, are permitted provided that the following conditions are |
| 8 | // met: |
| 9 | // |
| 10 | // * Redistributions of source code must retain the above copyright |
| 11 | // notice, this list of conditions and the following disclaimer. |
| 12 | // * Redistributions in binary form must reproduce the above |
| 13 | // copyright notice, this list of conditions and the following disclaimer |
| 14 | // in the documentation and/or other materials provided with the |
| 15 | // distribution. |
| 16 | // |
| 17 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| 18 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| 19 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| 20 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| 21 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| 22 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| 23 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| 24 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| 25 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 26 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 27 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 28 | |
| 29 | package proto |
| 30 | |
| 31 | import ( |
| 32 | "bytes" |
| 33 | "errors" |
| 34 | "fmt" |
| 35 | "io" |
| 36 | "reflect" |
| 37 | "sort" |
| 38 | "strings" |
| 39 | "sync" |
| 40 | ) |
| 41 | |
| 42 | type extensionsBytes interface { |
| 43 | Message |
| 44 | ExtensionRangeArray() []ExtensionRange |
| 45 | GetExtensions() *[]byte |
| 46 | } |
| 47 | |
| 48 | type slowExtensionAdapter struct { |
| 49 | extensionsBytes |
| 50 | } |
| 51 | |
| 52 | func (s slowExtensionAdapter) extensionsWrite() map[int32]Extension { |
| 53 | panic("Please report a bug to github.com/gogo/protobuf if you see this message: Writing extensions is not supported for extensions stored in a byte slice field.") |
| 54 | } |
| 55 | |
| 56 | func (s slowExtensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { |
| 57 | b := s.GetExtensions() |
| 58 | m, err := BytesToExtensionsMap(*b) |
| 59 | if err != nil { |
| 60 | panic(err) |
| 61 | } |
| 62 | return m, notLocker{} |
| 63 | } |
| 64 | |
| 65 | func GetBoolExtension(pb Message, extension *ExtensionDesc, ifnotset bool) bool { |
| 66 | if reflect.ValueOf(pb).IsNil() { |
| 67 | return ifnotset |
| 68 | } |
| 69 | value, err := GetExtension(pb, extension) |
| 70 | if err != nil { |
| 71 | return ifnotset |
| 72 | } |
| 73 | if value == nil { |
| 74 | return ifnotset |
| 75 | } |
| 76 | if value.(*bool) == nil { |
| 77 | return ifnotset |
| 78 | } |
| 79 | return *(value.(*bool)) |
| 80 | } |
| 81 | |
| 82 | func (this *Extension) Equal(that *Extension) bool { |
| 83 | if err := this.Encode(); err != nil { |
| 84 | return false |
| 85 | } |
| 86 | if err := that.Encode(); err != nil { |
| 87 | return false |
| 88 | } |
| 89 | return bytes.Equal(this.enc, that.enc) |
| 90 | } |
| 91 | |
| 92 | func (this *Extension) Compare(that *Extension) int { |
| 93 | if err := this.Encode(); err != nil { |
| 94 | return 1 |
| 95 | } |
| 96 | if err := that.Encode(); err != nil { |
| 97 | return -1 |
| 98 | } |
| 99 | return bytes.Compare(this.enc, that.enc) |
| 100 | } |
| 101 | |
| 102 | func SizeOfInternalExtension(m extendableProto) (n int) { |
| 103 | info := getMarshalInfo(reflect.TypeOf(m)) |
| 104 | return info.sizeV1Extensions(m.extensionsWrite()) |
| 105 | } |
| 106 | |
| 107 | type sortableMapElem struct { |
| 108 | field int32 |
| 109 | ext Extension |
| 110 | } |
| 111 | |
| 112 | func newSortableExtensionsFromMap(m map[int32]Extension) sortableExtensions { |
| 113 | s := make(sortableExtensions, 0, len(m)) |
| 114 | for k, v := range m { |
| 115 | s = append(s, &sortableMapElem{field: k, ext: v}) |
| 116 | } |
| 117 | return s |
| 118 | } |
| 119 | |
| 120 | type sortableExtensions []*sortableMapElem |
| 121 | |
| 122 | func (this sortableExtensions) Len() int { return len(this) } |
| 123 | |
| 124 | func (this sortableExtensions) Swap(i, j int) { this[i], this[j] = this[j], this[i] } |
| 125 | |
| 126 | func (this sortableExtensions) Less(i, j int) bool { return this[i].field < this[j].field } |
| 127 | |
| 128 | func (this sortableExtensions) String() string { |
| 129 | sort.Sort(this) |
| 130 | ss := make([]string, len(this)) |
| 131 | for i := range this { |
| 132 | ss[i] = fmt.Sprintf("%d: %v", this[i].field, this[i].ext) |
| 133 | } |
| 134 | return "map[" + strings.Join(ss, ",") + "]" |
| 135 | } |
| 136 | |
| 137 | func StringFromInternalExtension(m extendableProto) string { |
| 138 | return StringFromExtensionsMap(m.extensionsWrite()) |
| 139 | } |
| 140 | |
| 141 | func StringFromExtensionsMap(m map[int32]Extension) string { |
| 142 | return newSortableExtensionsFromMap(m).String() |
| 143 | } |
| 144 | |
| 145 | func StringFromExtensionsBytes(ext []byte) string { |
| 146 | m, err := BytesToExtensionsMap(ext) |
| 147 | if err != nil { |
| 148 | panic(err) |
| 149 | } |
| 150 | return StringFromExtensionsMap(m) |
| 151 | } |
| 152 | |
| 153 | func EncodeInternalExtension(m extendableProto, data []byte) (n int, err error) { |
| 154 | return EncodeExtensionMap(m.extensionsWrite(), data) |
| 155 | } |
| 156 | |
Scott Baker | 8487c5d | 2019-10-18 12:49:46 -0700 | [diff] [blame] | 157 | func EncodeInternalExtensionBackwards(m extendableProto, data []byte) (n int, err error) { |
| 158 | return EncodeExtensionMapBackwards(m.extensionsWrite(), data) |
| 159 | } |
| 160 | |
Scott Baker | 2d89798 | 2019-09-24 11:50:08 -0700 | [diff] [blame] | 161 | func EncodeExtensionMap(m map[int32]Extension, data []byte) (n int, err error) { |
| 162 | o := 0 |
| 163 | for _, e := range m { |
| 164 | if err := e.Encode(); err != nil { |
| 165 | return 0, err |
| 166 | } |
| 167 | n := copy(data[o:], e.enc) |
| 168 | if n != len(e.enc) { |
| 169 | return 0, io.ErrShortBuffer |
| 170 | } |
| 171 | o += n |
| 172 | } |
| 173 | return o, nil |
| 174 | } |
| 175 | |
Scott Baker | 8487c5d | 2019-10-18 12:49:46 -0700 | [diff] [blame] | 176 | func EncodeExtensionMapBackwards(m map[int32]Extension, data []byte) (n int, err error) { |
| 177 | o := 0 |
| 178 | end := len(data) |
| 179 | for _, e := range m { |
| 180 | if err := e.Encode(); err != nil { |
| 181 | return 0, err |
| 182 | } |
| 183 | n := copy(data[end-len(e.enc):], e.enc) |
| 184 | if n != len(e.enc) { |
| 185 | return 0, io.ErrShortBuffer |
| 186 | } |
| 187 | end -= n |
| 188 | o += n |
| 189 | } |
| 190 | return o, nil |
| 191 | } |
| 192 | |
Scott Baker | 2d89798 | 2019-09-24 11:50:08 -0700 | [diff] [blame] | 193 | func GetRawExtension(m map[int32]Extension, id int32) ([]byte, error) { |
| 194 | e := m[id] |
| 195 | if err := e.Encode(); err != nil { |
| 196 | return nil, err |
| 197 | } |
| 198 | return e.enc, nil |
| 199 | } |
| 200 | |
| 201 | func size(buf []byte, wire int) (int, error) { |
| 202 | switch wire { |
| 203 | case WireVarint: |
| 204 | _, n := DecodeVarint(buf) |
| 205 | return n, nil |
| 206 | case WireFixed64: |
| 207 | return 8, nil |
| 208 | case WireBytes: |
| 209 | v, n := DecodeVarint(buf) |
| 210 | return int(v) + n, nil |
| 211 | case WireFixed32: |
| 212 | return 4, nil |
| 213 | case WireStartGroup: |
| 214 | offset := 0 |
| 215 | for { |
| 216 | u, n := DecodeVarint(buf[offset:]) |
| 217 | fwire := int(u & 0x7) |
| 218 | offset += n |
| 219 | if fwire == WireEndGroup { |
| 220 | return offset, nil |
| 221 | } |
| 222 | s, err := size(buf[offset:], wire) |
| 223 | if err != nil { |
| 224 | return 0, err |
| 225 | } |
| 226 | offset += s |
| 227 | } |
| 228 | } |
| 229 | return 0, fmt.Errorf("proto: can't get size for unknown wire type %d", wire) |
| 230 | } |
| 231 | |
| 232 | func BytesToExtensionsMap(buf []byte) (map[int32]Extension, error) { |
| 233 | m := make(map[int32]Extension) |
| 234 | i := 0 |
| 235 | for i < len(buf) { |
| 236 | tag, n := DecodeVarint(buf[i:]) |
| 237 | if n <= 0 { |
| 238 | return nil, fmt.Errorf("unable to decode varint") |
| 239 | } |
| 240 | fieldNum := int32(tag >> 3) |
| 241 | wireType := int(tag & 0x7) |
| 242 | l, err := size(buf[i+n:], wireType) |
| 243 | if err != nil { |
| 244 | return nil, err |
| 245 | } |
| 246 | end := i + int(l) + n |
| 247 | m[int32(fieldNum)] = Extension{enc: buf[i:end]} |
| 248 | i = end |
| 249 | } |
| 250 | return m, nil |
| 251 | } |
| 252 | |
| 253 | func NewExtension(e []byte) Extension { |
| 254 | ee := Extension{enc: make([]byte, len(e))} |
| 255 | copy(ee.enc, e) |
| 256 | return ee |
| 257 | } |
| 258 | |
| 259 | func AppendExtension(e Message, tag int32, buf []byte) { |
| 260 | if ee, eok := e.(extensionsBytes); eok { |
| 261 | ext := ee.GetExtensions() |
| 262 | *ext = append(*ext, buf...) |
| 263 | return |
| 264 | } |
| 265 | if ee, eok := e.(extendableProto); eok { |
| 266 | m := ee.extensionsWrite() |
| 267 | ext := m[int32(tag)] // may be missing |
| 268 | ext.enc = append(ext.enc, buf...) |
| 269 | m[int32(tag)] = ext |
| 270 | } |
| 271 | } |
| 272 | |
| 273 | func encodeExtension(extension *ExtensionDesc, value interface{}) ([]byte, error) { |
| 274 | u := getMarshalInfo(reflect.TypeOf(extension.ExtendedType)) |
| 275 | ei := u.getExtElemInfo(extension) |
| 276 | v := value |
| 277 | p := toAddrPointer(&v, ei.isptr) |
| 278 | siz := ei.sizer(p, SizeVarint(ei.wiretag)) |
| 279 | buf := make([]byte, 0, siz) |
| 280 | return ei.marshaler(buf, p, ei.wiretag, false) |
| 281 | } |
| 282 | |
| 283 | func decodeExtensionFromBytes(extension *ExtensionDesc, buf []byte) (interface{}, error) { |
| 284 | o := 0 |
| 285 | for o < len(buf) { |
| 286 | tag, n := DecodeVarint((buf)[o:]) |
| 287 | fieldNum := int32(tag >> 3) |
| 288 | wireType := int(tag & 0x7) |
| 289 | if o+n > len(buf) { |
| 290 | return nil, fmt.Errorf("unable to decode extension") |
| 291 | } |
| 292 | l, err := size((buf)[o+n:], wireType) |
| 293 | if err != nil { |
| 294 | return nil, err |
| 295 | } |
| 296 | if int32(fieldNum) == extension.Field { |
| 297 | if o+n+l > len(buf) { |
| 298 | return nil, fmt.Errorf("unable to decode extension") |
| 299 | } |
| 300 | v, err := decodeExtension((buf)[o:o+n+l], extension) |
| 301 | if err != nil { |
| 302 | return nil, err |
| 303 | } |
| 304 | return v, nil |
| 305 | } |
| 306 | o += n + l |
| 307 | } |
| 308 | return defaultExtensionValue(extension) |
| 309 | } |
| 310 | |
| 311 | func (this *Extension) Encode() error { |
| 312 | if this.enc == nil { |
| 313 | var err error |
| 314 | this.enc, err = encodeExtension(this.desc, this.value) |
| 315 | if err != nil { |
| 316 | return err |
| 317 | } |
| 318 | } |
| 319 | return nil |
| 320 | } |
| 321 | |
| 322 | func (this Extension) GoString() string { |
| 323 | if err := this.Encode(); err != nil { |
| 324 | return fmt.Sprintf("error encoding extension: %v", err) |
| 325 | } |
| 326 | return fmt.Sprintf("proto.NewExtension(%#v)", this.enc) |
| 327 | } |
| 328 | |
| 329 | func SetUnsafeExtension(pb Message, fieldNum int32, value interface{}) error { |
| 330 | typ := reflect.TypeOf(pb).Elem() |
| 331 | ext, ok := extensionMaps[typ] |
| 332 | if !ok { |
| 333 | return fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String()) |
| 334 | } |
| 335 | desc, ok := ext[fieldNum] |
| 336 | if !ok { |
| 337 | return errors.New("proto: bad extension number; not in declared ranges") |
| 338 | } |
| 339 | return SetExtension(pb, desc, value) |
| 340 | } |
| 341 | |
| 342 | func GetUnsafeExtension(pb Message, fieldNum int32) (interface{}, error) { |
| 343 | typ := reflect.TypeOf(pb).Elem() |
| 344 | ext, ok := extensionMaps[typ] |
| 345 | if !ok { |
| 346 | return nil, fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String()) |
| 347 | } |
| 348 | desc, ok := ext[fieldNum] |
| 349 | if !ok { |
| 350 | return nil, fmt.Errorf("unregistered field number %d", fieldNum) |
| 351 | } |
| 352 | return GetExtension(pb, desc) |
| 353 | } |
| 354 | |
| 355 | func NewUnsafeXXX_InternalExtensions(m map[int32]Extension) XXX_InternalExtensions { |
| 356 | x := &XXX_InternalExtensions{ |
| 357 | p: new(struct { |
| 358 | mu sync.Mutex |
| 359 | extensionMap map[int32]Extension |
| 360 | }), |
| 361 | } |
| 362 | x.p.extensionMap = m |
| 363 | return *x |
| 364 | } |
| 365 | |
| 366 | func GetUnsafeExtensionsMap(extendable Message) map[int32]Extension { |
| 367 | pb := extendable.(extendableProto) |
| 368 | return pb.extensionsWrite() |
| 369 | } |
| 370 | |
| 371 | func deleteExtension(pb extensionsBytes, theFieldNum int32, offset int) int { |
| 372 | ext := pb.GetExtensions() |
| 373 | for offset < len(*ext) { |
| 374 | tag, n1 := DecodeVarint((*ext)[offset:]) |
| 375 | fieldNum := int32(tag >> 3) |
| 376 | wireType := int(tag & 0x7) |
| 377 | n2, err := size((*ext)[offset+n1:], wireType) |
| 378 | if err != nil { |
| 379 | panic(err) |
| 380 | } |
| 381 | newOffset := offset + n1 + n2 |
| 382 | if fieldNum == theFieldNum { |
| 383 | *ext = append((*ext)[:offset], (*ext)[newOffset:]...) |
| 384 | return offset |
| 385 | } |
| 386 | offset = newOffset |
| 387 | } |
| 388 | return -1 |
| 389 | } |