David K. Bainbridge | 215e024 | 2017-09-05 23:18:24 -0700 | [diff] [blame] | 1 | // Protocol Buffers for Go with Gadgets |
| 2 | // |
| 3 | // Copyright (c) 2013, The GoGo Authors. All rights reserved. |
| 4 | // http://github.com/gogo/protobuf |
| 5 | // |
| 6 | // Redistribution and use in source and binary forms, with or without |
| 7 | // modification, are permitted provided that the following conditions are |
| 8 | // met: |
| 9 | // |
| 10 | // * Redistributions of source code must retain the above copyright |
| 11 | // notice, this list of conditions and the following disclaimer. |
| 12 | // * Redistributions in binary form must reproduce the above |
| 13 | // copyright notice, this list of conditions and the following disclaimer |
| 14 | // in the documentation and/or other materials provided with the |
| 15 | // distribution. |
| 16 | // |
| 17 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| 18 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| 19 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| 20 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| 21 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| 22 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| 23 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| 24 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| 25 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 26 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 27 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 28 | |
| 29 | package proto |
| 30 | |
| 31 | import ( |
| 32 | "bytes" |
| 33 | "errors" |
| 34 | "fmt" |
| 35 | "reflect" |
| 36 | "sort" |
| 37 | "strings" |
| 38 | "sync" |
| 39 | ) |
| 40 | |
| 41 | func GetBoolExtension(pb Message, extension *ExtensionDesc, ifnotset bool) bool { |
| 42 | if reflect.ValueOf(pb).IsNil() { |
| 43 | return ifnotset |
| 44 | } |
| 45 | value, err := GetExtension(pb, extension) |
| 46 | if err != nil { |
| 47 | return ifnotset |
| 48 | } |
| 49 | if value == nil { |
| 50 | return ifnotset |
| 51 | } |
| 52 | if value.(*bool) == nil { |
| 53 | return ifnotset |
| 54 | } |
| 55 | return *(value.(*bool)) |
| 56 | } |
| 57 | |
| 58 | func (this *Extension) Equal(that *Extension) bool { |
| 59 | return bytes.Equal(this.enc, that.enc) |
| 60 | } |
| 61 | |
| 62 | func (this *Extension) Compare(that *Extension) int { |
| 63 | return bytes.Compare(this.enc, that.enc) |
| 64 | } |
| 65 | |
| 66 | func SizeOfInternalExtension(m extendableProto) (n int) { |
| 67 | return SizeOfExtensionMap(m.extensionsWrite()) |
| 68 | } |
| 69 | |
| 70 | func SizeOfExtensionMap(m map[int32]Extension) (n int) { |
| 71 | return extensionsMapSize(m) |
| 72 | } |
| 73 | |
| 74 | type sortableMapElem struct { |
| 75 | field int32 |
| 76 | ext Extension |
| 77 | } |
| 78 | |
| 79 | func newSortableExtensionsFromMap(m map[int32]Extension) sortableExtensions { |
| 80 | s := make(sortableExtensions, 0, len(m)) |
| 81 | for k, v := range m { |
| 82 | s = append(s, &sortableMapElem{field: k, ext: v}) |
| 83 | } |
| 84 | return s |
| 85 | } |
| 86 | |
| 87 | type sortableExtensions []*sortableMapElem |
| 88 | |
| 89 | func (this sortableExtensions) Len() int { return len(this) } |
| 90 | |
| 91 | func (this sortableExtensions) Swap(i, j int) { this[i], this[j] = this[j], this[i] } |
| 92 | |
| 93 | func (this sortableExtensions) Less(i, j int) bool { return this[i].field < this[j].field } |
| 94 | |
| 95 | func (this sortableExtensions) String() string { |
| 96 | sort.Sort(this) |
| 97 | ss := make([]string, len(this)) |
| 98 | for i := range this { |
| 99 | ss[i] = fmt.Sprintf("%d: %v", this[i].field, this[i].ext) |
| 100 | } |
| 101 | return "map[" + strings.Join(ss, ",") + "]" |
| 102 | } |
| 103 | |
| 104 | func StringFromInternalExtension(m extendableProto) string { |
| 105 | return StringFromExtensionsMap(m.extensionsWrite()) |
| 106 | } |
| 107 | |
| 108 | func StringFromExtensionsMap(m map[int32]Extension) string { |
| 109 | return newSortableExtensionsFromMap(m).String() |
| 110 | } |
| 111 | |
| 112 | func StringFromExtensionsBytes(ext []byte) string { |
| 113 | m, err := BytesToExtensionsMap(ext) |
| 114 | if err != nil { |
| 115 | panic(err) |
| 116 | } |
| 117 | return StringFromExtensionsMap(m) |
| 118 | } |
| 119 | |
| 120 | func EncodeInternalExtension(m extendableProto, data []byte) (n int, err error) { |
| 121 | return EncodeExtensionMap(m.extensionsWrite(), data) |
| 122 | } |
| 123 | |
| 124 | func EncodeExtensionMap(m map[int32]Extension, data []byte) (n int, err error) { |
| 125 | if err := encodeExtensionsMap(m); err != nil { |
| 126 | return 0, err |
| 127 | } |
| 128 | keys := make([]int, 0, len(m)) |
| 129 | for k := range m { |
| 130 | keys = append(keys, int(k)) |
| 131 | } |
| 132 | sort.Ints(keys) |
| 133 | for _, k := range keys { |
| 134 | n += copy(data[n:], m[int32(k)].enc) |
| 135 | } |
| 136 | return n, nil |
| 137 | } |
| 138 | |
| 139 | func GetRawExtension(m map[int32]Extension, id int32) ([]byte, error) { |
| 140 | if m[id].value == nil || m[id].desc == nil { |
| 141 | return m[id].enc, nil |
| 142 | } |
| 143 | if err := encodeExtensionsMap(m); err != nil { |
| 144 | return nil, err |
| 145 | } |
| 146 | return m[id].enc, nil |
| 147 | } |
| 148 | |
| 149 | func size(buf []byte, wire int) (int, error) { |
| 150 | switch wire { |
| 151 | case WireVarint: |
| 152 | _, n := DecodeVarint(buf) |
| 153 | return n, nil |
| 154 | case WireFixed64: |
| 155 | return 8, nil |
| 156 | case WireBytes: |
| 157 | v, n := DecodeVarint(buf) |
| 158 | return int(v) + n, nil |
| 159 | case WireFixed32: |
| 160 | return 4, nil |
| 161 | case WireStartGroup: |
| 162 | offset := 0 |
| 163 | for { |
| 164 | u, n := DecodeVarint(buf[offset:]) |
| 165 | fwire := int(u & 0x7) |
| 166 | offset += n |
| 167 | if fwire == WireEndGroup { |
| 168 | return offset, nil |
| 169 | } |
| 170 | s, err := size(buf[offset:], wire) |
| 171 | if err != nil { |
| 172 | return 0, err |
| 173 | } |
| 174 | offset += s |
| 175 | } |
| 176 | } |
| 177 | return 0, fmt.Errorf("proto: can't get size for unknown wire type %d", wire) |
| 178 | } |
| 179 | |
| 180 | func BytesToExtensionsMap(buf []byte) (map[int32]Extension, error) { |
| 181 | m := make(map[int32]Extension) |
| 182 | i := 0 |
| 183 | for i < len(buf) { |
| 184 | tag, n := DecodeVarint(buf[i:]) |
| 185 | if n <= 0 { |
| 186 | return nil, fmt.Errorf("unable to decode varint") |
| 187 | } |
| 188 | fieldNum := int32(tag >> 3) |
| 189 | wireType := int(tag & 0x7) |
| 190 | l, err := size(buf[i+n:], wireType) |
| 191 | if err != nil { |
| 192 | return nil, err |
| 193 | } |
| 194 | end := i + int(l) + n |
| 195 | m[int32(fieldNum)] = Extension{enc: buf[i:end]} |
| 196 | i = end |
| 197 | } |
| 198 | return m, nil |
| 199 | } |
| 200 | |
| 201 | func NewExtension(e []byte) Extension { |
| 202 | ee := Extension{enc: make([]byte, len(e))} |
| 203 | copy(ee.enc, e) |
| 204 | return ee |
| 205 | } |
| 206 | |
| 207 | func AppendExtension(e Message, tag int32, buf []byte) { |
| 208 | if ee, eok := e.(extensionsBytes); eok { |
| 209 | ext := ee.GetExtensions() |
| 210 | *ext = append(*ext, buf...) |
| 211 | return |
| 212 | } |
| 213 | if ee, eok := e.(extendableProto); eok { |
| 214 | m := ee.extensionsWrite() |
| 215 | ext := m[int32(tag)] // may be missing |
| 216 | ext.enc = append(ext.enc, buf...) |
| 217 | m[int32(tag)] = ext |
| 218 | } |
| 219 | } |
| 220 | |
| 221 | func encodeExtension(e *Extension) error { |
| 222 | if e.value == nil || e.desc == nil { |
| 223 | // Extension is only in its encoded form. |
| 224 | return nil |
| 225 | } |
| 226 | // We don't skip extensions that have an encoded form set, |
| 227 | // because the extension value may have been mutated after |
| 228 | // the last time this function was called. |
| 229 | |
| 230 | et := reflect.TypeOf(e.desc.ExtensionType) |
| 231 | props := extensionProperties(e.desc) |
| 232 | |
| 233 | p := NewBuffer(nil) |
| 234 | // If e.value has type T, the encoder expects a *struct{ X T }. |
| 235 | // Pass a *T with a zero field and hope it all works out. |
| 236 | x := reflect.New(et) |
| 237 | x.Elem().Set(reflect.ValueOf(e.value)) |
| 238 | if err := props.enc(p, props, toStructPointer(x)); err != nil { |
| 239 | return err |
| 240 | } |
| 241 | e.enc = p.buf |
| 242 | return nil |
| 243 | } |
| 244 | |
| 245 | func (this Extension) GoString() string { |
| 246 | if this.enc == nil { |
| 247 | if err := encodeExtension(&this); err != nil { |
| 248 | panic(err) |
| 249 | } |
| 250 | } |
| 251 | return fmt.Sprintf("proto.NewExtension(%#v)", this.enc) |
| 252 | } |
| 253 | |
| 254 | func SetUnsafeExtension(pb Message, fieldNum int32, value interface{}) error { |
| 255 | typ := reflect.TypeOf(pb).Elem() |
| 256 | ext, ok := extensionMaps[typ] |
| 257 | if !ok { |
| 258 | return fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String()) |
| 259 | } |
| 260 | desc, ok := ext[fieldNum] |
| 261 | if !ok { |
| 262 | return errors.New("proto: bad extension number; not in declared ranges") |
| 263 | } |
| 264 | return SetExtension(pb, desc, value) |
| 265 | } |
| 266 | |
| 267 | func GetUnsafeExtension(pb Message, fieldNum int32) (interface{}, error) { |
| 268 | typ := reflect.TypeOf(pb).Elem() |
| 269 | ext, ok := extensionMaps[typ] |
| 270 | if !ok { |
| 271 | return nil, fmt.Errorf("proto: bad extended type; %s is not extendable", typ.String()) |
| 272 | } |
| 273 | desc, ok := ext[fieldNum] |
| 274 | if !ok { |
| 275 | return nil, fmt.Errorf("unregistered field number %d", fieldNum) |
| 276 | } |
| 277 | return GetExtension(pb, desc) |
| 278 | } |
| 279 | |
| 280 | func NewUnsafeXXX_InternalExtensions(m map[int32]Extension) XXX_InternalExtensions { |
| 281 | x := &XXX_InternalExtensions{ |
| 282 | p: new(struct { |
| 283 | mu sync.Mutex |
| 284 | extensionMap map[int32]Extension |
| 285 | }), |
| 286 | } |
| 287 | x.p.extensionMap = m |
| 288 | return *x |
| 289 | } |
| 290 | |
| 291 | func GetUnsafeExtensionsMap(extendable Message) map[int32]Extension { |
| 292 | pb := extendable.(extendableProto) |
| 293 | return pb.extensionsWrite() |
| 294 | } |