blob: f8f69bfb70fd9309875cfb567597ef02c0ed2e3a [file] [log] [blame]
Scott Baker2d897982019-09-24 11:50:08 -07001/*
2 *
3 * Copyright 2017 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package grpc
20
21import (
22 "bufio"
23 "context"
24 "encoding/base64"
25 "errors"
26 "fmt"
27 "io"
28 "net"
29 "net/http"
30 "net/http/httputil"
31 "net/url"
32)
33
34const proxyAuthHeaderKey = "Proxy-Authorization"
35
36var (
37 // errDisabled indicates that proxy is disabled for the address.
38 errDisabled = errors.New("proxy is disabled for the address")
39 // The following variable will be overwritten in the tests.
40 httpProxyFromEnvironment = http.ProxyFromEnvironment
41)
42
43func mapAddress(ctx context.Context, address string) (*url.URL, error) {
44 req := &http.Request{
45 URL: &url.URL{
46 Scheme: "https",
47 Host: address,
48 },
49 }
50 url, err := httpProxyFromEnvironment(req)
51 if err != nil {
52 return nil, err
53 }
54 if url == nil {
55 return nil, errDisabled
56 }
57 return url, nil
58}
59
60// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
61// It's possible that this reader reads more than what's need for the response and stores
62// those bytes in the buffer.
63// bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the
64// bytes in the buffer.
65type bufConn struct {
66 net.Conn
67 r io.Reader
68}
69
70func (c *bufConn) Read(b []byte) (int, error) {
71 return c.r.Read(b)
72}
73
74func basicAuth(username, password string) string {
75 auth := username + ":" + password
76 return base64.StdEncoding.EncodeToString([]byte(auth))
77}
78
79func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL) (_ net.Conn, err error) {
80 defer func() {
81 if err != nil {
82 conn.Close()
83 }
84 }()
85
86 req := &http.Request{
87 Method: http.MethodConnect,
88 URL: &url.URL{Host: backendAddr},
89 Header: map[string][]string{"User-Agent": {grpcUA}},
90 }
91 if t := proxyURL.User; t != nil {
92 u := t.Username()
93 p, _ := t.Password()
94 req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p))
95 }
96
97 if err := sendHTTPRequest(ctx, req, conn); err != nil {
98 return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
99 }
100
101 r := bufio.NewReader(conn)
102 resp, err := http.ReadResponse(r, req)
103 if err != nil {
104 return nil, fmt.Errorf("reading server HTTP response: %v", err)
105 }
106 defer resp.Body.Close()
107 if resp.StatusCode != http.StatusOK {
108 dump, err := httputil.DumpResponse(resp, true)
109 if err != nil {
110 return nil, fmt.Errorf("failed to do connect handshake, status code: %s", resp.Status)
111 }
112 return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
113 }
114
115 return &bufConn{Conn: conn, r: r}, nil
116}
117
118// newProxyDialer returns a dialer that connects to proxy first if necessary.
119// The returned dialer checks if a proxy is necessary, dial to the proxy with the
120// provided dialer, does HTTP CONNECT handshake and returns the connection.
121func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
122 return func(ctx context.Context, addr string) (conn net.Conn, err error) {
123 var newAddr string
124 proxyURL, err := mapAddress(ctx, addr)
125 if err != nil {
126 if err != errDisabled {
127 return nil, err
128 }
129 newAddr = addr
130 } else {
131 newAddr = proxyURL.Host
132 }
133
134 conn, err = dialer(ctx, newAddr)
135 if err != nil {
136 return
137 }
138 if proxyURL != nil {
139 // proxy is disabled if proxyURL is nil.
140 conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL)
141 }
142 return
143 }
144}
145
146func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {
147 req = req.WithContext(ctx)
148 if err := req.Write(conn); err != nil {
149 return fmt.Errorf("failed to write the HTTP request: %v", err)
150 }
151 return nil
152}