blob: 8916d316d18c6bba4696f219dabd6fee6bc6b888 [file] [log] [blame]
Matteo Scandoloa6a3aee2019-11-26 13:30:14 -07001package descriptor
2
3import (
4 "fmt"
5 "strings"
6
7 "github.com/golang/glog"
8 "github.com/golang/protobuf/proto"
9 descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
10 "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
11 options "google.golang.org/genproto/googleapis/api/annotations"
12)
13
14// loadServices registers services and their methods from "targetFile" to "r".
15// It must be called after loadFile is called for all files so that loadServices
16// can resolve names of message types and their fields.
17func (r *Registry) loadServices(file *File) error {
18 glog.V(1).Infof("Loading services from %s", file.GetName())
19 var svcs []*Service
20 for _, sd := range file.GetService() {
21 glog.V(2).Infof("Registering %s", sd.GetName())
22 svc := &Service{
23 File: file,
24 ServiceDescriptorProto: sd,
25 }
26 for _, md := range sd.GetMethod() {
27 glog.V(2).Infof("Processing %s.%s", sd.GetName(), md.GetName())
28 opts, err := extractAPIOptions(md)
29 if err != nil {
30 glog.Errorf("Failed to extract HttpRule from %s.%s: %v", svc.GetName(), md.GetName(), err)
31 return err
32 }
33 optsList := r.LookupExternalHTTPRules((&Method{Service: svc, MethodDescriptorProto: md}).FQMN())
34 if opts != nil {
35 optsList = append(optsList, opts)
36 }
37 if len(optsList) == 0 {
38 glog.V(1).Infof("Found non-target method: %s.%s", svc.GetName(), md.GetName())
39 }
40 meth, err := r.newMethod(svc, md, optsList)
41 if err != nil {
42 return err
43 }
44 svc.Methods = append(svc.Methods, meth)
45 }
46 if len(svc.Methods) == 0 {
47 continue
48 }
49 glog.V(2).Infof("Registered %s with %d method(s)", svc.GetName(), len(svc.Methods))
50 svcs = append(svcs, svc)
51 }
52 file.Services = svcs
53 return nil
54}
55
56func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, optsList []*options.HttpRule) (*Method, error) {
57 requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType())
58 if err != nil {
59 return nil, err
60 }
61 responseType, err := r.LookupMsg(svc.File.GetPackage(), md.GetOutputType())
62 if err != nil {
63 return nil, err
64 }
65 meth := &Method{
66 Service: svc,
67 MethodDescriptorProto: md,
68 RequestType: requestType,
69 ResponseType: responseType,
70 }
71
72 newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) {
73 var (
74 httpMethod string
75 pathTemplate string
76 )
77 switch {
78 case opts.GetGet() != "":
79 httpMethod = "GET"
80 pathTemplate = opts.GetGet()
81 if opts.Body != "" {
82 return nil, fmt.Errorf("must not set request body when http method is GET: %s", md.GetName())
83 }
84
85 case opts.GetPut() != "":
86 httpMethod = "PUT"
87 pathTemplate = opts.GetPut()
88
89 case opts.GetPost() != "":
90 httpMethod = "POST"
91 pathTemplate = opts.GetPost()
92
93 case opts.GetDelete() != "":
94 httpMethod = "DELETE"
95 pathTemplate = opts.GetDelete()
96 if opts.Body != "" && !r.allowDeleteBody {
97 return nil, fmt.Errorf("must not set request body when http method is DELETE except allow_delete_body option is true: %s", md.GetName())
98 }
99
100 case opts.GetPatch() != "":
101 httpMethod = "PATCH"
102 pathTemplate = opts.GetPatch()
103
104 case opts.GetCustom() != nil:
105 custom := opts.GetCustom()
106 httpMethod = custom.Kind
107 pathTemplate = custom.Path
108
109 default:
110 glog.V(1).Infof("No pattern specified in google.api.HttpRule: %s", md.GetName())
111 return nil, nil
112 }
113
114 parsed, err := httprule.Parse(pathTemplate)
115 if err != nil {
116 return nil, err
117 }
118 tmpl := parsed.Compile()
119
120 if md.GetClientStreaming() && len(tmpl.Fields) > 0 {
121 return nil, fmt.Errorf("cannot use path parameter in client streaming")
122 }
123
124 b := &Binding{
125 Method: meth,
126 Index: idx,
127 PathTmpl: tmpl,
128 HTTPMethod: httpMethod,
129 }
130
131 for _, f := range tmpl.Fields {
132 param, err := r.newParam(meth, f)
133 if err != nil {
134 return nil, err
135 }
136 b.PathParams = append(b.PathParams, param)
137 }
138
139 // TODO(yugui) Handle query params
140
141 b.Body, err = r.newBody(meth, opts.Body)
142 if err != nil {
143 return nil, err
144 }
145
146 b.ResponseBody, err = r.newResponse(meth, opts.ResponseBody)
147 if err != nil {
148 return nil, err
149 }
150
151 return b, nil
152 }
153
154 applyOpts := func(opts *options.HttpRule) error {
155 b, err := newBinding(opts, len(meth.Bindings))
156 if err != nil {
157 return err
158 }
159
160 if b != nil {
161 meth.Bindings = append(meth.Bindings, b)
162 }
163 for _, additional := range opts.GetAdditionalBindings() {
164 if len(additional.AdditionalBindings) > 0 {
165 return fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName())
166 }
167 b, err := newBinding(additional, len(meth.Bindings))
168 if err != nil {
169 return err
170 }
171 meth.Bindings = append(meth.Bindings, b)
172 }
173
174 return nil
175 }
176
177 for _, opts := range optsList {
178 if err := applyOpts(opts); err != nil {
179 return nil, err
180 }
181 }
182
183 return meth, nil
184}
185
186func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) {
187 if meth.Options == nil {
188 return nil, nil
189 }
190 if !proto.HasExtension(meth.Options, options.E_Http) {
191 return nil, nil
192 }
193 ext, err := proto.GetExtension(meth.Options, options.E_Http)
194 if err != nil {
195 return nil, err
196 }
197 opts, ok := ext.(*options.HttpRule)
198 if !ok {
199 return nil, fmt.Errorf("extension is %T; want an HttpRule", ext)
200 }
201 return opts, nil
202}
203
204func (r *Registry) newParam(meth *Method, path string) (Parameter, error) {
205 msg := meth.RequestType
206 fields, err := r.resolveFieldPath(msg, path, true)
207 if err != nil {
208 return Parameter{}, err
209 }
210 l := len(fields)
211 if l == 0 {
212 return Parameter{}, fmt.Errorf("invalid field access list for %s", path)
213 }
214 target := fields[l-1].Target
215 switch target.GetType() {
216 case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP:
217 glog.V(2).Infoln("found aggregate type:", target, target.TypeName)
218 if IsWellKnownType(*target.TypeName) {
219 glog.V(2).Infoln("found well known aggregate type:", target)
220 } else {
221 return Parameter{}, fmt.Errorf("aggregate type %s in parameter of %s.%s: %s", target.Type, meth.Service.GetName(), meth.GetName(), path)
222 }
223 }
224 return Parameter{
225 FieldPath: FieldPath(fields),
226 Method: meth,
227 Target: fields[l-1].Target,
228 }, nil
229}
230
231func (r *Registry) newBody(meth *Method, path string) (*Body, error) {
232 msg := meth.RequestType
233 switch path {
234 case "":
235 return nil, nil
236 case "*":
237 return &Body{FieldPath: nil}, nil
238 }
239 fields, err := r.resolveFieldPath(msg, path, false)
240 if err != nil {
241 return nil, err
242 }
243 return &Body{FieldPath: FieldPath(fields)}, nil
244}
245
246func (r *Registry) newResponse(meth *Method, path string) (*Body, error) {
247 msg := meth.ResponseType
248 switch path {
249 case "", "*":
250 return nil, nil
251 }
252 fields, err := r.resolveFieldPath(msg, path, false)
253 if err != nil {
254 return nil, err
255 }
256 return &Body{FieldPath: FieldPath(fields)}, nil
257}
258
259// lookupField looks up a field named "name" within "msg".
260// It returns nil if no such field found.
261func lookupField(msg *Message, name string) *Field {
262 for _, f := range msg.Fields {
263 if f.GetName() == name {
264 return f
265 }
266 }
267 return nil
268}
269
270// resolveFieldPath resolves "path" into a list of fieldDescriptor, starting from "msg".
271func (r *Registry) resolveFieldPath(msg *Message, path string, isPathParam bool) ([]FieldPathComponent, error) {
272 if path == "" {
273 return nil, nil
274 }
275
276 root := msg
277 var result []FieldPathComponent
278 for i, c := range strings.Split(path, ".") {
279 if i > 0 {
280 f := result[i-1].Target
281 switch f.GetType() {
282 case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP:
283 var err error
284 msg, err = r.LookupMsg(msg.FQMN(), f.GetTypeName())
285 if err != nil {
286 return nil, err
287 }
288 default:
289 return nil, fmt.Errorf("not an aggregate type: %s in %s", f.GetName(), path)
290 }
291 }
292
293 glog.V(2).Infof("Lookup %s in %s", c, msg.FQMN())
294 f := lookupField(msg, c)
295 if f == nil {
296 return nil, fmt.Errorf("no field %q found in %s", path, root.GetName())
297 }
298 if !(isPathParam || r.allowRepeatedFieldsInBody) && f.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED {
299 return nil, fmt.Errorf("repeated field not allowed in field path: %s in %s", f.GetName(), path)
300 }
301 result = append(result, FieldPathComponent{Name: c, Target: f})
302 }
303 return result, nil
304}