blob: c1509362de6c6d8622fc1514b1c58c9bc07f8eb1 [file] [log] [blame]
package protoparse
import (
"bytes"
"fmt"
"sort"
"strings"
"github.com/golang/protobuf/proto"
dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/internal"
)
type linker struct {
files map[string]*parseResult
descriptorPool map[*dpb.FileDescriptorProto]map[string]proto.Message
extensions map[string]map[int32]string
}
func newLinker(files map[string]*parseResult) *linker {
return &linker{files: files}
}
func (l *linker) linkFiles() (map[string]*desc.FileDescriptor, error) {
// First, we put all symbols into a single pool, which lets us ensure there
// are no duplicate symbols and will also let us resolve and revise all type
// references in next step.
if err := l.createDescriptorPool(); err != nil {
return nil, err
}
// After we've populated the pool, we can now try to resolve all type
// references. All references must be checked for correct type, any fields
// with enum types must be corrected (since we parse them as if they are
// message references since we don't actually know message or enum until
// link time), and references will be re-written to be fully-qualified
// references (e.g. start with a dot ".").
if err := l.resolveReferences(); err != nil {
return nil, err
}
// Now we've validated the descriptors, so we can link them into rich
// descriptors. This is a little redundant since that step does similar
// checking of symbols. But, without breaking encapsulation (e.g. exporting
// a lot of fields from desc package that are currently unexported) or
// merging this into the same package, we can't really prevent it.
linked, err := l.createdLinkedDescriptors()
if err != nil {
return nil, err
}
// Now that we have linked descriptors, we can interpret any uninterpreted
// options that remain.
for _, r := range l.files {
fd := linked[r.fd.GetName()]
if err := interpretFileOptions(r, richFileDescriptorish{FileDescriptor: fd}); err != nil {
return nil, err
}
}
return linked, nil
}
func (l *linker) createDescriptorPool() error {
l.descriptorPool = map[*dpb.FileDescriptorProto]map[string]proto.Message{}
for _, r := range l.files {
fd := r.fd
pool := map[string]proto.Message{}
l.descriptorPool[fd] = pool
prefix := fd.GetPackage()
if prefix != "" {
prefix += "."
}
for _, md := range fd.MessageType {
if err := addMessageToPool(r, pool, prefix, md); err != nil {
return err
}
}
for _, fld := range fd.Extension {
if err := addFieldToPool(r, pool, prefix, fld); err != nil {
return err
}
}
for _, ed := range fd.EnumType {
if err := addEnumToPool(r, pool, prefix, ed); err != nil {
return err
}
}
for _, sd := range fd.Service {
if err := addServiceToPool(r, pool, prefix, sd); err != nil {
return err
}
}
}
// try putting everything into a single pool, to ensure there are no duplicates
// across files (e.g. same symbol, but declared in two different files)
type entry struct {
file string
msg proto.Message
}
pool := map[string]entry{}
for f, p := range l.descriptorPool {
for k, v := range p {
if e, ok := pool[k]; ok {
desc1 := e.msg
file1 := e.file
desc2 := v
file2 := f.GetName()
if file2 < file1 {
file1, file2 = file2, file1
desc1, desc2 = desc2, desc1
}
node := l.files[file2].nodes[desc2]
return ErrorWithSourcePos{Pos: node.start(), Underlying: fmt.Errorf("duplicate symbol %s: already defined as %s in %q", k, descriptorType(desc1), file1)}
}
pool[k] = entry{file: f.GetName(), msg: v}
}
}
return nil
}
func addMessageToPool(r *parseResult, pool map[string]proto.Message, prefix string, md *dpb.DescriptorProto) error {
fqn := prefix + md.GetName()
if err := addToPool(r, pool, fqn, md); err != nil {
return err
}
prefix = fqn + "."
for _, fld := range md.Field {
if err := addFieldToPool(r, pool, prefix, fld); err != nil {
return err
}
}
for _, fld := range md.Extension {
if err := addFieldToPool(r, pool, prefix, fld); err != nil {
return err
}
}
for _, nmd := range md.NestedType {
if err := addMessageToPool(r, pool, prefix, nmd); err != nil {
return err
}
}
for _, ed := range md.EnumType {
if err := addEnumToPool(r, pool, prefix, ed); err != nil {
return err
}
}
return nil
}
func addFieldToPool(r *parseResult, pool map[string]proto.Message, prefix string, fld *dpb.FieldDescriptorProto) error {
fqn := prefix + fld.GetName()
return addToPool(r, pool, fqn, fld)
}
func addEnumToPool(r *parseResult, pool map[string]proto.Message, prefix string, ed *dpb.EnumDescriptorProto) error {
fqn := prefix + ed.GetName()
if err := addToPool(r, pool, fqn, ed); err != nil {
return err
}
for _, evd := range ed.Value {
vfqn := fqn + "." + evd.GetName()
if err := addToPool(r, pool, vfqn, evd); err != nil {
return err
}
}
return nil
}
func addServiceToPool(r *parseResult, pool map[string]proto.Message, prefix string, sd *dpb.ServiceDescriptorProto) error {
fqn := prefix + sd.GetName()
if err := addToPool(r, pool, fqn, sd); err != nil {
return err
}
for _, mtd := range sd.Method {
mfqn := fqn + "." + mtd.GetName()
if err := addToPool(r, pool, mfqn, mtd); err != nil {
return err
}
}
return nil
}
func addToPool(r *parseResult, pool map[string]proto.Message, fqn string, dsc proto.Message) error {
if d, ok := pool[fqn]; ok {
node := r.nodes[dsc]
return ErrorWithSourcePos{Pos: node.start(), Underlying: fmt.Errorf("duplicate symbol %s: already defined as %s", fqn, descriptorType(d))}
}
pool[fqn] = dsc
return nil
}
func descriptorType(m proto.Message) string {
switch m := m.(type) {
case *dpb.DescriptorProto:
return "message"
case *dpb.DescriptorProto_ExtensionRange:
return "extension range"
case *dpb.FieldDescriptorProto:
if m.GetExtendee() == "" {
return "field"
} else {
return "extension"
}
case *dpb.EnumDescriptorProto:
return "enum"
case *dpb.EnumValueDescriptorProto:
return "enum value"
case *dpb.ServiceDescriptorProto:
return "service"
case *dpb.MethodDescriptorProto:
return "method"
case *dpb.FileDescriptorProto:
return "file"
default:
// shouldn't be possible
return fmt.Sprintf("%T", m)
}
}
func (l *linker) resolveReferences() error {
l.extensions = map[string]map[int32]string{}
for _, r := range l.files {
fd := r.fd
prefix := fd.GetPackage()
scopes := []scope{fileScope(fd, l)}
if prefix != "" {
prefix += "."
}
if fd.Options != nil {
if err := l.resolveOptions(r, fd, "file", fd.GetName(), proto.MessageName(fd.Options), fd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, md := range fd.MessageType {
if err := l.resolveMessageTypes(r, fd, prefix, md, scopes); err != nil {
return err
}
}
for _, fld := range fd.Extension {
if err := l.resolveFieldTypes(r, fd, prefix, fld, scopes); err != nil {
return err
}
}
for _, ed := range fd.EnumType {
if err := l.resolveEnumTypes(r, fd, prefix, ed, scopes); err != nil {
return err
}
}
for _, sd := range fd.Service {
if err := l.resolveServiceTypes(r, fd, prefix, sd, scopes); err != nil {
return err
}
}
}
return nil
}
func (l *linker) resolveEnumTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, ed *dpb.EnumDescriptorProto, scopes []scope) error {
enumFqn := prefix + ed.GetName()
if ed.Options != nil {
if err := l.resolveOptions(r, fd, "enum", enumFqn, proto.MessageName(ed.Options), ed.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, evd := range ed.Value {
if evd.Options != nil {
evFqn := enumFqn + "." + evd.GetName()
if err := l.resolveOptions(r, fd, "enum value", evFqn, proto.MessageName(evd.Options), evd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
}
return nil
}
func (l *linker) resolveMessageTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, md *dpb.DescriptorProto, scopes []scope) error {
fqn := prefix + md.GetName()
scope := messageScope(fqn, isProto3(fd), l.descriptorPool[fd])
scopes = append(scopes, scope)
prefix = fqn + "."
if md.Options != nil {
if err := l.resolveOptions(r, fd, "message", fqn, proto.MessageName(md.Options), md.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, nmd := range md.NestedType {
if err := l.resolveMessageTypes(r, fd, prefix, nmd, scopes); err != nil {
return err
}
}
for _, ned := range md.EnumType {
if err := l.resolveEnumTypes(r, fd, prefix, ned, scopes); err != nil {
return err
}
}
for _, fld := range md.Field {
if err := l.resolveFieldTypes(r, fd, prefix, fld, scopes); err != nil {
return err
}
}
for _, fld := range md.Extension {
if err := l.resolveFieldTypes(r, fd, prefix, fld, scopes); err != nil {
return err
}
}
for _, er := range md.ExtensionRange {
if er.Options != nil {
erName := fmt.Sprintf("%s:%d-%d", fqn, er.GetStart(), er.GetEnd()-1)
if err := l.resolveOptions(r, fd, "extension range", erName, proto.MessageName(er.Options), er.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
}
return nil
}
func (l *linker) resolveFieldTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto, scopes []scope) error {
thisName := prefix + fld.GetName()
scope := fmt.Sprintf("field %s", thisName)
node := r.getFieldNode(fld)
elemType := "field"
if fld.GetExtendee() != "" {
fqn, dsc, _ := l.resolve(fd, fld.GetExtendee(), isMessage, scopes)
if dsc == nil {
return ErrorWithSourcePos{Pos: node.fieldExtendee().start(), Underlying: fmt.Errorf("unknown extendee type %s", fld.GetExtendee())}
}
extd, ok := dsc.(*dpb.DescriptorProto)
if !ok {
otherType := descriptorType(dsc)
return ErrorWithSourcePos{Pos: node.fieldExtendee().start(), Underlying: fmt.Errorf("extendee is invalid: %s is a %s, not a message", fqn, otherType)}
}
fld.Extendee = proto.String("." + fqn)
// make sure the tag number is in range
found := false
tag := fld.GetNumber()
for _, rng := range extd.ExtensionRange {
if tag >= rng.GetStart() && tag < rng.GetEnd() {
found = true
break
}
}
if !found {
return ErrorWithSourcePos{Pos: node.fieldTag().start(), Underlying: fmt.Errorf("%s: tag %d is not in valid range for extended type %s", scope, tag, fqn)}
}
// make sure tag is not a duplicate
usedExtTags := l.extensions[fqn]
if usedExtTags == nil {
usedExtTags = map[int32]string{}
l.extensions[fqn] = usedExtTags
}
if other := usedExtTags[fld.GetNumber()]; other != "" {
return ErrorWithSourcePos{Pos: node.fieldTag().start(), Underlying: fmt.Errorf("%s: duplicate extension: %s and %s are both using tag %d", scope, other, thisName, fld.GetNumber())}
}
usedExtTags[fld.GetNumber()] = thisName
elemType = "extension"
}
if fld.Options != nil {
if err := l.resolveOptions(r, fd, elemType, thisName, proto.MessageName(fld.Options), fld.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
if fld.GetTypeName() == "" {
// scalar type; no further resolution required
return nil
}
fqn, dsc, proto3 := l.resolve(fd, fld.GetTypeName(), isType, scopes)
if dsc == nil {
return ErrorWithSourcePos{Pos: node.fieldType().start(), Underlying: fmt.Errorf("%s: unknown type %s", scope, fld.GetTypeName())}
}
switch dsc := dsc.(type) {
case *dpb.DescriptorProto:
fld.TypeName = proto.String("." + fqn)
case *dpb.EnumDescriptorProto:
if fld.GetExtendee() == "" && isProto3(fd) && !proto3 {
// fields in a proto3 message cannot refer to proto2 enums
return ErrorWithSourcePos{Pos: node.fieldType().start(), Underlying: fmt.Errorf("%s: cannot use proto2 enum %s in a proto3 message", scope, fld.GetTypeName())}
}
fld.TypeName = proto.String("." + fqn)
// the type was tentatively set to message, but now we know it's actually an enum
fld.Type = dpb.FieldDescriptorProto_TYPE_ENUM.Enum()
default:
otherType := descriptorType(dsc)
return ErrorWithSourcePos{Pos: node.fieldType().start(), Underlying: fmt.Errorf("%s: invalid type: %s is a %s, not a message or enum", scope, fqn, otherType)}
}
return nil
}
func (l *linker) resolveServiceTypes(r *parseResult, fd *dpb.FileDescriptorProto, prefix string, sd *dpb.ServiceDescriptorProto, scopes []scope) error {
thisName := prefix + sd.GetName()
if sd.Options != nil {
if err := l.resolveOptions(r, fd, "service", thisName, proto.MessageName(sd.Options), sd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
for _, mtd := range sd.Method {
if mtd.Options != nil {
if err := l.resolveOptions(r, fd, "method", thisName+"."+mtd.GetName(), proto.MessageName(mtd.Options), mtd.Options.UninterpretedOption, scopes); err != nil {
return err
}
}
scope := fmt.Sprintf("method %s.%s", thisName, mtd.GetName())
node := r.getMethodNode(mtd)
fqn, dsc, _ := l.resolve(fd, mtd.GetInputType(), isMessage, scopes)
if dsc == nil {
return ErrorWithSourcePos{Pos: node.getInputType().start(), Underlying: fmt.Errorf("%s: unknown request type %s", scope, mtd.GetInputType())}
}
if _, ok := dsc.(*dpb.DescriptorProto); !ok {
otherType := descriptorType(dsc)
return ErrorWithSourcePos{Pos: node.getInputType().start(), Underlying: fmt.Errorf("%s: invalid request type: %s is a %s, not a message", scope, fqn, otherType)}
}
mtd.InputType = proto.String("." + fqn)
fqn, dsc, _ = l.resolve(fd, mtd.GetOutputType(), isMessage, scopes)
if dsc == nil {
return ErrorWithSourcePos{Pos: node.getOutputType().start(), Underlying: fmt.Errorf("%s: unknown response type %s", scope, mtd.GetOutputType())}
}
if _, ok := dsc.(*dpb.DescriptorProto); !ok {
otherType := descriptorType(dsc)
return ErrorWithSourcePos{Pos: node.getOutputType().start(), Underlying: fmt.Errorf("%s: invalid response type: %s is a %s, not a message", scope, fqn, otherType)}
}
mtd.OutputType = proto.String("." + fqn)
}
return nil
}
func (l *linker) resolveOptions(r *parseResult, fd *dpb.FileDescriptorProto, elemType, elemName, optType string, opts []*dpb.UninterpretedOption, scopes []scope) error {
var scope string
if elemType != "file" {
scope = fmt.Sprintf("%s %s: ", elemType, elemName)
}
for _, opt := range opts {
for _, nm := range opt.Name {
if nm.GetIsExtension() {
node := r.getOptionNamePartNode(nm)
fqn, dsc, _ := l.resolve(fd, nm.GetNamePart(), isField, scopes)
if dsc == nil {
return ErrorWithSourcePos{Pos: node.start(), Underlying: fmt.Errorf("%sunknown extension %s", scope, nm.GetNamePart())}
}
if ext, ok := dsc.(*dpb.FieldDescriptorProto); !ok {
otherType := descriptorType(dsc)
return ErrorWithSourcePos{Pos: node.start(), Underlying: fmt.Errorf("%sinvalid extension: %s is a %s, not an extension", scope, nm.GetNamePart(), otherType)}
} else if ext.GetExtendee() == "" {
return ErrorWithSourcePos{Pos: node.start(), Underlying: fmt.Errorf("%sinvalid extension: %s is a field but not an extension", scope, nm.GetNamePart())}
}
nm.NamePart = proto.String("." + fqn)
}
}
}
return nil
}
func (l *linker) resolve(fd *dpb.FileDescriptorProto, name string, allowed func(proto.Message) bool, scopes []scope) (fqn string, element proto.Message, proto3 bool) {
if strings.HasPrefix(name, ".") {
// already fully-qualified
d, proto3 := l.findSymbol(fd, name[1:], false, map[*dpb.FileDescriptorProto]struct{}{})
if d != nil {
return name[1:], d, proto3
}
} else {
// unqualified, so we look in the enclosing (last) scope first and move
// towards outermost (first) scope, trying to resolve the symbol
var bestGuess proto.Message
var bestGuessFqn string
var bestGuessProto3 bool
for i := len(scopes) - 1; i >= 0; i-- {
fqn, d, proto3 := scopes[i](name)
if d != nil {
if allowed(d) {
return fqn, d, proto3
} else if bestGuess == nil {
bestGuess = d
bestGuessFqn = fqn
bestGuessProto3 = proto3
}
}
}
// we return best guess, even though it was not an allowed kind of
// descriptor, so caller can print a better error message (e.g.
// indicating that the name was found but that it's the wrong type)
return bestGuessFqn, bestGuess, bestGuessProto3
}
return "", nil, false
}
func isField(m proto.Message) bool {
_, ok := m.(*dpb.FieldDescriptorProto)
return ok
}
func isMessage(m proto.Message) bool {
_, ok := m.(*dpb.DescriptorProto)
return ok
}
func isType(m proto.Message) bool {
switch m.(type) {
case *dpb.DescriptorProto, *dpb.EnumDescriptorProto:
return true
}
return false
}
// scope represents a lexical scope in a proto file in which messages and enums
// can be declared.
type scope func(symbol string) (fqn string, element proto.Message, proto3 bool)
func fileScope(fd *dpb.FileDescriptorProto, l *linker) scope {
// we search symbols in this file, but also symbols in other files that have
// the same package as this file or a "parent" package (in protobuf,
// packages are a hierarchy like C++ namespaces)
prefixes := internal.CreatePrefixList(fd.GetPackage())
return func(name string) (string, proto.Message, bool) {
for _, prefix := range prefixes {
var n string
if prefix == "" {
n = name
} else {
n = prefix + "." + name
}
d, proto3 := l.findSymbol(fd, n, false, map[*dpb.FileDescriptorProto]struct{}{})
if d != nil {
return n, d, proto3
}
}
return "", nil, false
}
}
func messageScope(messageName string, proto3 bool, filePool map[string]proto.Message) scope {
return func(name string) (string, proto.Message, bool) {
n := messageName + "." + name
if d, ok := filePool[n]; ok {
return n, d, proto3
}
return "", nil, false
}
}
func (l *linker) findSymbol(fd *dpb.FileDescriptorProto, name string, public bool, checked map[*dpb.FileDescriptorProto]struct{}) (element proto.Message, proto3 bool) {
if _, ok := checked[fd]; ok {
// already checked this one
return nil, false
}
checked[fd] = struct{}{}
d := l.descriptorPool[fd][name]
if d != nil {
return d, isProto3(fd)
}
// When public = false, we are searching only directly imported symbols. But we
// also need to search transitive public imports due to semantics of public imports.
if public {
for _, depIndex := range fd.PublicDependency {
dep := fd.Dependency[depIndex]
depres := l.files[dep]
if depres == nil {
// we'll catch this error later
continue
}
if d, proto3 := l.findSymbol(depres.fd, name, true, checked); d != nil {
return d, proto3
}
}
} else {
for _, dep := range fd.Dependency {
depres := l.files[dep]
if depres == nil {
// we'll catch this error later
continue
}
if d, proto3 := l.findSymbol(depres.fd, name, true, checked); d != nil {
return d, proto3
}
}
}
return nil, false
}
func isProto3(fd *dpb.FileDescriptorProto) bool {
return fd.GetSyntax() == "proto3"
}
func (l *linker) createdLinkedDescriptors() (map[string]*desc.FileDescriptor, error) {
names := make([]string, 0, len(l.files))
for name := range l.files {
names = append(names, name)
}
sort.Strings(names)
linked := map[string]*desc.FileDescriptor{}
for _, name := range names {
if _, err := l.linkFile(name, nil, linked); err != nil {
return nil, err
}
}
return linked, nil
}
func (l *linker) linkFile(name string, seen []string, linked map[string]*desc.FileDescriptor) (*desc.FileDescriptor, error) {
// check for import cycle
for _, s := range seen {
if name == s {
var msg bytes.Buffer
first := true
for _, s := range seen {
if first {
first = false
} else {
msg.WriteString(" -> ")
}
fmt.Fprintf(&msg, "%q", s)
}
fmt.Fprintf(&msg, " -> %q", name)
return nil, fmt.Errorf("cycle found in imports: %s", msg.String())
}
}
seen = append(seen, name)
if lfd, ok := linked[name]; ok {
// already linked
return lfd, nil
}
r := l.files[name]
if r == nil {
importer := seen[len(seen)-2] // len-1 is *this* file, before that is the one that imported it
return nil, fmt.Errorf("no descriptor found for %q, imported by %q", name, importer)
}
var deps []*desc.FileDescriptor
for _, dep := range r.fd.Dependency {
ldep, err := l.linkFile(dep, seen, linked)
if err != nil {
return nil, err
}
deps = append(deps, ldep)
}
lfd, err := desc.CreateFileDescriptor(r.fd, deps...)
if err != nil {
return nil, fmt.Errorf("error linking %q: %s", name, err)
}
linked[name] = lfd
return lfd, nil
}