blob: f8083821f3d4b8b3de68521a5055d7d17988e8f7 [file] [log] [blame]
Matteo Scandoloa6a3aee2019-11-26 13:30:14 -07001package runtime
2
3import (
4 "context"
5 "encoding/base64"
6 "fmt"
7 "net"
8 "net/http"
9 "net/textproto"
10 "strconv"
11 "strings"
12 "time"
13
14 "google.golang.org/grpc/codes"
15 "google.golang.org/grpc/grpclog"
16 "google.golang.org/grpc/metadata"
17 "google.golang.org/grpc/status"
18)
19
20// MetadataHeaderPrefix is the http prefix that represents custom metadata
21// parameters to or from a gRPC call.
22const MetadataHeaderPrefix = "Grpc-Metadata-"
23
24// MetadataPrefix is prepended to permanent HTTP header keys (as specified
25// by the IANA) when added to the gRPC context.
26const MetadataPrefix = "grpcgateway-"
27
28// MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
29// HTTP headers in a response handled by grpc-gateway
30const MetadataTrailerPrefix = "Grpc-Trailer-"
31
32const metadataGrpcTimeout = "Grpc-Timeout"
33const metadataHeaderBinarySuffix = "-Bin"
34
35const xForwardedFor = "X-Forwarded-For"
36const xForwardedHost = "X-Forwarded-Host"
37
38var (
39 // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
40 // header isn't present. If the value is 0 the sent `context` will not have a timeout.
41 DefaultContextTimeout = 0 * time.Second
42)
43
44func decodeBinHeader(v string) ([]byte, error) {
45 if len(v)%4 == 0 {
46 // Input was padded, or padding was not necessary.
47 return base64.StdEncoding.DecodeString(v)
48 }
49 return base64.RawStdEncoding.DecodeString(v)
50}
51
52/*
53AnnotateContext adds context information such as metadata from the request.
54
55At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
56except that the forwarded destination is not another HTTP service but rather
57a gRPC service.
58*/
59func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request) (context.Context, error) {
60 ctx, md, err := annotateContext(ctx, mux, req)
61 if err != nil {
62 return nil, err
63 }
64 if md == nil {
65 return ctx, nil
66 }
67
68 return metadata.NewOutgoingContext(ctx, md), nil
69}
70
71// AnnotateIncomingContext adds context information such as metadata from the request.
72// Attach metadata as incoming context.
73func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request) (context.Context, error) {
74 ctx, md, err := annotateContext(ctx, mux, req)
75 if err != nil {
76 return nil, err
77 }
78 if md == nil {
79 return ctx, nil
80 }
81
82 return metadata.NewIncomingContext(ctx, md), nil
83}
84
85func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request) (context.Context, metadata.MD, error) {
86 var pairs []string
87 timeout := DefaultContextTimeout
88 if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
89 var err error
90 timeout, err = timeoutDecode(tm)
91 if err != nil {
92 return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
93 }
94 }
95
96 for key, vals := range req.Header {
97 for _, val := range vals {
98 key = textproto.CanonicalMIMEHeaderKey(key)
99 // For backwards-compatibility, pass through 'authorization' header with no prefix.
100 if key == "Authorization" {
101 pairs = append(pairs, "authorization", val)
102 }
103 if h, ok := mux.incomingHeaderMatcher(key); ok {
104 // Handles "-bin" metadata in grpc, since grpc will do another base64
105 // encode before sending to server, we need to decode it first.
106 if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
107 b, err := decodeBinHeader(val)
108 if err != nil {
109 return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
110 }
111
112 val = string(b)
113 }
114 pairs = append(pairs, h, val)
115 }
116 }
117 }
118 if host := req.Header.Get(xForwardedHost); host != "" {
119 pairs = append(pairs, strings.ToLower(xForwardedHost), host)
120 } else if req.Host != "" {
121 pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
122 }
123
124 if addr := req.RemoteAddr; addr != "" {
125 if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
126 if fwd := req.Header.Get(xForwardedFor); fwd == "" {
127 pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
128 } else {
129 pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
130 }
131 } else {
132 grpclog.Infof("invalid remote addr: %s", addr)
133 }
134 }
135
136 if timeout != 0 {
137 ctx, _ = context.WithTimeout(ctx, timeout)
138 }
139 if len(pairs) == 0 {
140 return ctx, nil, nil
141 }
142 md := metadata.Pairs(pairs...)
143 for _, mda := range mux.metadataAnnotators {
144 md = metadata.Join(md, mda(ctx, req))
145 }
146 return ctx, md, nil
147}
148
149// ServerMetadata consists of metadata sent from gRPC server.
150type ServerMetadata struct {
151 HeaderMD metadata.MD
152 TrailerMD metadata.MD
153}
154
155type serverMetadataKey struct{}
156
157// NewServerMetadataContext creates a new context with ServerMetadata
158func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
159 return context.WithValue(ctx, serverMetadataKey{}, md)
160}
161
162// ServerMetadataFromContext returns the ServerMetadata in ctx
163func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
164 md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
165 return
166}
167
168func timeoutDecode(s string) (time.Duration, error) {
169 size := len(s)
170 if size < 2 {
171 return 0, fmt.Errorf("timeout string is too short: %q", s)
172 }
173 d, ok := timeoutUnitToDuration(s[size-1])
174 if !ok {
175 return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
176 }
177 t, err := strconv.ParseInt(s[:size-1], 10, 64)
178 if err != nil {
179 return 0, err
180 }
181 return d * time.Duration(t), nil
182}
183
184func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
185 switch u {
186 case 'H':
187 return time.Hour, true
188 case 'M':
189 return time.Minute, true
190 case 'S':
191 return time.Second, true
192 case 'm':
193 return time.Millisecond, true
194 case 'u':
195 return time.Microsecond, true
196 case 'n':
197 return time.Nanosecond, true
198 default:
199 }
200 return
201}
202
203// isPermanentHTTPHeader checks whether hdr belongs to the list of
204// permenant request headers maintained by IANA.
205// http://www.iana.org/assignments/message-headers/message-headers.xml
206func isPermanentHTTPHeader(hdr string) bool {
207 switch hdr {
208 case
209 "Accept",
210 "Accept-Charset",
211 "Accept-Language",
212 "Accept-Ranges",
213 "Authorization",
214 "Cache-Control",
215 "Content-Type",
216 "Cookie",
217 "Date",
218 "Expect",
219 "From",
220 "Host",
221 "If-Match",
222 "If-Modified-Since",
223 "If-None-Match",
224 "If-Schedule-Tag-Match",
225 "If-Unmodified-Since",
226 "Max-Forwards",
227 "Origin",
228 "Pragma",
229 "Referer",
230 "User-Agent",
231 "Via",
232 "Warning":
233 return true
234 }
235 return false
236}