Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 1 | // Copyright 2010 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 4 | |
| 5 | package proto |
| 6 | |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 7 | import ( |
| 8 | "errors" |
| 9 | "fmt" |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 10 | "reflect" |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 11 | |
| 12 | "google.golang.org/protobuf/encoding/protowire" |
| 13 | "google.golang.org/protobuf/proto" |
| 14 | "google.golang.org/protobuf/reflect/protoreflect" |
| 15 | "google.golang.org/protobuf/reflect/protoregistry" |
| 16 | "google.golang.org/protobuf/runtime/protoiface" |
| 17 | "google.golang.org/protobuf/runtime/protoimpl" |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 18 | ) |
| 19 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 20 | type ( |
| 21 | // ExtensionDesc represents an extension descriptor and |
| 22 | // is used to interact with an extension field in a message. |
| 23 | // |
| 24 | // Variables of this type are generated in code by protoc-gen-go. |
| 25 | ExtensionDesc = protoimpl.ExtensionInfo |
| 26 | |
| 27 | // ExtensionRange represents a range of message extensions. |
| 28 | // Used in code generated by protoc-gen-go. |
| 29 | ExtensionRange = protoiface.ExtensionRangeV1 |
| 30 | |
| 31 | // Deprecated: Do not use; this is an internal type. |
| 32 | Extension = protoimpl.ExtensionFieldV1 |
| 33 | |
| 34 | // Deprecated: Do not use; this is an internal type. |
| 35 | XXX_InternalExtensions = protoimpl.ExtensionFields |
| 36 | ) |
| 37 | |
| 38 | // ErrMissingExtension reports whether the extension was not present. |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 39 | var ErrMissingExtension = errors.New("proto: missing extension") |
| 40 | |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 41 | var errNotExtendable = errors.New("proto: not an extendable proto.Message") |
| 42 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 43 | // HasExtension reports whether the extension field is present in m |
| 44 | // either as an explicitly populated field or as an unknown field. |
| 45 | func HasExtension(m Message, xt *ExtensionDesc) (has bool) { |
| 46 | mr := MessageReflect(m) |
| 47 | if mr == nil || !mr.IsValid() { |
| 48 | return false |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 49 | } |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 50 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 51 | // Check whether any populated known field matches the field number. |
| 52 | xtd := xt.TypeDescriptor() |
| 53 | if isValidExtension(mr.Descriptor(), xtd) { |
| 54 | has = mr.Has(xtd) |
| 55 | } else { |
| 56 | mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { |
| 57 | has = int32(fd.Number()) == xt.Field |
| 58 | return !has |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 59 | }) |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 60 | } |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 61 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 62 | // Check whether any unknown field matches the field number. |
| 63 | for b := mr.GetUnknown(); !has && len(b) > 0; { |
| 64 | num, _, n := protowire.ConsumeField(b) |
| 65 | has = int32(num) == xt.Field |
| 66 | b = b[n:] |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 67 | } |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 68 | return has |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 69 | } |
| 70 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 71 | // ClearExtension removes the extension field from m |
| 72 | // either as an explicitly populated field or as an unknown field. |
| 73 | func ClearExtension(m Message, xt *ExtensionDesc) { |
| 74 | mr := MessageReflect(m) |
| 75 | if mr == nil || !mr.IsValid() { |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 76 | return |
| 77 | } |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 78 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 79 | xtd := xt.TypeDescriptor() |
| 80 | if isValidExtension(mr.Descriptor(), xtd) { |
| 81 | mr.Clear(xtd) |
| 82 | } else { |
| 83 | mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { |
| 84 | if int32(fd.Number()) == xt.Field { |
| 85 | mr.Clear(fd) |
| 86 | return false |
| 87 | } |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 88 | return true |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 89 | }) |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 90 | } |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 91 | clearUnknown(mr, fieldNum(xt.Field)) |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 92 | } |
| 93 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 94 | // ClearAllExtensions clears all extensions from m. |
| 95 | // This includes populated fields and unknown fields in the extension range. |
| 96 | func ClearAllExtensions(m Message) { |
| 97 | mr := MessageReflect(m) |
| 98 | if mr == nil || !mr.IsValid() { |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 99 | return |
| 100 | } |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 101 | |
| 102 | mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { |
| 103 | if fd.IsExtension() { |
| 104 | mr.Clear(fd) |
| 105 | } |
| 106 | return true |
| 107 | }) |
| 108 | clearUnknown(mr, mr.Descriptor().ExtensionRanges()) |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 109 | } |
| 110 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 111 | // GetExtension retrieves a proto2 extended field from m. |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 112 | // |
| 113 | // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil), |
| 114 | // then GetExtension parses the encoded field and returns a Go value of the specified type. |
| 115 | // If the field is not present, then the default value is returned (if one is specified), |
| 116 | // otherwise ErrMissingExtension is reported. |
| 117 | // |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 118 | // If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil), |
| 119 | // then GetExtension returns the raw encoded bytes for the extension field. |
| 120 | func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) { |
| 121 | mr := MessageReflect(m) |
| 122 | if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 { |
| 123 | return nil, errNotExtendable |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 124 | } |
| 125 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 126 | // Retrieve the unknown fields for this extension field. |
| 127 | var bo protoreflect.RawFields |
| 128 | for bi := mr.GetUnknown(); len(bi) > 0; { |
| 129 | num, _, n := protowire.ConsumeField(bi) |
| 130 | if int32(num) == xt.Field { |
| 131 | bo = append(bo, bi[:n]...) |
| 132 | } |
| 133 | bi = bi[n:] |
| 134 | } |
| 135 | |
| 136 | // For type incomplete descriptors, only retrieve the unknown fields. |
| 137 | if xt.ExtensionType == nil { |
| 138 | return []byte(bo), nil |
| 139 | } |
| 140 | |
| 141 | // If the extension field only exists as unknown fields, unmarshal it. |
| 142 | // This is rarely done since proto.Unmarshal eagerly unmarshals extensions. |
| 143 | xtd := xt.TypeDescriptor() |
| 144 | if !isValidExtension(mr.Descriptor(), xtd) { |
| 145 | return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m) |
| 146 | } |
| 147 | if !mr.Has(xtd) && len(bo) > 0 { |
| 148 | m2 := mr.New() |
| 149 | if err := (proto.UnmarshalOptions{ |
| 150 | Resolver: extensionResolver{xt}, |
| 151 | }.Unmarshal(bo, m2.Interface())); err != nil { |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 152 | return nil, err |
| 153 | } |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 154 | if m2.Has(xtd) { |
| 155 | mr.Set(xtd, m2.Get(xtd)) |
| 156 | clearUnknown(mr, fieldNum(xt.Field)) |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 157 | } |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 158 | } |
| 159 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 160 | // Check whether the message has the extension field set or a default. |
| 161 | var pv protoreflect.Value |
| 162 | switch { |
| 163 | case mr.Has(xtd): |
| 164 | pv = mr.Get(xtd) |
| 165 | case xtd.HasDefault(): |
| 166 | pv = xtd.Default() |
| 167 | default: |
khenaidoo | ac63710 | 2019-01-14 15:44:34 -0500 | [diff] [blame] | 168 | return nil, ErrMissingExtension |
| 169 | } |
| 170 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 171 | v := xt.InterfaceOf(pv) |
| 172 | rv := reflect.ValueOf(v) |
| 173 | if isScalarKind(rv.Kind()) { |
William Kurkian | daa6bb2 | 2019-03-07 12:26:28 -0500 | [diff] [blame] | 174 | rv2 := reflect.New(rv.Type()) |
| 175 | rv2.Elem().Set(rv) |
| 176 | v = rv2.Interface() |
William Kurkian | daa6bb2 | 2019-03-07 12:26:28 -0500 | [diff] [blame] | 177 | } |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 178 | return v, nil |
William Kurkian | daa6bb2 | 2019-03-07 12:26:28 -0500 | [diff] [blame] | 179 | } |
| 180 | |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 181 | // extensionResolver is a custom extension resolver that stores a single |
| 182 | // extension type that takes precedence over the global registry. |
| 183 | type extensionResolver struct{ xt protoreflect.ExtensionType } |
| 184 | |
| 185 | func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { |
| 186 | if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field { |
| 187 | return r.xt, nil |
| 188 | } |
| 189 | return protoregistry.GlobalTypes.FindExtensionByName(field) |
| 190 | } |
| 191 | |
| 192 | func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { |
| 193 | if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field { |
| 194 | return r.xt, nil |
| 195 | } |
| 196 | return protoregistry.GlobalTypes.FindExtensionByNumber(message, field) |
| 197 | } |
| 198 | |
| 199 | // GetExtensions returns a list of the extensions values present in m, |
| 200 | // corresponding with the provided list of extension descriptors, xts. |
| 201 | // If an extension is missing in m, the corresponding value is nil. |
| 202 | func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) { |
| 203 | mr := MessageReflect(m) |
| 204 | if mr == nil || !mr.IsValid() { |
| 205 | return nil, errNotExtendable |
| 206 | } |
| 207 | |
| 208 | vs := make([]interface{}, len(xts)) |
| 209 | for i, xt := range xts { |
| 210 | v, err := GetExtension(m, xt) |
| 211 | if err != nil { |
| 212 | if err == ErrMissingExtension { |
| 213 | continue |
William Kurkian | daa6bb2 | 2019-03-07 12:26:28 -0500 | [diff] [blame] | 214 | } |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 215 | return vs, err |
William Kurkian | daa6bb2 | 2019-03-07 12:26:28 -0500 | [diff] [blame] | 216 | } |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 217 | vs[i] = v |
| 218 | } |
| 219 | return vs, nil |
| 220 | } |
| 221 | |
| 222 | // SetExtension sets an extension field in m to the provided value. |
| 223 | func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error { |
| 224 | mr := MessageReflect(m) |
| 225 | if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 { |
| 226 | return errNotExtendable |
| 227 | } |
| 228 | |
| 229 | rv := reflect.ValueOf(v) |
| 230 | if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) { |
| 231 | return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType) |
| 232 | } |
| 233 | if rv.Kind() == reflect.Ptr { |
| 234 | if rv.IsNil() { |
| 235 | return fmt.Errorf("proto: SetExtension called with nil value of type %T", v) |
| 236 | } |
| 237 | if isScalarKind(rv.Elem().Kind()) { |
| 238 | v = rv.Elem().Interface() |
William Kurkian | daa6bb2 | 2019-03-07 12:26:28 -0500 | [diff] [blame] | 239 | } |
| 240 | } |
Andrea Campanella | 3614a92 | 2021-02-25 12:40:42 +0100 | [diff] [blame^] | 241 | |
| 242 | xtd := xt.TypeDescriptor() |
| 243 | if !isValidExtension(mr.Descriptor(), xtd) { |
| 244 | return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m) |
| 245 | } |
| 246 | mr.Set(xtd, xt.ValueOf(v)) |
| 247 | clearUnknown(mr, fieldNum(xt.Field)) |
| 248 | return nil |
| 249 | } |
| 250 | |
| 251 | // SetRawExtension inserts b into the unknown fields of m. |
| 252 | // |
| 253 | // Deprecated: Use Message.ProtoReflect.SetUnknown instead. |
| 254 | func SetRawExtension(m Message, fnum int32, b []byte) { |
| 255 | mr := MessageReflect(m) |
| 256 | if mr == nil || !mr.IsValid() { |
| 257 | return |
| 258 | } |
| 259 | |
| 260 | // Verify that the raw field is valid. |
| 261 | for b0 := b; len(b0) > 0; { |
| 262 | num, _, n := protowire.ConsumeField(b0) |
| 263 | if int32(num) != fnum { |
| 264 | panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum)) |
| 265 | } |
| 266 | b0 = b0[n:] |
| 267 | } |
| 268 | |
| 269 | ClearExtension(m, &ExtensionDesc{Field: fnum}) |
| 270 | mr.SetUnknown(append(mr.GetUnknown(), b...)) |
| 271 | } |
| 272 | |
| 273 | // ExtensionDescs returns a list of extension descriptors found in m, |
| 274 | // containing descriptors for both populated extension fields in m and |
| 275 | // also unknown fields of m that are in the extension range. |
| 276 | // For the later case, an type incomplete descriptor is provided where only |
| 277 | // the ExtensionDesc.Field field is populated. |
| 278 | // The order of the extension descriptors is undefined. |
| 279 | func ExtensionDescs(m Message) ([]*ExtensionDesc, error) { |
| 280 | mr := MessageReflect(m) |
| 281 | if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 { |
| 282 | return nil, errNotExtendable |
| 283 | } |
| 284 | |
| 285 | // Collect a set of known extension descriptors. |
| 286 | extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc) |
| 287 | mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { |
| 288 | if fd.IsExtension() { |
| 289 | xt := fd.(protoreflect.ExtensionTypeDescriptor) |
| 290 | if xd, ok := xt.Type().(*ExtensionDesc); ok { |
| 291 | extDescs[fd.Number()] = xd |
| 292 | } |
| 293 | } |
| 294 | return true |
| 295 | }) |
| 296 | |
| 297 | // Collect a set of unknown extension descriptors. |
| 298 | extRanges := mr.Descriptor().ExtensionRanges() |
| 299 | for b := mr.GetUnknown(); len(b) > 0; { |
| 300 | num, _, n := protowire.ConsumeField(b) |
| 301 | if extRanges.Has(num) && extDescs[num] == nil { |
| 302 | extDescs[num] = nil |
| 303 | } |
| 304 | b = b[n:] |
| 305 | } |
| 306 | |
| 307 | // Transpose the set of descriptors into a list. |
| 308 | var xts []*ExtensionDesc |
| 309 | for num, xt := range extDescs { |
| 310 | if xt == nil { |
| 311 | xt = &ExtensionDesc{Field: int32(num)} |
| 312 | } |
| 313 | xts = append(xts, xt) |
| 314 | } |
| 315 | return xts, nil |
| 316 | } |
| 317 | |
| 318 | // isValidExtension reports whether xtd is a valid extension descriptor for md. |
| 319 | func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool { |
| 320 | return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number()) |
| 321 | } |
| 322 | |
| 323 | // isScalarKind reports whether k is a protobuf scalar kind (except bytes). |
| 324 | // This function exists for historical reasons since the representation of |
| 325 | // scalars differs between v1 and v2, where v1 uses *T and v2 uses T. |
| 326 | func isScalarKind(k reflect.Kind) bool { |
| 327 | switch k { |
| 328 | case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: |
| 329 | return true |
| 330 | default: |
| 331 | return false |
| 332 | } |
| 333 | } |
| 334 | |
| 335 | // clearUnknown removes unknown fields from m where remover.Has reports true. |
| 336 | func clearUnknown(m protoreflect.Message, remover interface { |
| 337 | Has(protoreflect.FieldNumber) bool |
| 338 | }) { |
| 339 | var bo protoreflect.RawFields |
| 340 | for bi := m.GetUnknown(); len(bi) > 0; { |
| 341 | num, _, n := protowire.ConsumeField(bi) |
| 342 | if !remover.Has(num) { |
| 343 | bo = append(bo, bi[:n]...) |
| 344 | } |
| 345 | bi = bi[n:] |
| 346 | } |
| 347 | if bi := m.GetUnknown(); len(bi) != len(bo) { |
| 348 | m.SetUnknown(bo) |
| 349 | } |
| 350 | } |
| 351 | |
| 352 | type fieldNum protoreflect.FieldNumber |
| 353 | |
| 354 | func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool { |
| 355 | return protoreflect.FieldNumber(n1) == n2 |
William Kurkian | daa6bb2 | 2019-03-07 12:26:28 -0500 | [diff] [blame] | 356 | } |