blob: 523a9cb43c93065843cd590d0c574a81f0277f2e [file] [log] [blame]
khenaidooab1f7bd2019-11-14 14:00:27 -05001package runtime
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/textproto"
8 "strings"
9
10 "github.com/golang/protobuf/proto"
11 "google.golang.org/grpc/codes"
12 "google.golang.org/grpc/metadata"
13 "google.golang.org/grpc/status"
14)
15
16// A HandlerFunc handles a specific pair of path pattern and HTTP method.
17type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
18
19// ErrUnknownURI is the error supplied to a custom ProtoErrorHandlerFunc when
20// a request is received with a URI path that does not match any registered
21// service method.
22//
23// Since gRPC servers return an "Unimplemented" code for requests with an
24// unrecognized URI path, this error also has a gRPC "Unimplemented" code.
25var ErrUnknownURI = status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
26
27// ServeMux is a request multiplexer for grpc-gateway.
28// It matches http requests to patterns and invokes the corresponding handler.
29type ServeMux struct {
30 // handlers maps HTTP method to a list of handlers.
31 handlers map[string][]handler
32 forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
33 marshalers marshalerRegistry
34 incomingHeaderMatcher HeaderMatcherFunc
35 outgoingHeaderMatcher HeaderMatcherFunc
36 metadataAnnotators []func(context.Context, *http.Request) metadata.MD
37 streamErrorHandler StreamErrorHandlerFunc
38 protoErrorHandler ProtoErrorHandlerFunc
39 disablePathLengthFallback bool
40 lastMatchWins bool
41}
42
43// ServeMuxOption is an option that can be given to a ServeMux on construction.
44type ServeMuxOption func(*ServeMux)
45
46// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
47//
48// forwardResponseOption is an option that will be called on the relevant context.Context,
49// http.ResponseWriter, and proto.Message before every forwarded response.
50//
51// The message may be nil in the case where just a header is being sent.
52func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
53 return func(serveMux *ServeMux) {
54 serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
55 }
56}
57
khenaidood948f772021-08-11 17:49:24 -040058// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
59// Configuring this will mean the generated swagger output is no longer correct, and it should be
60// done with careful consideration.
61func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
62 return func(serveMux *ServeMux) {
63 currentQueryParser = queryParameterParser
64 }
65}
66
khenaidooab1f7bd2019-11-14 14:00:27 -050067// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
68type HeaderMatcherFunc func(string) (string, bool)
69
70// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
71// keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
72// 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
73func DefaultHeaderMatcher(key string) (string, bool) {
74 key = textproto.CanonicalMIMEHeaderKey(key)
75 if isPermanentHTTPHeader(key) {
76 return MetadataPrefix + key, true
77 } else if strings.HasPrefix(key, MetadataHeaderPrefix) {
78 return key[len(MetadataHeaderPrefix):], true
79 }
80 return "", false
81}
82
83// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
84//
85// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
86// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
87func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
88 return func(mux *ServeMux) {
89 mux.incomingHeaderMatcher = fn
90 }
91}
92
93// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
94//
95// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
96// passed to http response returned from gateway. To transform the header before passing to response,
97// matcher should return modified header.
98func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
99 return func(mux *ServeMux) {
100 mux.outgoingHeaderMatcher = fn
101 }
102}
103
104// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
105//
106// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
107// is reading token from cookie and adding it in gRPC context.
108func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
109 return func(serveMux *ServeMux) {
110 serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
111 }
112}
113
khenaidood948f772021-08-11 17:49:24 -0400114// WithProtoErrorHandler returns a ServeMuxOption for configuring a custom error handler.
khenaidooab1f7bd2019-11-14 14:00:27 -0500115//
116// This can be used to handle an error as general proto message defined by gRPC.
khenaidood948f772021-08-11 17:49:24 -0400117// When this option is used, the mux uses the configured error handler instead of HTTPError and
118// OtherErrorHandler.
khenaidooab1f7bd2019-11-14 14:00:27 -0500119func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
120 return func(serveMux *ServeMux) {
121 serveMux.protoErrorHandler = fn
122 }
123}
124
125// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
126func WithDisablePathLengthFallback() ServeMuxOption {
127 return func(serveMux *ServeMux) {
128 serveMux.disablePathLengthFallback = true
129 }
130}
131
132// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
133// error handler, which allows for customizing the error trailer for server-streaming
134// calls.
135//
136// For stream errors that occur before any response has been written, the mux's
137// ProtoErrorHandler will be invoked. However, once data has been written, the errors must
138// be handled differently: they must be included in the response body. The response body's
139// final message will include the error details returned by the stream error handler.
140func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
141 return func(serveMux *ServeMux) {
142 serveMux.streamErrorHandler = fn
143 }
144}
145
146// WithLastMatchWins returns a ServeMuxOption that will enable "last
147// match wins" behavior, where if multiple path patterns match a
148// request path, the last one defined in the .proto file will be used.
149func WithLastMatchWins() ServeMuxOption {
150 return func(serveMux *ServeMux) {
151 serveMux.lastMatchWins = true
152 }
153}
154
155// NewServeMux returns a new ServeMux whose internal mapping is empty.
156func NewServeMux(opts ...ServeMuxOption) *ServeMux {
157 serveMux := &ServeMux{
158 handlers: make(map[string][]handler),
159 forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
160 marshalers: makeMarshalerMIMERegistry(),
161 streamErrorHandler: DefaultHTTPStreamErrorHandler,
162 }
163
164 for _, opt := range opts {
165 opt(serveMux)
166 }
167
khenaidooab1f7bd2019-11-14 14:00:27 -0500168 if serveMux.incomingHeaderMatcher == nil {
169 serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
170 }
171
172 if serveMux.outgoingHeaderMatcher == nil {
173 serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
174 return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
175 }
176 }
177
178 return serveMux
179}
180
181// Handle associates "h" to the pair of HTTP method and path pattern.
182func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
183 if s.lastMatchWins {
184 s.handlers[meth] = append([]handler{handler{pat: pat, h: h}}, s.handlers[meth]...)
185 } else {
186 s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
187 }
188}
189
190// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
191func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
192 ctx := r.Context()
193
194 path := r.URL.Path
195 if !strings.HasPrefix(path, "/") {
196 if s.protoErrorHandler != nil {
197 _, outboundMarshaler := MarshalerForRequest(s, r)
198 sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
199 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
200 } else {
201 OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
202 }
203 return
204 }
205
206 components := strings.Split(path[1:], "/")
207 l := len(components)
208 var verb string
209 if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
210 if s.protoErrorHandler != nil {
211 _, outboundMarshaler := MarshalerForRequest(s, r)
212 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
213 } else {
214 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
215 }
216 return
217 } else if idx > 0 {
218 c := components[l-1]
219 components[l-1], verb = c[:idx], c[idx+1:]
220 }
221
222 if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
223 r.Method = strings.ToUpper(override)
224 if err := r.ParseForm(); err != nil {
225 if s.protoErrorHandler != nil {
226 _, outboundMarshaler := MarshalerForRequest(s, r)
227 sterr := status.Error(codes.InvalidArgument, err.Error())
228 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
229 } else {
230 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
231 }
232 return
233 }
234 }
235 for _, h := range s.handlers[r.Method] {
236 pathParams, err := h.pat.Match(components, verb)
237 if err != nil {
238 continue
239 }
240 h.h(w, r, pathParams)
241 return
242 }
243
244 // lookup other methods to handle fallback from GET to POST and
245 // to determine if it is MethodNotAllowed or NotFound.
246 for m, handlers := range s.handlers {
247 if m == r.Method {
248 continue
249 }
250 for _, h := range handlers {
251 pathParams, err := h.pat.Match(components, verb)
252 if err != nil {
253 continue
254 }
255 // X-HTTP-Method-Override is optional. Always allow fallback to POST.
256 if s.isPathLengthFallback(r) {
257 if err := r.ParseForm(); err != nil {
258 if s.protoErrorHandler != nil {
259 _, outboundMarshaler := MarshalerForRequest(s, r)
260 sterr := status.Error(codes.InvalidArgument, err.Error())
261 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
262 } else {
263 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
264 }
265 return
266 }
267 h.h(w, r, pathParams)
268 return
269 }
270 if s.protoErrorHandler != nil {
271 _, outboundMarshaler := MarshalerForRequest(s, r)
272 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
273 } else {
274 OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
275 }
276 return
277 }
278 }
279
280 if s.protoErrorHandler != nil {
281 _, outboundMarshaler := MarshalerForRequest(s, r)
282 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
283 } else {
284 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
285 }
286}
287
288// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
289func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
290 return s.forwardResponseOptions
291}
292
293func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
294 return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
295}
296
297type handler struct {
298 pat Pattern
299 h HandlerFunc
300}