Zack Williams | e940c7a | 2019-08-21 14:25:39 -0700 | [diff] [blame] | 1 | package dynamic |
| 2 | |
| 3 | import ( |
| 4 | "fmt" |
| 5 | "reflect" |
| 6 | "sync" |
| 7 | |
| 8 | "github.com/golang/protobuf/proto" |
| 9 | |
| 10 | "github.com/jhump/protoreflect/desc" |
| 11 | ) |
| 12 | |
| 13 | // ExtensionRegistry is a registry of known extension fields. This is used to parse |
| 14 | // extension fields encountered when de-serializing a dynamic message. |
| 15 | type ExtensionRegistry struct { |
| 16 | includeDefault bool |
| 17 | mu sync.RWMutex |
| 18 | exts map[string]map[int32]*desc.FieldDescriptor |
| 19 | } |
| 20 | |
| 21 | // NewExtensionRegistryWithDefaults is a registry that includes all "default" extensions, |
| 22 | // which are those that are statically linked into the current program (e.g. registered by |
| 23 | // protoc-generated code via proto.RegisterExtension). Extensions explicitly added to the |
| 24 | // registry will override any default extensions that are for the same extendee and have the |
| 25 | // same tag number and/or name. |
| 26 | func NewExtensionRegistryWithDefaults() *ExtensionRegistry { |
| 27 | return &ExtensionRegistry{includeDefault: true} |
| 28 | } |
| 29 | |
| 30 | // AddExtensionDesc adds the given extensions to the registry. |
| 31 | func (r *ExtensionRegistry) AddExtensionDesc(exts ...*proto.ExtensionDesc) error { |
| 32 | flds := make([]*desc.FieldDescriptor, len(exts)) |
| 33 | for i, ext := range exts { |
| 34 | fd, err := desc.LoadFieldDescriptorForExtension(ext) |
| 35 | if err != nil { |
| 36 | return err |
| 37 | } |
| 38 | flds[i] = fd |
| 39 | } |
| 40 | r.mu.Lock() |
| 41 | defer r.mu.Unlock() |
| 42 | if r.exts == nil { |
| 43 | r.exts = map[string]map[int32]*desc.FieldDescriptor{} |
| 44 | } |
| 45 | for _, fd := range flds { |
| 46 | r.putExtensionLocked(fd) |
| 47 | } |
| 48 | return nil |
| 49 | } |
| 50 | |
| 51 | // AddExtension adds the given extensions to the registry. The given extensions |
| 52 | // will overwrite any previously added extensions that are for the same extendee |
| 53 | // message and same extension tag number. |
| 54 | func (r *ExtensionRegistry) AddExtension(exts ...*desc.FieldDescriptor) error { |
| 55 | for _, ext := range exts { |
| 56 | if !ext.IsExtension() { |
| 57 | return fmt.Errorf("given field is not an extension: %s", ext.GetFullyQualifiedName()) |
| 58 | } |
| 59 | } |
| 60 | r.mu.Lock() |
| 61 | defer r.mu.Unlock() |
| 62 | if r.exts == nil { |
| 63 | r.exts = map[string]map[int32]*desc.FieldDescriptor{} |
| 64 | } |
| 65 | for _, ext := range exts { |
| 66 | r.putExtensionLocked(ext) |
| 67 | } |
| 68 | return nil |
| 69 | } |
| 70 | |
| 71 | // AddExtensionsFromFile adds to the registry all extension fields defined in the given file descriptor. |
| 72 | func (r *ExtensionRegistry) AddExtensionsFromFile(fd *desc.FileDescriptor) { |
| 73 | r.mu.Lock() |
| 74 | defer r.mu.Unlock() |
| 75 | r.addExtensionsFromFileLocked(fd, false, nil) |
| 76 | } |
| 77 | |
| 78 | // AddExtensionsFromFileRecursively adds to the registry all extension fields defined in the give file |
| 79 | // descriptor and also recursively adds all extensions defined in that file's dependencies. This adds |
| 80 | // extensions from the entire transitive closure for the given file. |
| 81 | func (r *ExtensionRegistry) AddExtensionsFromFileRecursively(fd *desc.FileDescriptor) { |
| 82 | r.mu.Lock() |
| 83 | defer r.mu.Unlock() |
| 84 | already := map[*desc.FileDescriptor]struct{}{} |
| 85 | r.addExtensionsFromFileLocked(fd, true, already) |
| 86 | } |
| 87 | |
| 88 | func (r *ExtensionRegistry) addExtensionsFromFileLocked(fd *desc.FileDescriptor, recursive bool, alreadySeen map[*desc.FileDescriptor]struct{}) { |
| 89 | if _, ok := alreadySeen[fd]; ok { |
| 90 | return |
| 91 | } |
| 92 | |
| 93 | if r.exts == nil { |
| 94 | r.exts = map[string]map[int32]*desc.FieldDescriptor{} |
| 95 | } |
| 96 | for _, ext := range fd.GetExtensions() { |
| 97 | r.putExtensionLocked(ext) |
| 98 | } |
| 99 | for _, msg := range fd.GetMessageTypes() { |
| 100 | r.addExtensionsFromMessageLocked(msg) |
| 101 | } |
| 102 | |
| 103 | if recursive { |
| 104 | alreadySeen[fd] = struct{}{} |
| 105 | for _, dep := range fd.GetDependencies() { |
| 106 | r.addExtensionsFromFileLocked(dep, recursive, alreadySeen) |
| 107 | } |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | func (r *ExtensionRegistry) addExtensionsFromMessageLocked(md *desc.MessageDescriptor) { |
| 112 | for _, ext := range md.GetNestedExtensions() { |
| 113 | r.putExtensionLocked(ext) |
| 114 | } |
| 115 | for _, msg := range md.GetNestedMessageTypes() { |
| 116 | r.addExtensionsFromMessageLocked(msg) |
| 117 | } |
| 118 | } |
| 119 | |
| 120 | func (r *ExtensionRegistry) putExtensionLocked(fd *desc.FieldDescriptor) { |
| 121 | msgName := fd.GetOwner().GetFullyQualifiedName() |
| 122 | m := r.exts[msgName] |
| 123 | if m == nil { |
| 124 | m = map[int32]*desc.FieldDescriptor{} |
| 125 | r.exts[msgName] = m |
| 126 | } |
| 127 | m[fd.GetNumber()] = fd |
| 128 | } |
| 129 | |
| 130 | // FindExtension queries for the extension field with the given extendee name (must be a fully-qualified |
| 131 | // message name) and tag number. If no extension is known, nil is returned. |
| 132 | func (r *ExtensionRegistry) FindExtension(messageName string, tagNumber int32) *desc.FieldDescriptor { |
| 133 | if r == nil { |
| 134 | return nil |
| 135 | } |
| 136 | r.mu.RLock() |
| 137 | defer r.mu.RUnlock() |
| 138 | fd := r.exts[messageName][tagNumber] |
| 139 | if fd == nil && r.includeDefault { |
| 140 | ext := getDefaultExtensions(messageName)[tagNumber] |
| 141 | if ext != nil { |
| 142 | fd, _ = desc.LoadFieldDescriptorForExtension(ext) |
| 143 | } |
| 144 | } |
| 145 | return fd |
| 146 | } |
| 147 | |
| 148 | // FindExtensionByName queries for the extension field with the given extendee name (must be a fully-qualified |
| 149 | // message name) and field name (must also be a fully-qualified extension name). If no extension is known, nil |
| 150 | // is returned. |
| 151 | func (r *ExtensionRegistry) FindExtensionByName(messageName string, fieldName string) *desc.FieldDescriptor { |
| 152 | if r == nil { |
| 153 | return nil |
| 154 | } |
| 155 | r.mu.RLock() |
| 156 | defer r.mu.RUnlock() |
| 157 | for _, fd := range r.exts[messageName] { |
| 158 | if fd.GetFullyQualifiedName() == fieldName { |
| 159 | return fd |
| 160 | } |
| 161 | } |
| 162 | if r.includeDefault { |
| 163 | for _, ext := range getDefaultExtensions(messageName) { |
| 164 | fd, _ := desc.LoadFieldDescriptorForExtension(ext) |
| 165 | if fd.GetFullyQualifiedName() == fieldName { |
| 166 | return fd |
| 167 | } |
| 168 | } |
| 169 | } |
| 170 | return nil |
| 171 | } |
| 172 | |
| 173 | // FindExtensionByJSONName queries for the extension field with the given extendee name (must be a fully-qualified |
| 174 | // message name) and JSON field name (must also be a fully-qualified name). If no extension is known, nil is returned. |
| 175 | // The fully-qualified JSON name is the same as the extension's normal fully-qualified name except that the last |
| 176 | // component uses the field's JSON name (if present). |
| 177 | func (r *ExtensionRegistry) FindExtensionByJSONName(messageName string, fieldName string) *desc.FieldDescriptor { |
| 178 | if r == nil { |
| 179 | return nil |
| 180 | } |
| 181 | r.mu.RLock() |
| 182 | defer r.mu.RUnlock() |
| 183 | for _, fd := range r.exts[messageName] { |
| 184 | if fd.GetFullyQualifiedJSONName() == fieldName { |
| 185 | return fd |
| 186 | } |
| 187 | } |
| 188 | if r.includeDefault { |
| 189 | for _, ext := range getDefaultExtensions(messageName) { |
| 190 | fd, _ := desc.LoadFieldDescriptorForExtension(ext) |
| 191 | if fd.GetFullyQualifiedJSONName() == fieldName { |
| 192 | return fd |
| 193 | } |
| 194 | } |
| 195 | } |
| 196 | return nil |
| 197 | } |
| 198 | |
| 199 | func getDefaultExtensions(messageName string) map[int32]*proto.ExtensionDesc { |
| 200 | t := proto.MessageType(messageName) |
| 201 | if t != nil { |
| 202 | msg := reflect.Zero(t).Interface().(proto.Message) |
| 203 | return proto.RegisteredExtensions(msg) |
| 204 | } |
| 205 | return nil |
| 206 | } |
| 207 | |
| 208 | // AllExtensionsForType returns all known extension fields for the given extendee name (must be a |
| 209 | // fully-qualified message name). |
| 210 | func (r *ExtensionRegistry) AllExtensionsForType(messageName string) []*desc.FieldDescriptor { |
| 211 | if r == nil { |
| 212 | return []*desc.FieldDescriptor(nil) |
| 213 | } |
| 214 | r.mu.RLock() |
| 215 | defer r.mu.RUnlock() |
| 216 | flds := r.exts[messageName] |
| 217 | var ret []*desc.FieldDescriptor |
| 218 | if r.includeDefault { |
| 219 | exts := getDefaultExtensions(messageName) |
| 220 | if len(exts) > 0 || len(flds) > 0 { |
| 221 | ret = make([]*desc.FieldDescriptor, 0, len(exts)+len(flds)) |
| 222 | } |
| 223 | for tag, ext := range exts { |
| 224 | if _, ok := flds[tag]; ok { |
| 225 | // skip default extension and use the one explicitly registered instead |
| 226 | continue |
| 227 | } |
| 228 | fd, _ := desc.LoadFieldDescriptorForExtension(ext) |
| 229 | if fd != nil { |
| 230 | ret = append(ret, fd) |
| 231 | } |
| 232 | } |
| 233 | } else if len(flds) > 0 { |
| 234 | ret = make([]*desc.FieldDescriptor, 0, len(flds)) |
| 235 | } |
| 236 | |
| 237 | for _, ext := range flds { |
| 238 | ret = append(ret, ext) |
| 239 | } |
| 240 | return ret |
| 241 | } |