blob: 42fc120c972b8399411cf9e5e3bc73644d166525 [file] [log] [blame]
Andrea Campanella3614a922021-02-25 12:40:42 +01001// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
khenaidooac637102019-01-14 15:44:34 -05004
5package proto
6
khenaidooac637102019-01-14 15:44:34 -05007import (
8 "errors"
9 "fmt"
khenaidooac637102019-01-14 15:44:34 -050010 "reflect"
Andrea Campanella3614a922021-02-25 12:40:42 +010011
12 "google.golang.org/protobuf/encoding/protowire"
13 "google.golang.org/protobuf/proto"
14 "google.golang.org/protobuf/reflect/protoreflect"
15 "google.golang.org/protobuf/reflect/protoregistry"
16 "google.golang.org/protobuf/runtime/protoiface"
17 "google.golang.org/protobuf/runtime/protoimpl"
khenaidooac637102019-01-14 15:44:34 -050018)
19
Andrea Campanella3614a922021-02-25 12:40:42 +010020type (
21 // ExtensionDesc represents an extension descriptor and
22 // is used to interact with an extension field in a message.
23 //
24 // Variables of this type are generated in code by protoc-gen-go.
25 ExtensionDesc = protoimpl.ExtensionInfo
26
27 // ExtensionRange represents a range of message extensions.
28 // Used in code generated by protoc-gen-go.
29 ExtensionRange = protoiface.ExtensionRangeV1
30
31 // Deprecated: Do not use; this is an internal type.
32 Extension = protoimpl.ExtensionFieldV1
33
34 // Deprecated: Do not use; this is an internal type.
35 XXX_InternalExtensions = protoimpl.ExtensionFields
36)
37
38// ErrMissingExtension reports whether the extension was not present.
khenaidooac637102019-01-14 15:44:34 -050039var ErrMissingExtension = errors.New("proto: missing extension")
40
khenaidooac637102019-01-14 15:44:34 -050041var errNotExtendable = errors.New("proto: not an extendable proto.Message")
42
Andrea Campanella3614a922021-02-25 12:40:42 +010043// HasExtension reports whether the extension field is present in m
44// either as an explicitly populated field or as an unknown field.
45func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
46 mr := MessageReflect(m)
47 if mr == nil || !mr.IsValid() {
48 return false
khenaidooac637102019-01-14 15:44:34 -050049 }
khenaidooac637102019-01-14 15:44:34 -050050
Andrea Campanella3614a922021-02-25 12:40:42 +010051 // Check whether any populated known field matches the field number.
52 xtd := xt.TypeDescriptor()
53 if isValidExtension(mr.Descriptor(), xtd) {
54 has = mr.Has(xtd)
55 } else {
56 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
57 has = int32(fd.Number()) == xt.Field
58 return !has
khenaidooac637102019-01-14 15:44:34 -050059 })
khenaidooac637102019-01-14 15:44:34 -050060 }
khenaidooac637102019-01-14 15:44:34 -050061
Andrea Campanella3614a922021-02-25 12:40:42 +010062 // Check whether any unknown field matches the field number.
63 for b := mr.GetUnknown(); !has && len(b) > 0; {
64 num, _, n := protowire.ConsumeField(b)
65 has = int32(num) == xt.Field
66 b = b[n:]
khenaidooac637102019-01-14 15:44:34 -050067 }
Andrea Campanella3614a922021-02-25 12:40:42 +010068 return has
khenaidooac637102019-01-14 15:44:34 -050069}
70
Andrea Campanella3614a922021-02-25 12:40:42 +010071// ClearExtension removes the extension field from m
72// either as an explicitly populated field or as an unknown field.
73func ClearExtension(m Message, xt *ExtensionDesc) {
74 mr := MessageReflect(m)
75 if mr == nil || !mr.IsValid() {
khenaidooac637102019-01-14 15:44:34 -050076 return
77 }
khenaidooac637102019-01-14 15:44:34 -050078
Andrea Campanella3614a922021-02-25 12:40:42 +010079 xtd := xt.TypeDescriptor()
80 if isValidExtension(mr.Descriptor(), xtd) {
81 mr.Clear(xtd)
82 } else {
83 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
84 if int32(fd.Number()) == xt.Field {
85 mr.Clear(fd)
86 return false
87 }
khenaidooac637102019-01-14 15:44:34 -050088 return true
Andrea Campanella3614a922021-02-25 12:40:42 +010089 })
khenaidooac637102019-01-14 15:44:34 -050090 }
Andrea Campanella3614a922021-02-25 12:40:42 +010091 clearUnknown(mr, fieldNum(xt.Field))
khenaidooac637102019-01-14 15:44:34 -050092}
93
Andrea Campanella3614a922021-02-25 12:40:42 +010094// ClearAllExtensions clears all extensions from m.
95// This includes populated fields and unknown fields in the extension range.
96func ClearAllExtensions(m Message) {
97 mr := MessageReflect(m)
98 if mr == nil || !mr.IsValid() {
khenaidooac637102019-01-14 15:44:34 -050099 return
100 }
Andrea Campanella3614a922021-02-25 12:40:42 +0100101
102 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
103 if fd.IsExtension() {
104 mr.Clear(fd)
105 }
106 return true
107 })
108 clearUnknown(mr, mr.Descriptor().ExtensionRanges())
khenaidooac637102019-01-14 15:44:34 -0500109}
110
Andrea Campanella3614a922021-02-25 12:40:42 +0100111// GetExtension retrieves a proto2 extended field from m.
khenaidooac637102019-01-14 15:44:34 -0500112//
113// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
114// then GetExtension parses the encoded field and returns a Go value of the specified type.
115// If the field is not present, then the default value is returned (if one is specified),
116// otherwise ErrMissingExtension is reported.
117//
Andrea Campanella3614a922021-02-25 12:40:42 +0100118// If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
119// then GetExtension returns the raw encoded bytes for the extension field.
120func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
121 mr := MessageReflect(m)
122 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
123 return nil, errNotExtendable
khenaidooac637102019-01-14 15:44:34 -0500124 }
125
Andrea Campanella3614a922021-02-25 12:40:42 +0100126 // Retrieve the unknown fields for this extension field.
127 var bo protoreflect.RawFields
128 for bi := mr.GetUnknown(); len(bi) > 0; {
129 num, _, n := protowire.ConsumeField(bi)
130 if int32(num) == xt.Field {
131 bo = append(bo, bi[:n]...)
132 }
133 bi = bi[n:]
134 }
135
136 // For type incomplete descriptors, only retrieve the unknown fields.
137 if xt.ExtensionType == nil {
138 return []byte(bo), nil
139 }
140
141 // If the extension field only exists as unknown fields, unmarshal it.
142 // This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
143 xtd := xt.TypeDescriptor()
144 if !isValidExtension(mr.Descriptor(), xtd) {
145 return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
146 }
147 if !mr.Has(xtd) && len(bo) > 0 {
148 m2 := mr.New()
149 if err := (proto.UnmarshalOptions{
150 Resolver: extensionResolver{xt},
151 }.Unmarshal(bo, m2.Interface())); err != nil {
khenaidooac637102019-01-14 15:44:34 -0500152 return nil, err
153 }
Andrea Campanella3614a922021-02-25 12:40:42 +0100154 if m2.Has(xtd) {
155 mr.Set(xtd, m2.Get(xtd))
156 clearUnknown(mr, fieldNum(xt.Field))
khenaidooac637102019-01-14 15:44:34 -0500157 }
khenaidooac637102019-01-14 15:44:34 -0500158 }
159
Andrea Campanella3614a922021-02-25 12:40:42 +0100160 // Check whether the message has the extension field set or a default.
161 var pv protoreflect.Value
162 switch {
163 case mr.Has(xtd):
164 pv = mr.Get(xtd)
165 case xtd.HasDefault():
166 pv = xtd.Default()
167 default:
khenaidooac637102019-01-14 15:44:34 -0500168 return nil, ErrMissingExtension
169 }
170
Andrea Campanella3614a922021-02-25 12:40:42 +0100171 v := xt.InterfaceOf(pv)
172 rv := reflect.ValueOf(v)
173 if isScalarKind(rv.Kind()) {
William Kurkiandaa6bb22019-03-07 12:26:28 -0500174 rv2 := reflect.New(rv.Type())
175 rv2.Elem().Set(rv)
176 v = rv2.Interface()
William Kurkiandaa6bb22019-03-07 12:26:28 -0500177 }
Andrea Campanella3614a922021-02-25 12:40:42 +0100178 return v, nil
William Kurkiandaa6bb22019-03-07 12:26:28 -0500179}
180
Andrea Campanella3614a922021-02-25 12:40:42 +0100181// extensionResolver is a custom extension resolver that stores a single
182// extension type that takes precedence over the global registry.
183type extensionResolver struct{ xt protoreflect.ExtensionType }
184
185func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
186 if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
187 return r.xt, nil
188 }
189 return protoregistry.GlobalTypes.FindExtensionByName(field)
190}
191
192func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
193 if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
194 return r.xt, nil
195 }
196 return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
197}
198
199// GetExtensions returns a list of the extensions values present in m,
200// corresponding with the provided list of extension descriptors, xts.
201// If an extension is missing in m, the corresponding value is nil.
202func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
203 mr := MessageReflect(m)
204 if mr == nil || !mr.IsValid() {
205 return nil, errNotExtendable
206 }
207
208 vs := make([]interface{}, len(xts))
209 for i, xt := range xts {
210 v, err := GetExtension(m, xt)
211 if err != nil {
212 if err == ErrMissingExtension {
213 continue
William Kurkiandaa6bb22019-03-07 12:26:28 -0500214 }
Andrea Campanella3614a922021-02-25 12:40:42 +0100215 return vs, err
William Kurkiandaa6bb22019-03-07 12:26:28 -0500216 }
Andrea Campanella3614a922021-02-25 12:40:42 +0100217 vs[i] = v
218 }
219 return vs, nil
220}
221
222// SetExtension sets an extension field in m to the provided value.
223func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
224 mr := MessageReflect(m)
225 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
226 return errNotExtendable
227 }
228
229 rv := reflect.ValueOf(v)
230 if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
231 return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
232 }
233 if rv.Kind() == reflect.Ptr {
234 if rv.IsNil() {
235 return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
236 }
237 if isScalarKind(rv.Elem().Kind()) {
238 v = rv.Elem().Interface()
William Kurkiandaa6bb22019-03-07 12:26:28 -0500239 }
240 }
Andrea Campanella3614a922021-02-25 12:40:42 +0100241
242 xtd := xt.TypeDescriptor()
243 if !isValidExtension(mr.Descriptor(), xtd) {
244 return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
245 }
246 mr.Set(xtd, xt.ValueOf(v))
247 clearUnknown(mr, fieldNum(xt.Field))
248 return nil
249}
250
251// SetRawExtension inserts b into the unknown fields of m.
252//
253// Deprecated: Use Message.ProtoReflect.SetUnknown instead.
254func SetRawExtension(m Message, fnum int32, b []byte) {
255 mr := MessageReflect(m)
256 if mr == nil || !mr.IsValid() {
257 return
258 }
259
260 // Verify that the raw field is valid.
261 for b0 := b; len(b0) > 0; {
262 num, _, n := protowire.ConsumeField(b0)
263 if int32(num) != fnum {
264 panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
265 }
266 b0 = b0[n:]
267 }
268
269 ClearExtension(m, &ExtensionDesc{Field: fnum})
270 mr.SetUnknown(append(mr.GetUnknown(), b...))
271}
272
273// ExtensionDescs returns a list of extension descriptors found in m,
274// containing descriptors for both populated extension fields in m and
275// also unknown fields of m that are in the extension range.
276// For the later case, an type incomplete descriptor is provided where only
277// the ExtensionDesc.Field field is populated.
278// The order of the extension descriptors is undefined.
279func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
280 mr := MessageReflect(m)
281 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
282 return nil, errNotExtendable
283 }
284
285 // Collect a set of known extension descriptors.
286 extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
287 mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
288 if fd.IsExtension() {
289 xt := fd.(protoreflect.ExtensionTypeDescriptor)
290 if xd, ok := xt.Type().(*ExtensionDesc); ok {
291 extDescs[fd.Number()] = xd
292 }
293 }
294 return true
295 })
296
297 // Collect a set of unknown extension descriptors.
298 extRanges := mr.Descriptor().ExtensionRanges()
299 for b := mr.GetUnknown(); len(b) > 0; {
300 num, _, n := protowire.ConsumeField(b)
301 if extRanges.Has(num) && extDescs[num] == nil {
302 extDescs[num] = nil
303 }
304 b = b[n:]
305 }
306
307 // Transpose the set of descriptors into a list.
308 var xts []*ExtensionDesc
309 for num, xt := range extDescs {
310 if xt == nil {
311 xt = &ExtensionDesc{Field: int32(num)}
312 }
313 xts = append(xts, xt)
314 }
315 return xts, nil
316}
317
318// isValidExtension reports whether xtd is a valid extension descriptor for md.
319func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
320 return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
321}
322
323// isScalarKind reports whether k is a protobuf scalar kind (except bytes).
324// This function exists for historical reasons since the representation of
325// scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
326func isScalarKind(k reflect.Kind) bool {
327 switch k {
328 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
329 return true
330 default:
331 return false
332 }
333}
334
335// clearUnknown removes unknown fields from m where remover.Has reports true.
336func clearUnknown(m protoreflect.Message, remover interface {
337 Has(protoreflect.FieldNumber) bool
338}) {
339 var bo protoreflect.RawFields
340 for bi := m.GetUnknown(); len(bi) > 0; {
341 num, _, n := protowire.ConsumeField(bi)
342 if !remover.Has(num) {
343 bo = append(bo, bi[:n]...)
344 }
345 bi = bi[n:]
346 }
347 if bi := m.GetUnknown(); len(bi) != len(bo) {
348 m.SetUnknown(bo)
349 }
350}
351
352type fieldNum protoreflect.FieldNumber
353
354func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
355 return protoreflect.FieldNumber(n1) == n2
William Kurkiandaa6bb22019-03-07 12:26:28 -0500356}