blob: 1da3a58854db35c3e0ace2085e3cf4dd2bbeec5d [file] [log] [blame]
khenaidoo59ce9dd2019-11-11 13:05:32 -05001package runtime
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/textproto"
8 "strings"
9
10 "github.com/golang/protobuf/proto"
11 "google.golang.org/grpc/codes"
12 "google.golang.org/grpc/metadata"
13 "google.golang.org/grpc/status"
14)
15
16// A HandlerFunc handles a specific pair of path pattern and HTTP method.
17type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
18
19// ErrUnknownURI is the error supplied to a custom ProtoErrorHandlerFunc when
20// a request is received with a URI path that does not match any registered
21// service method.
22//
23// Since gRPC servers return an "Unimplemented" code for requests with an
24// unrecognized URI path, this error also has a gRPC "Unimplemented" code.
25var ErrUnknownURI = status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
26
27// ServeMux is a request multiplexer for grpc-gateway.
28// It matches http requests to patterns and invokes the corresponding handler.
29type ServeMux struct {
30 // handlers maps HTTP method to a list of handlers.
31 handlers map[string][]handler
32 forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
33 marshalers marshalerRegistry
34 incomingHeaderMatcher HeaderMatcherFunc
35 outgoingHeaderMatcher HeaderMatcherFunc
36 metadataAnnotators []func(context.Context, *http.Request) metadata.MD
37 streamErrorHandler StreamErrorHandlerFunc
38 protoErrorHandler ProtoErrorHandlerFunc
39 disablePathLengthFallback bool
40 lastMatchWins bool
41}
42
43// ServeMuxOption is an option that can be given to a ServeMux on construction.
44type ServeMuxOption func(*ServeMux)
45
46// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
47//
48// forwardResponseOption is an option that will be called on the relevant context.Context,
49// http.ResponseWriter, and proto.Message before every forwarded response.
50//
51// The message may be nil in the case where just a header is being sent.
52func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
53 return func(serveMux *ServeMux) {
54 serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
55 }
56}
57
58// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
59type HeaderMatcherFunc func(string) (string, bool)
60
61// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
62// keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
63// 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
64func DefaultHeaderMatcher(key string) (string, bool) {
65 key = textproto.CanonicalMIMEHeaderKey(key)
66 if isPermanentHTTPHeader(key) {
67 return MetadataPrefix + key, true
68 } else if strings.HasPrefix(key, MetadataHeaderPrefix) {
69 return key[len(MetadataHeaderPrefix):], true
70 }
71 return "", false
72}
73
74// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
75//
76// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
77// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
78func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
79 return func(mux *ServeMux) {
80 mux.incomingHeaderMatcher = fn
81 }
82}
83
84// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
85//
86// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
87// passed to http response returned from gateway. To transform the header before passing to response,
88// matcher should return modified header.
89func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
90 return func(mux *ServeMux) {
91 mux.outgoingHeaderMatcher = fn
92 }
93}
94
95// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
96//
97// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
98// is reading token from cookie and adding it in gRPC context.
99func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
100 return func(serveMux *ServeMux) {
101 serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
102 }
103}
104
105// WithProtoErrorHandler returns a ServeMuxOption for passing metadata to a gRPC context.
106//
107// This can be used to handle an error as general proto message defined by gRPC.
108// The response including body and status is not backward compatible with the default error handler.
109// When this option is used, HTTPError and OtherErrorHandler are overwritten on initialization.
110func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
111 return func(serveMux *ServeMux) {
112 serveMux.protoErrorHandler = fn
113 }
114}
115
116// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
117func WithDisablePathLengthFallback() ServeMuxOption {
118 return func(serveMux *ServeMux) {
119 serveMux.disablePathLengthFallback = true
120 }
121}
122
123// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
124// error handler, which allows for customizing the error trailer for server-streaming
125// calls.
126//
127// For stream errors that occur before any response has been written, the mux's
128// ProtoErrorHandler will be invoked. However, once data has been written, the errors must
129// be handled differently: they must be included in the response body. The response body's
130// final message will include the error details returned by the stream error handler.
131func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
132 return func(serveMux *ServeMux) {
133 serveMux.streamErrorHandler = fn
134 }
135}
136
137// WithLastMatchWins returns a ServeMuxOption that will enable "last
138// match wins" behavior, where if multiple path patterns match a
139// request path, the last one defined in the .proto file will be used.
140func WithLastMatchWins() ServeMuxOption {
141 return func(serveMux *ServeMux) {
142 serveMux.lastMatchWins = true
143 }
144}
145
146// NewServeMux returns a new ServeMux whose internal mapping is empty.
147func NewServeMux(opts ...ServeMuxOption) *ServeMux {
148 serveMux := &ServeMux{
149 handlers: make(map[string][]handler),
150 forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
151 marshalers: makeMarshalerMIMERegistry(),
152 streamErrorHandler: DefaultHTTPStreamErrorHandler,
153 }
154
155 for _, opt := range opts {
156 opt(serveMux)
157 }
158
159 if serveMux.protoErrorHandler != nil {
160 HTTPError = serveMux.protoErrorHandler
161 // OtherErrorHandler is no longer used when protoErrorHandler is set.
162 // Overwritten by a special error handler to return Unknown.
163 OtherErrorHandler = func(w http.ResponseWriter, r *http.Request, _ string, _ int) {
164 ctx := context.Background()
165 _, outboundMarshaler := MarshalerForRequest(serveMux, r)
166 sterr := status.Error(codes.Unknown, "unexpected use of OtherErrorHandler")
167 serveMux.protoErrorHandler(ctx, serveMux, outboundMarshaler, w, r, sterr)
168 }
169 }
170
171 if serveMux.incomingHeaderMatcher == nil {
172 serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
173 }
174
175 if serveMux.outgoingHeaderMatcher == nil {
176 serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
177 return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
178 }
179 }
180
181 return serveMux
182}
183
184// Handle associates "h" to the pair of HTTP method and path pattern.
185func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
186 if s.lastMatchWins {
187 s.handlers[meth] = append([]handler{handler{pat: pat, h: h}}, s.handlers[meth]...)
188 } else {
189 s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
190 }
191}
192
193// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
194func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
195 ctx := r.Context()
196
197 path := r.URL.Path
198 if !strings.HasPrefix(path, "/") {
199 if s.protoErrorHandler != nil {
200 _, outboundMarshaler := MarshalerForRequest(s, r)
201 sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
202 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
203 } else {
204 OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
205 }
206 return
207 }
208
209 components := strings.Split(path[1:], "/")
210 l := len(components)
211 var verb string
212 if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
213 if s.protoErrorHandler != nil {
214 _, outboundMarshaler := MarshalerForRequest(s, r)
215 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
216 } else {
217 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
218 }
219 return
220 } else if idx > 0 {
221 c := components[l-1]
222 components[l-1], verb = c[:idx], c[idx+1:]
223 }
224
225 if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
226 r.Method = strings.ToUpper(override)
227 if err := r.ParseForm(); err != nil {
228 if s.protoErrorHandler != nil {
229 _, outboundMarshaler := MarshalerForRequest(s, r)
230 sterr := status.Error(codes.InvalidArgument, err.Error())
231 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
232 } else {
233 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
234 }
235 return
236 }
237 }
238 for _, h := range s.handlers[r.Method] {
239 pathParams, err := h.pat.Match(components, verb)
240 if err != nil {
241 continue
242 }
243 h.h(w, r, pathParams)
244 return
245 }
246
247 // lookup other methods to handle fallback from GET to POST and
248 // to determine if it is MethodNotAllowed or NotFound.
249 for m, handlers := range s.handlers {
250 if m == r.Method {
251 continue
252 }
253 for _, h := range handlers {
254 pathParams, err := h.pat.Match(components, verb)
255 if err != nil {
256 continue
257 }
258 // X-HTTP-Method-Override is optional. Always allow fallback to POST.
259 if s.isPathLengthFallback(r) {
260 if err := r.ParseForm(); err != nil {
261 if s.protoErrorHandler != nil {
262 _, outboundMarshaler := MarshalerForRequest(s, r)
263 sterr := status.Error(codes.InvalidArgument, err.Error())
264 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
265 } else {
266 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
267 }
268 return
269 }
270 h.h(w, r, pathParams)
271 return
272 }
273 if s.protoErrorHandler != nil {
274 _, outboundMarshaler := MarshalerForRequest(s, r)
275 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
276 } else {
277 OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
278 }
279 return
280 }
281 }
282
283 if s.protoErrorHandler != nil {
284 _, outboundMarshaler := MarshalerForRequest(s, r)
285 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI)
286 } else {
287 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
288 }
289}
290
291// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
292func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
293 return s.forwardResponseOptions
294}
295
296func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
297 return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
298}
299
300type handler struct {
301 pat Pattern
302 h HandlerFunc
303}