| package descriptor |
| |
| import ( |
| "fmt" |
| "strings" |
| |
| "github.com/golang/glog" |
| "github.com/golang/protobuf/proto" |
| descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" |
| "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule" |
| options "google.golang.org/genproto/googleapis/api/annotations" |
| ) |
| |
| // loadServices registers services and their methods from "targetFile" to "r". |
| // It must be called after loadFile is called for all files so that loadServices |
| // can resolve names of message types and their fields. |
| func (r *Registry) loadServices(file *File) error { |
| glog.V(1).Infof("Loading services from %s", file.GetName()) |
| var svcs []*Service |
| for _, sd := range file.GetService() { |
| glog.V(2).Infof("Registering %s", sd.GetName()) |
| svc := &Service{ |
| File: file, |
| ServiceDescriptorProto: sd, |
| } |
| for _, md := range sd.GetMethod() { |
| glog.V(2).Infof("Processing %s.%s", sd.GetName(), md.GetName()) |
| opts, err := extractAPIOptions(md) |
| if err != nil { |
| glog.Errorf("Failed to extract HttpRule from %s.%s: %v", svc.GetName(), md.GetName(), err) |
| return err |
| } |
| optsList := r.LookupExternalHTTPRules((&Method{Service: svc, MethodDescriptorProto: md}).FQMN()) |
| if opts != nil { |
| optsList = append(optsList, opts) |
| } |
| if len(optsList) == 0 { |
| glog.V(1).Infof("Found non-target method: %s.%s", svc.GetName(), md.GetName()) |
| } |
| meth, err := r.newMethod(svc, md, optsList) |
| if err != nil { |
| return err |
| } |
| svc.Methods = append(svc.Methods, meth) |
| } |
| if len(svc.Methods) == 0 { |
| continue |
| } |
| glog.V(2).Infof("Registered %s with %d method(s)", svc.GetName(), len(svc.Methods)) |
| svcs = append(svcs, svc) |
| } |
| file.Services = svcs |
| return nil |
| } |
| |
| func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, optsList []*options.HttpRule) (*Method, error) { |
| requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType()) |
| if err != nil { |
| return nil, err |
| } |
| responseType, err := r.LookupMsg(svc.File.GetPackage(), md.GetOutputType()) |
| if err != nil { |
| return nil, err |
| } |
| meth := &Method{ |
| Service: svc, |
| MethodDescriptorProto: md, |
| RequestType: requestType, |
| ResponseType: responseType, |
| } |
| |
| newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) { |
| var ( |
| httpMethod string |
| pathTemplate string |
| ) |
| switch { |
| case opts.GetGet() != "": |
| httpMethod = "GET" |
| pathTemplate = opts.GetGet() |
| if opts.Body != "" { |
| return nil, fmt.Errorf("must not set request body when http method is GET: %s", md.GetName()) |
| } |
| |
| case opts.GetPut() != "": |
| httpMethod = "PUT" |
| pathTemplate = opts.GetPut() |
| |
| case opts.GetPost() != "": |
| httpMethod = "POST" |
| pathTemplate = opts.GetPost() |
| |
| case opts.GetDelete() != "": |
| httpMethod = "DELETE" |
| pathTemplate = opts.GetDelete() |
| if opts.Body != "" && !r.allowDeleteBody { |
| return nil, fmt.Errorf("must not set request body when http method is DELETE except allow_delete_body option is true: %s", md.GetName()) |
| } |
| |
| case opts.GetPatch() != "": |
| httpMethod = "PATCH" |
| pathTemplate = opts.GetPatch() |
| |
| case opts.GetCustom() != nil: |
| custom := opts.GetCustom() |
| httpMethod = custom.Kind |
| pathTemplate = custom.Path |
| |
| default: |
| glog.V(1).Infof("No pattern specified in google.api.HttpRule: %s", md.GetName()) |
| return nil, nil |
| } |
| |
| parsed, err := httprule.Parse(pathTemplate) |
| if err != nil { |
| return nil, err |
| } |
| tmpl := parsed.Compile() |
| |
| if md.GetClientStreaming() && len(tmpl.Fields) > 0 { |
| return nil, fmt.Errorf("cannot use path parameter in client streaming") |
| } |
| |
| b := &Binding{ |
| Method: meth, |
| Index: idx, |
| PathTmpl: tmpl, |
| HTTPMethod: httpMethod, |
| } |
| |
| for _, f := range tmpl.Fields { |
| param, err := r.newParam(meth, f) |
| if err != nil { |
| return nil, err |
| } |
| b.PathParams = append(b.PathParams, param) |
| } |
| |
| // TODO(yugui) Handle query params |
| |
| b.Body, err = r.newBody(meth, opts.Body) |
| if err != nil { |
| return nil, err |
| } |
| |
| b.ResponseBody, err = r.newResponse(meth, opts.ResponseBody) |
| if err != nil { |
| return nil, err |
| } |
| |
| return b, nil |
| } |
| |
| applyOpts := func(opts *options.HttpRule) error { |
| b, err := newBinding(opts, len(meth.Bindings)) |
| if err != nil { |
| return err |
| } |
| |
| if b != nil { |
| meth.Bindings = append(meth.Bindings, b) |
| } |
| for _, additional := range opts.GetAdditionalBindings() { |
| if len(additional.AdditionalBindings) > 0 { |
| return fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName()) |
| } |
| b, err := newBinding(additional, len(meth.Bindings)) |
| if err != nil { |
| return err |
| } |
| meth.Bindings = append(meth.Bindings, b) |
| } |
| |
| return nil |
| } |
| |
| for _, opts := range optsList { |
| if err := applyOpts(opts); err != nil { |
| return nil, err |
| } |
| } |
| |
| return meth, nil |
| } |
| |
| func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) { |
| if meth.Options == nil { |
| return nil, nil |
| } |
| if !proto.HasExtension(meth.Options, options.E_Http) { |
| return nil, nil |
| } |
| ext, err := proto.GetExtension(meth.Options, options.E_Http) |
| if err != nil { |
| return nil, err |
| } |
| opts, ok := ext.(*options.HttpRule) |
| if !ok { |
| return nil, fmt.Errorf("extension is %T; want an HttpRule", ext) |
| } |
| return opts, nil |
| } |
| |
| func (r *Registry) newParam(meth *Method, path string) (Parameter, error) { |
| msg := meth.RequestType |
| fields, err := r.resolveFieldPath(msg, path, true) |
| if err != nil { |
| return Parameter{}, err |
| } |
| l := len(fields) |
| if l == 0 { |
| return Parameter{}, fmt.Errorf("invalid field access list for %s", path) |
| } |
| target := fields[l-1].Target |
| switch target.GetType() { |
| case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP: |
| glog.V(2).Infoln("found aggregate type:", target, target.TypeName) |
| if IsWellKnownType(*target.TypeName) { |
| glog.V(2).Infoln("found well known aggregate type:", target) |
| } else { |
| return Parameter{}, fmt.Errorf("aggregate type %s in parameter of %s.%s: %s", target.Type, meth.Service.GetName(), meth.GetName(), path) |
| } |
| } |
| return Parameter{ |
| FieldPath: FieldPath(fields), |
| Method: meth, |
| Target: fields[l-1].Target, |
| }, nil |
| } |
| |
| func (r *Registry) newBody(meth *Method, path string) (*Body, error) { |
| msg := meth.RequestType |
| switch path { |
| case "": |
| return nil, nil |
| case "*": |
| return &Body{FieldPath: nil}, nil |
| } |
| fields, err := r.resolveFieldPath(msg, path, false) |
| if err != nil { |
| return nil, err |
| } |
| return &Body{FieldPath: FieldPath(fields)}, nil |
| } |
| |
| func (r *Registry) newResponse(meth *Method, path string) (*Body, error) { |
| msg := meth.ResponseType |
| switch path { |
| case "", "*": |
| return nil, nil |
| } |
| fields, err := r.resolveFieldPath(msg, path, false) |
| if err != nil { |
| return nil, err |
| } |
| return &Body{FieldPath: FieldPath(fields)}, nil |
| } |
| |
| // lookupField looks up a field named "name" within "msg". |
| // It returns nil if no such field found. |
| func lookupField(msg *Message, name string) *Field { |
| for _, f := range msg.Fields { |
| if f.GetName() == name { |
| return f |
| } |
| } |
| return nil |
| } |
| |
| // resolveFieldPath resolves "path" into a list of fieldDescriptor, starting from "msg". |
| func (r *Registry) resolveFieldPath(msg *Message, path string, isPathParam bool) ([]FieldPathComponent, error) { |
| if path == "" { |
| return nil, nil |
| } |
| |
| root := msg |
| var result []FieldPathComponent |
| for i, c := range strings.Split(path, ".") { |
| if i > 0 { |
| f := result[i-1].Target |
| switch f.GetType() { |
| case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP: |
| var err error |
| msg, err = r.LookupMsg(msg.FQMN(), f.GetTypeName()) |
| if err != nil { |
| return nil, err |
| } |
| default: |
| return nil, fmt.Errorf("not an aggregate type: %s in %s", f.GetName(), path) |
| } |
| } |
| |
| glog.V(2).Infof("Lookup %s in %s", c, msg.FQMN()) |
| f := lookupField(msg, c) |
| if f == nil { |
| return nil, fmt.Errorf("no field %q found in %s", path, root.GetName()) |
| } |
| if !(isPathParam || r.allowRepeatedFieldsInBody) && f.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED { |
| return nil, fmt.Errorf("repeated field not allowed in field path: %s in %s", f.GetName(), path) |
| } |
| result = append(result, FieldPathComponent{Name: c, Target: f}) |
| } |
| return result, nil |
| } |