blob: d8247bcdf69249555649fa5a6179eae883bb0962 [file] [log] [blame]
khenaidoo5fc5cea2021-08-11 17:39:16 -04001/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package transport
20
21import (
22 "bufio"
23 "bytes"
24 "encoding/base64"
25 "fmt"
26 "io"
27 "math"
28 "net"
29 "net/http"
30 "net/url"
31 "strconv"
32 "strings"
33 "time"
34 "unicode/utf8"
35
36 "github.com/golang/protobuf/proto"
37 "golang.org/x/net/http2"
38 "golang.org/x/net/http2/hpack"
39 spb "google.golang.org/genproto/googleapis/rpc/status"
40 "google.golang.org/grpc/codes"
41 "google.golang.org/grpc/grpclog"
42 "google.golang.org/grpc/status"
43)
44
45const (
46 // http2MaxFrameLen specifies the max length of a HTTP2 frame.
47 http2MaxFrameLen = 16384 // 16KB frame
48 // http://http2.github.io/http2-spec/#SettingValues
49 http2InitHeaderTableSize = 4096
50 // baseContentType is the base content-type for gRPC. This is a valid
51 // content-type on it's own, but can also include a content-subtype such as
52 // "proto" as a suffix after "+" or ";". See
53 // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
54 // for more details.
55
56)
57
58var (
59 clientPreface = []byte(http2.ClientPreface)
60 http2ErrConvTab = map[http2.ErrCode]codes.Code{
61 http2.ErrCodeNo: codes.Internal,
62 http2.ErrCodeProtocol: codes.Internal,
63 http2.ErrCodeInternal: codes.Internal,
64 http2.ErrCodeFlowControl: codes.ResourceExhausted,
65 http2.ErrCodeSettingsTimeout: codes.Internal,
66 http2.ErrCodeStreamClosed: codes.Internal,
67 http2.ErrCodeFrameSize: codes.Internal,
68 http2.ErrCodeRefusedStream: codes.Unavailable,
69 http2.ErrCodeCancel: codes.Canceled,
70 http2.ErrCodeCompression: codes.Internal,
71 http2.ErrCodeConnect: codes.Internal,
72 http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
73 http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
74 http2.ErrCodeHTTP11Required: codes.Internal,
75 }
76 // HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
77 HTTPStatusConvTab = map[int]codes.Code{
78 // 400 Bad Request - INTERNAL.
79 http.StatusBadRequest: codes.Internal,
80 // 401 Unauthorized - UNAUTHENTICATED.
81 http.StatusUnauthorized: codes.Unauthenticated,
82 // 403 Forbidden - PERMISSION_DENIED.
83 http.StatusForbidden: codes.PermissionDenied,
84 // 404 Not Found - UNIMPLEMENTED.
85 http.StatusNotFound: codes.Unimplemented,
86 // 429 Too Many Requests - UNAVAILABLE.
87 http.StatusTooManyRequests: codes.Unavailable,
88 // 502 Bad Gateway - UNAVAILABLE.
89 http.StatusBadGateway: codes.Unavailable,
90 // 503 Service Unavailable - UNAVAILABLE.
91 http.StatusServiceUnavailable: codes.Unavailable,
92 // 504 Gateway timeout - UNAVAILABLE.
93 http.StatusGatewayTimeout: codes.Unavailable,
94 }
95 logger = grpclog.Component("transport")
96)
97
98// isReservedHeader checks whether hdr belongs to HTTP2 headers
99// reserved by gRPC protocol. Any other headers are classified as the
100// user-specified metadata.
101func isReservedHeader(hdr string) bool {
102 if hdr != "" && hdr[0] == ':' {
103 return true
104 }
105 switch hdr {
106 case "content-type",
107 "user-agent",
108 "grpc-message-type",
109 "grpc-encoding",
110 "grpc-message",
111 "grpc-status",
112 "grpc-timeout",
113 "grpc-status-details-bin",
114 // Intentionally exclude grpc-previous-rpc-attempts and
115 // grpc-retry-pushback-ms, which are "reserved", but their API
116 // intentionally works via metadata.
117 "te":
118 return true
119 default:
120 return false
121 }
122}
123
124// isWhitelistedHeader checks whether hdr should be propagated into metadata
125// visible to users, even though it is classified as "reserved", above.
126func isWhitelistedHeader(hdr string) bool {
127 switch hdr {
128 case ":authority", "user-agent":
129 return true
130 default:
131 return false
132 }
133}
134
135const binHdrSuffix = "-bin"
136
137func encodeBinHeader(v []byte) string {
138 return base64.RawStdEncoding.EncodeToString(v)
139}
140
141func decodeBinHeader(v string) ([]byte, error) {
142 if len(v)%4 == 0 {
143 // Input was padded, or padding was not necessary.
144 return base64.StdEncoding.DecodeString(v)
145 }
146 return base64.RawStdEncoding.DecodeString(v)
147}
148
149func encodeMetadataHeader(k, v string) string {
150 if strings.HasSuffix(k, binHdrSuffix) {
151 return encodeBinHeader(([]byte)(v))
152 }
153 return v
154}
155
156func decodeMetadataHeader(k, v string) (string, error) {
157 if strings.HasSuffix(k, binHdrSuffix) {
158 b, err := decodeBinHeader(v)
159 return string(b), err
160 }
161 return v, nil
162}
163
164func decodeGRPCStatusDetails(rawDetails string) (*status.Status, error) {
165 v, err := decodeBinHeader(rawDetails)
166 if err != nil {
167 return nil, err
168 }
169 st := &spb.Status{}
170 if err = proto.Unmarshal(v, st); err != nil {
171 return nil, err
172 }
173 return status.FromProto(st), nil
174}
175
176type timeoutUnit uint8
177
178const (
179 hour timeoutUnit = 'H'
180 minute timeoutUnit = 'M'
181 second timeoutUnit = 'S'
182 millisecond timeoutUnit = 'm'
183 microsecond timeoutUnit = 'u'
184 nanosecond timeoutUnit = 'n'
185)
186
187func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
188 switch u {
189 case hour:
190 return time.Hour, true
191 case minute:
192 return time.Minute, true
193 case second:
194 return time.Second, true
195 case millisecond:
196 return time.Millisecond, true
197 case microsecond:
198 return time.Microsecond, true
199 case nanosecond:
200 return time.Nanosecond, true
201 default:
202 }
203 return
204}
205
206func decodeTimeout(s string) (time.Duration, error) {
207 size := len(s)
208 if size < 2 {
209 return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
210 }
211 if size > 9 {
212 // Spec allows for 8 digits plus the unit.
213 return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
214 }
215 unit := timeoutUnit(s[size-1])
216 d, ok := timeoutUnitToDuration(unit)
217 if !ok {
218 return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
219 }
220 t, err := strconv.ParseInt(s[:size-1], 10, 64)
221 if err != nil {
222 return 0, err
223 }
224 const maxHours = math.MaxInt64 / int64(time.Hour)
225 if d == time.Hour && t > maxHours {
226 // This timeout would overflow math.MaxInt64; clamp it.
227 return time.Duration(math.MaxInt64), nil
228 }
229 return d * time.Duration(t), nil
230}
231
232const (
233 spaceByte = ' '
234 tildeByte = '~'
235 percentByte = '%'
236)
237
238// encodeGrpcMessage is used to encode status code in header field
239// "grpc-message". It does percent encoding and also replaces invalid utf-8
240// characters with Unicode replacement character.
241//
242// It checks to see if each individual byte in msg is an allowable byte, and
243// then either percent encoding or passing it through. When percent encoding,
244// the byte is converted into hexadecimal notation with a '%' prepended.
245func encodeGrpcMessage(msg string) string {
246 if msg == "" {
247 return ""
248 }
249 lenMsg := len(msg)
250 for i := 0; i < lenMsg; i++ {
251 c := msg[i]
252 if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
253 return encodeGrpcMessageUnchecked(msg)
254 }
255 }
256 return msg
257}
258
259func encodeGrpcMessageUnchecked(msg string) string {
260 var buf bytes.Buffer
261 for len(msg) > 0 {
262 r, size := utf8.DecodeRuneInString(msg)
263 for _, b := range []byte(string(r)) {
264 if size > 1 {
265 // If size > 1, r is not ascii. Always do percent encoding.
266 buf.WriteString(fmt.Sprintf("%%%02X", b))
267 continue
268 }
269
270 // The for loop is necessary even if size == 1. r could be
271 // utf8.RuneError.
272 //
273 // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
274 if b >= spaceByte && b <= tildeByte && b != percentByte {
275 buf.WriteByte(b)
276 } else {
277 buf.WriteString(fmt.Sprintf("%%%02X", b))
278 }
279 }
280 msg = msg[size:]
281 }
282 return buf.String()
283}
284
285// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
286func decodeGrpcMessage(msg string) string {
287 if msg == "" {
288 return ""
289 }
290 lenMsg := len(msg)
291 for i := 0; i < lenMsg; i++ {
292 if msg[i] == percentByte && i+2 < lenMsg {
293 return decodeGrpcMessageUnchecked(msg)
294 }
295 }
296 return msg
297}
298
299func decodeGrpcMessageUnchecked(msg string) string {
300 var buf bytes.Buffer
301 lenMsg := len(msg)
302 for i := 0; i < lenMsg; i++ {
303 c := msg[i]
304 if c == percentByte && i+2 < lenMsg {
305 parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
306 if err != nil {
307 buf.WriteByte(c)
308 } else {
309 buf.WriteByte(byte(parsed))
310 i += 2
311 }
312 } else {
313 buf.WriteByte(c)
314 }
315 }
316 return buf.String()
317}
318
319type bufWriter struct {
320 buf []byte
321 offset int
322 batchSize int
323 conn net.Conn
324 err error
325
326 onFlush func()
327}
328
329func newBufWriter(conn net.Conn, batchSize int) *bufWriter {
330 return &bufWriter{
331 buf: make([]byte, batchSize*2),
332 batchSize: batchSize,
333 conn: conn,
334 }
335}
336
337func (w *bufWriter) Write(b []byte) (n int, err error) {
338 if w.err != nil {
339 return 0, w.err
340 }
341 if w.batchSize == 0 { // Buffer has been disabled.
342 return w.conn.Write(b)
343 }
344 for len(b) > 0 {
345 nn := copy(w.buf[w.offset:], b)
346 b = b[nn:]
347 w.offset += nn
348 n += nn
349 if w.offset >= w.batchSize {
350 err = w.Flush()
351 }
352 }
353 return n, err
354}
355
356func (w *bufWriter) Flush() error {
357 if w.err != nil {
358 return w.err
359 }
360 if w.offset == 0 {
361 return nil
362 }
363 if w.onFlush != nil {
364 w.onFlush()
365 }
366 _, w.err = w.conn.Write(w.buf[:w.offset])
367 w.offset = 0
368 return w.err
369}
370
371type framer struct {
372 writer *bufWriter
373 fr *http2.Framer
374}
375
376func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderListSize uint32) *framer {
377 if writeBufferSize < 0 {
378 writeBufferSize = 0
379 }
380 var r io.Reader = conn
381 if readBufferSize > 0 {
382 r = bufio.NewReaderSize(r, readBufferSize)
383 }
384 w := newBufWriter(conn, writeBufferSize)
385 f := &framer{
386 writer: w,
387 fr: http2.NewFramer(w, r),
388 }
389 f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
390 // Opt-in to Frame reuse API on framer to reduce garbage.
391 // Frames aren't safe to read from after a subsequent call to ReadFrame.
392 f.fr.SetReuseFrames()
393 f.fr.MaxHeaderListSize = maxHeaderListSize
394 f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
395 return f
396}
397
398// parseDialTarget returns the network and address to pass to dialer.
399func parseDialTarget(target string) (string, string) {
400 net := "tcp"
401 m1 := strings.Index(target, ":")
402 m2 := strings.Index(target, ":/")
403 // handle unix:addr which will fail with url.Parse
404 if m1 >= 0 && m2 < 0 {
405 if n := target[0:m1]; n == "unix" {
406 return n, target[m1+1:]
407 }
408 }
409 if m2 >= 0 {
410 t, err := url.Parse(target)
411 if err != nil {
412 return net, target
413 }
414 scheme := t.Scheme
415 addr := t.Path
416 if scheme == "unix" {
417 if addr == "" {
418 addr = t.Host
419 }
420 return scheme, addr
421 }
422 }
423 return net, target
424}