blob: 463084a7022b6eb8fac9dd51c7bc2e9a49a4e572 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001package runtime
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "strings"
8
9 "github.com/golang/protobuf/proto"
10 "google.golang.org/grpc/codes"
11 "google.golang.org/grpc/metadata"
12 "google.golang.org/grpc/status"
13)
14
15// A HandlerFunc handles a specific pair of path pattern and HTTP method.
16type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
17
18// ServeMux is a request multiplexer for grpc-gateway.
19// It matches http requests to patterns and invokes the corresponding handler.
20type ServeMux struct {
21 // handlers maps HTTP method to a list of handlers.
22 handlers map[string][]handler
23 forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
24 marshalers marshalerRegistry
25 incomingHeaderMatcher HeaderMatcherFunc
26 outgoingHeaderMatcher HeaderMatcherFunc
27 metadataAnnotators []func(context.Context, *http.Request) metadata.MD
28 protoErrorHandler ProtoErrorHandlerFunc
29}
30
31// ServeMuxOption is an option that can be given to a ServeMux on construction.
32type ServeMuxOption func(*ServeMux)
33
34// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
35//
36// forwardResponseOption is an option that will be called on the relevant context.Context,
37// http.ResponseWriter, and proto.Message before every forwarded response.
38//
39// The message may be nil in the case where just a header is being sent.
40func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
41 return func(serveMux *ServeMux) {
42 serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
43 }
44}
45
46// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
47type HeaderMatcherFunc func(string) (string, bool)
48
49// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
50// keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
51// 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
52func DefaultHeaderMatcher(key string) (string, bool) {
53 if isPermanentHTTPHeader(key) {
54 return MetadataPrefix + key, true
55 } else if strings.HasPrefix(key, MetadataHeaderPrefix) {
56 return key[len(MetadataHeaderPrefix):], true
57 }
58 return "", false
59}
60
61// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
62//
63// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
64// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
65func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
66 return func(mux *ServeMux) {
67 mux.incomingHeaderMatcher = fn
68 }
69}
70
71// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
72//
73// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
74// passed to http response returned from gateway. To transform the header before passing to response,
75// matcher should return modified header.
76func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
77 return func(mux *ServeMux) {
78 mux.outgoingHeaderMatcher = fn
79 }
80}
81
82// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
83//
84// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
85// is reading token from cookie and adding it in gRPC context.
86func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
87 return func(serveMux *ServeMux) {
88 serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
89 }
90}
91
92// WithProtoErrorHandler returns a ServeMuxOption for passing metadata to a gRPC context.
93//
94// This can be used to handle an error as general proto message defined by gRPC.
95// The response including body and status is not backward compatible with the default error handler.
96// When this option is used, HTTPError and OtherErrorHandler are overwritten on initialization.
97func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption {
98 return func(serveMux *ServeMux) {
99 serveMux.protoErrorHandler = fn
100 }
101}
102
103// NewServeMux returns a new ServeMux whose internal mapping is empty.
104func NewServeMux(opts ...ServeMuxOption) *ServeMux {
105 serveMux := &ServeMux{
106 handlers: make(map[string][]handler),
107 forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
108 marshalers: makeMarshalerMIMERegistry(),
109 }
110
111 for _, opt := range opts {
112 opt(serveMux)
113 }
114
115 if serveMux.protoErrorHandler != nil {
116 HTTPError = serveMux.protoErrorHandler
117 // OtherErrorHandler is no longer used when protoErrorHandler is set.
118 // Overwritten by a special error handler to return Unknown.
119 OtherErrorHandler = func(w http.ResponseWriter, r *http.Request, _ string, _ int) {
120 ctx := context.Background()
121 _, outboundMarshaler := MarshalerForRequest(serveMux, r)
122 sterr := status.Error(codes.Unknown, "unexpected use of OtherErrorHandler")
123 serveMux.protoErrorHandler(ctx, serveMux, outboundMarshaler, w, r, sterr)
124 }
125 }
126
127 if serveMux.incomingHeaderMatcher == nil {
128 serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
129 }
130
131 if serveMux.outgoingHeaderMatcher == nil {
132 serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
133 return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
134 }
135 }
136
137 return serveMux
138}
139
140// Handle associates "h" to the pair of HTTP method and path pattern.
141func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
142 s.handlers[meth] = append(s.handlers[meth], handler{pat: pat, h: h})
143}
144
145// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
146func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
147 ctx := r.Context()
148
149 path := r.URL.Path
150 if !strings.HasPrefix(path, "/") {
151 if s.protoErrorHandler != nil {
152 _, outboundMarshaler := MarshalerForRequest(s, r)
153 sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest))
154 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
155 } else {
156 OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
157 }
158 return
159 }
160
161 components := strings.Split(path[1:], "/")
162 l := len(components)
163 var verb string
164 if idx := strings.LastIndex(components[l-1], ":"); idx == 0 {
165 if s.protoErrorHandler != nil {
166 _, outboundMarshaler := MarshalerForRequest(s, r)
167 sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
168 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
169 } else {
170 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
171 }
172 return
173 } else if idx > 0 {
174 c := components[l-1]
175 components[l-1], verb = c[:idx], c[idx+1:]
176 }
177
178 if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) {
179 r.Method = strings.ToUpper(override)
180 if err := r.ParseForm(); err != nil {
181 if s.protoErrorHandler != nil {
182 _, outboundMarshaler := MarshalerForRequest(s, r)
183 sterr := status.Error(codes.InvalidArgument, err.Error())
184 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
185 } else {
186 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
187 }
188 return
189 }
190 }
191 for _, h := range s.handlers[r.Method] {
192 pathParams, err := h.pat.Match(components, verb)
193 if err != nil {
194 continue
195 }
196 h.h(w, r, pathParams)
197 return
198 }
199
200 // lookup other methods to handle fallback from GET to POST and
201 // to determine if it is MethodNotAllowed or NotFound.
202 for m, handlers := range s.handlers {
203 if m == r.Method {
204 continue
205 }
206 for _, h := range handlers {
207 pathParams, err := h.pat.Match(components, verb)
208 if err != nil {
209 continue
210 }
211 // X-HTTP-Method-Override is optional. Always allow fallback to POST.
212 if isPathLengthFallback(r) {
213 if err := r.ParseForm(); err != nil {
214 if s.protoErrorHandler != nil {
215 _, outboundMarshaler := MarshalerForRequest(s, r)
216 sterr := status.Error(codes.InvalidArgument, err.Error())
217 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
218 } else {
219 OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest)
220 }
221 return
222 }
223 h.h(w, r, pathParams)
224 return
225 }
226 if s.protoErrorHandler != nil {
227 _, outboundMarshaler := MarshalerForRequest(s, r)
228 sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusMethodNotAllowed))
229 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
230 } else {
231 OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
232 }
233 return
234 }
235 }
236
237 if s.protoErrorHandler != nil {
238 _, outboundMarshaler := MarshalerForRequest(s, r)
239 sterr := status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))
240 s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr)
241 } else {
242 OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound)
243 }
244}
245
246// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
247func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
248 return s.forwardResponseOptions
249}
250
251func isPathLengthFallback(r *http.Request) bool {
252 return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
253}
254
255type handler struct {
256 pat Pattern
257 h HandlerFunc
258}