blob: 08d35170b66cc6768261b2bac2e422ba490fbc5b [file] [log] [blame]
khenaidood948f772021-08-11 17:49:24 -04001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "sync"
9 "sync/atomic"
10
11 "google.golang.org/protobuf/encoding/protowire"
12 "google.golang.org/protobuf/internal/errors"
13 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16type extensionFieldInfo struct {
17 wiretag uint64
18 tagsize int
19 unmarshalNeedsValue bool
20 funcs valueCoderFuncs
21 validation validationInfo
22}
23
24var legacyExtensionFieldInfoCache sync.Map // map[protoreflect.ExtensionType]*extensionFieldInfo
25
26func getExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
27 if xi, ok := xt.(*ExtensionInfo); ok {
28 xi.lazyInit()
29 return xi.info
30 }
31 return legacyLoadExtensionFieldInfo(xt)
32}
33
34// legacyLoadExtensionFieldInfo dynamically loads a *ExtensionInfo for xt.
35func legacyLoadExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
36 if xi, ok := legacyExtensionFieldInfoCache.Load(xt); ok {
37 return xi.(*extensionFieldInfo)
38 }
39 e := makeExtensionFieldInfo(xt.TypeDescriptor())
40 if e, ok := legacyMessageTypeCache.LoadOrStore(xt, e); ok {
41 return e.(*extensionFieldInfo)
42 }
43 return e
44}
45
46func makeExtensionFieldInfo(xd pref.ExtensionDescriptor) *extensionFieldInfo {
47 var wiretag uint64
48 if !xd.IsPacked() {
49 wiretag = protowire.EncodeTag(xd.Number(), wireTypes[xd.Kind()])
50 } else {
51 wiretag = protowire.EncodeTag(xd.Number(), protowire.BytesType)
52 }
53 e := &extensionFieldInfo{
54 wiretag: wiretag,
55 tagsize: protowire.SizeVarint(wiretag),
56 funcs: encoderFuncsForValue(xd),
57 }
58 // Does the unmarshal function need a value passed to it?
59 // This is true for composite types, where we pass in a message, list, or map to fill in,
60 // and for enums, where we pass in a prototype value to specify the concrete enum type.
61 switch xd.Kind() {
62 case pref.MessageKind, pref.GroupKind, pref.EnumKind:
63 e.unmarshalNeedsValue = true
64 default:
65 if xd.Cardinality() == pref.Repeated {
66 e.unmarshalNeedsValue = true
67 }
68 }
69 return e
70}
71
72type lazyExtensionValue struct {
73 atomicOnce uint32 // atomically set if value is valid
74 mu sync.Mutex
75 xi *extensionFieldInfo
76 value pref.Value
77 b []byte
78 fn func() pref.Value
79}
80
81type ExtensionField struct {
82 typ pref.ExtensionType
83
84 // value is either the value of GetValue,
85 // or a *lazyExtensionValue that then returns the value of GetValue.
86 value pref.Value
87 lazy *lazyExtensionValue
88}
89
90func (f *ExtensionField) appendLazyBytes(xt pref.ExtensionType, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, b []byte) {
91 if f.lazy == nil {
92 f.lazy = &lazyExtensionValue{xi: xi}
93 }
94 f.typ = xt
95 f.lazy.xi = xi
96 f.lazy.b = protowire.AppendTag(f.lazy.b, num, wtyp)
97 f.lazy.b = append(f.lazy.b, b...)
98}
99
100func (f *ExtensionField) canLazy(xt pref.ExtensionType) bool {
101 if f.typ == nil {
102 return true
103 }
104 if f.typ == xt && f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0 {
105 return true
106 }
107 return false
108}
109
110func (f *ExtensionField) lazyInit() {
111 f.lazy.mu.Lock()
112 defer f.lazy.mu.Unlock()
113 if atomic.LoadUint32(&f.lazy.atomicOnce) == 1 {
114 return
115 }
116 if f.lazy.xi != nil {
117 b := f.lazy.b
118 val := f.typ.New()
119 for len(b) > 0 {
120 var tag uint64
121 if b[0] < 0x80 {
122 tag = uint64(b[0])
123 b = b[1:]
124 } else if len(b) >= 2 && b[1] < 128 {
125 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
126 b = b[2:]
127 } else {
128 var n int
129 tag, n = protowire.ConsumeVarint(b)
130 if n < 0 {
131 panic(errors.New("bad tag in lazy extension decoding"))
132 }
133 b = b[n:]
134 }
135 num := protowire.Number(tag >> 3)
136 wtyp := protowire.Type(tag & 7)
137 var out unmarshalOutput
138 var err error
139 val, out, err = f.lazy.xi.funcs.unmarshal(b, val, num, wtyp, lazyUnmarshalOptions)
140 if err != nil {
141 panic(errors.New("decode failure in lazy extension decoding: %v", err))
142 }
143 b = b[out.n:]
144 }
145 f.lazy.value = val
146 } else {
147 f.lazy.value = f.lazy.fn()
148 }
149 f.lazy.xi = nil
150 f.lazy.fn = nil
151 f.lazy.b = nil
152 atomic.StoreUint32(&f.lazy.atomicOnce, 1)
153}
154
155// Set sets the type and value of the extension field.
156// This must not be called concurrently.
157func (f *ExtensionField) Set(t pref.ExtensionType, v pref.Value) {
158 f.typ = t
159 f.value = v
160 f.lazy = nil
161}
162
163// SetLazy sets the type and a value that is to be lazily evaluated upon first use.
164// This must not be called concurrently.
165func (f *ExtensionField) SetLazy(t pref.ExtensionType, fn func() pref.Value) {
166 f.typ = t
167 f.lazy = &lazyExtensionValue{fn: fn}
168}
169
170// Value returns the value of the extension field.
171// This may be called concurrently.
172func (f *ExtensionField) Value() pref.Value {
173 if f.lazy != nil {
174 if atomic.LoadUint32(&f.lazy.atomicOnce) == 0 {
175 f.lazyInit()
176 }
177 return f.lazy.value
178 }
179 return f.value
180}
181
182// Type returns the type of the extension field.
183// This may be called concurrently.
184func (f ExtensionField) Type() pref.ExtensionType {
185 return f.typ
186}
187
188// IsSet returns whether the extension field is set.
189// This may be called concurrently.
190func (f ExtensionField) IsSet() bool {
191 return f.typ != nil
192}
193
194// IsLazy reports whether a field is lazily encoded.
195// It is exported for testing.
196func IsLazy(m pref.Message, fd pref.FieldDescriptor) bool {
197 var mi *MessageInfo
198 var p pointer
199 switch m := m.(type) {
200 case *messageState:
201 mi = m.messageInfo()
202 p = m.pointer()
203 case *messageReflectWrapper:
204 mi = m.messageInfo()
205 p = m.pointer()
206 default:
207 return false
208 }
209 xd, ok := fd.(pref.ExtensionTypeDescriptor)
210 if !ok {
211 return false
212 }
213 xt := xd.Type()
214 ext := mi.extensionMap(p)
215 if ext == nil {
216 return false
217 }
218 f, ok := (*ext)[int32(fd.Number())]
219 if !ok {
220 return false
221 }
222 return f.typ == xt && f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0
223}