blob: ec81e55b5ef994e8f6c5d518ddca9d52dc77434b [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001package runtime
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/textproto"
8 "strings"
9
10 "github.com/golang/protobuf/proto"
11 "google.golang.org/grpc/codes"
12 "google.golang.org/grpc/metadata"
13 "google.golang.org/grpc/status"
14)
15
16// A HandlerFunc handles a specific pair of path pattern and HTTP method.
17type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
18
19// ServeMux is a request multiplexer for grpc-gateway.
20// It matches http requests to patterns and invokes the corresponding handler.
21type ServeMux struct {
22 // handlers maps HTTP method to a list of handlers.
23 handlers map[string][]handler
24 forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
25 marshalers marshalerRegistry
26 incomingHeaderMatcher HeaderMatcherFunc
27 outgoingHeaderMatcher HeaderMatcherFunc
28 metadataAnnotators []func(context.Context, *http.Request) metadata.MD
29 protoErrorHandler ProtoErrorHandlerFunc
30 disablePathLengthFallback bool
31}
32
33// ServeMuxOption is an option that can be given to a ServeMux on construction.
34type ServeMuxOption func(*ServeMux)
35
36// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
37//
38// forwardResponseOption is an option that will be called on the relevant context.Context,
39// http.ResponseWriter, and proto.Message before every forwarded response.
40//
41// The message may be nil in the case where just a header is being sent.
42func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
43 return func(serveMux *ServeMux) {
44 serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
45 }
46}
47
48// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
49type HeaderMatcherFunc func(string) (string, bool)
50
51// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
52// keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
53// 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
54func DefaultHeaderMatcher(key string) (string, bool) {
55 key = textproto.CanonicalMIMEHeaderKey(key)
56 if isPermanentHTTPHeader(key) {
57 return MetadataPrefix + key, true
58 } else if strings.HasPrefix(key, MetadataHeaderPrefix) {
59 return key[len(MetadataHeaderPrefix):], true
60 }
61 return "", false
62}
63
64// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
65//
66// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
67// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
68func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
69 return func(mux *ServeMux) {
70 mux.incomingHeaderMatcher = fn
71 }
72}
73
74// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
75//
76// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
77// passed to http response returned from gateway. To transform the header before passing to response,
78// matcher should return modified header.
79func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
80 return func(mux *ServeMux) {
81 mux.outgoingHeaderMatcher = fn
82 }
83}
84
85// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
86//
87// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
88// is reading token from cookie and adding it in gRPC context.
89func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
90 return func(serveMux *ServeMux) {
91 serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
92 }
93}
94
95// WithProtoErrorHandler returns a ServeMuxOption for passing metadata to a gRPC context.
96//
97// This can be used to handle an error as general proto message defined by gRPC.
98// The response including body and status is not backward compatible with the default error handler.
99// When this option is used, HTTPError and OtherErrorHandler are overwritten on initialization.
100func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
101 return func(serveMux *ServeMux) {
102 serveMux.protoErrorHandler = fn
103 }
104}
105
106// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
107func WithDisablePathLengthFallback() ServeMuxOption {
108 return func(serveMux *ServeMux) {
109 serveMux.disablePathLengthFallback = true
110 }
111}
112
113// NewServeMux returns a new ServeMux whose internal mapping is empty.
114func NewServeMux(opts ...ServeMuxOption) *ServeMux {
115 serveMux := &ServeMux{
116 handlers: make(map[string][]handler),
117 forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
118 marshalers: makeMarshalerMIMERegistry(),
119 }
120
121 for _, opt := range opts {
122 opt(serveMux)
123 }
124
125 if serveMux.protoErrorHandler != nil {
126 HTTPError = serveMux.protoErrorHandler
127 // OtherErrorHandler is no longer used when protoErrorHandler is set.
128 // Overwritten by a special error handler to return Unknown.
129 OtherErrorHandler = func(w http.ResponseWriter, r *http.Request, _ string, _ int) {
130 ctx := context.Background()
131 _, outboundMarshaler := MarshalerForRequest(serveMux, r)
132 sterr := status.Error(codes.Unknown, "unexpected use of OtherErrorHandler")
133 serveMux.protoErrorHandler(ctx, serveMux, outboundMarshaler, w, r, sterr)
134 }
135 }
136
137 if serveMux.incomingHeaderMatcher == nil {
138 serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
139 }
140
141 if serveMux.outgoingHeaderMatcher == nil {
142 serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
143 return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
144 }
145 }
146
147 return serveMux
148}
149
150// Handle associates "h" to the pair of HTTP method and path pattern.
151func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
152 s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
153}
154
155// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
156func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
157 ctx := r.Context()
158
159 path := r.URL.Path
160 if !strings.HasPrefix(path, "/") {
161 if s.protoErrorHandler != nil {
162 _, outboundMarshaler := MarshalerForRequest(s, r)
163 sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
164 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
165 } else {
166 OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
167 }
168 return
169 }
170
171 components := strings.Split(path[1:], "/")
172 l := len(components)
173 var verb string
174 if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
175 if s.protoErrorHandler != nil {
176 _, outboundMarshaler := MarshalerForRequest(s, r)
177 sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
178 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
179 } else {
180 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
181 }
182 return
183 } else if idx > 0 {
184 c := components[l-1]
185 components[l-1], verb = c[:idx], c[idx+1:]
186 }
187
188 if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
189 r.Method = strings.ToUpper(override)
190 if err := r.ParseForm(); err != nil {
191 if s.protoErrorHandler != nil {
192 _, outboundMarshaler := MarshalerForRequest(s, r)
193 sterr := status.Error(codes.InvalidArgument, err.Error())
194 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
195 } else {
196 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
197 }
198 return
199 }
200 }
201 for _, h := range s.handlers[r.Method] {
202 pathParams, err := h.pat.Match(components, verb)
203 if err != nil {
204 continue
205 }
206 h.h(w, r, pathParams)
207 return
208 }
209
210 // lookup other methods to handle fallback from GET to POST and
211 // to determine if it is MethodNotAllowed or NotFound.
212 for m, handlers := range s.handlers {
213 if m == r.Method {
214 continue
215 }
216 for _, h := range handlers {
217 pathParams, err := h.pat.Match(components, verb)
218 if err != nil {
219 continue
220 }
221 // X-HTTP-Method-Override is optional. Always allow fallback to POST.
222 if s.isPathLengthFallback(r) {
223 if err := r.ParseForm(); err != nil {
224 if s.protoErrorHandler != nil {
225 _, outboundMarshaler := MarshalerForRequest(s, r)
226 sterr := status.Error(codes.InvalidArgument, err.Error())
227 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
228 } else {
229 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
230 }
231 return
232 }
233 h.h(w, r, pathParams)
234 return
235 }
236 if s.protoErrorHandler != nil {
237 _, outboundMarshaler := MarshalerForRequest(s, r)
238 sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusMethodNotAllowed))
239 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
240 } else {
241 OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
242 }
243 return
244 }
245 }
246
247 if s.protoErrorHandler != nil {
248 _, outboundMarshaler := MarshalerForRequest(s, r)
249 sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
250 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
251 } else {
252 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
253 }
254}
255
256// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
257func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
258 return s.forwardResponseOptions
259}
260
261func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
262 return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
263}
264
265type handler struct {
266 pat Pattern
267 h HandlerFunc
268}