blob: 7c2a5e6286d1c8ec903e3c4ac17e263b1f970d14 [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package net
18
19import (
20 "bufio"
21 "bytes"
22 "context"
23 "crypto/tls"
24 "fmt"
25 "io"
26 "net"
27 "net/http"
28 "net/url"
29 "os"
30 "path"
31 "strconv"
32 "strings"
33
34 "github.com/golang/glog"
35 "golang.org/x/net/http2"
36)
37
38// JoinPreservingTrailingSlash does a path.Join of the specified elements,
39// preserving any trailing slash on the last non-empty segment
40func JoinPreservingTrailingSlash(elem ...string) string {
41 // do the basic path join
42 result := path.Join(elem...)
43
44 // find the last non-empty segment
45 for i := len(elem) - 1; i >= 0; i-- {
46 if len(elem[i]) > 0 {
47 // if the last segment ended in a slash, ensure our result does as well
48 if strings.HasSuffix(elem[i], "/") && !strings.HasSuffix(result, "/") {
49 result += "/"
50 }
51 break
52 }
53 }
54
55 return result
56}
57
58// IsProbableEOF returns true if the given error resembles a connection termination
59// scenario that would justify assuming that the watch is empty.
60// These errors are what the Go http stack returns back to us which are general
61// connection closure errors (strongly correlated) and callers that need to
62// differentiate probable errors in connection behavior between normal "this is
63// disconnected" should use the method.
64func IsProbableEOF(err error) bool {
65 if err == nil {
66 return false
67 }
68 if uerr, ok := err.(*url.Error); ok {
69 err = uerr.Err
70 }
71 switch {
72 case err == io.EOF:
73 return true
74 case err.Error() == "http: can't write HTTP request on broken connection":
75 return true
76 case strings.Contains(err.Error(), "connection reset by peer"):
77 return true
78 case strings.Contains(strings.ToLower(err.Error()), "use of closed network connection"):
79 return true
80 }
81 return false
82}
83
84var defaultTransport = http.DefaultTransport.(*http.Transport)
85
86// SetOldTransportDefaults applies the defaults from http.DefaultTransport
87// for the Proxy, Dial, and TLSHandshakeTimeout fields if unset
88func SetOldTransportDefaults(t *http.Transport) *http.Transport {
89 if t.Proxy == nil || isDefault(t.Proxy) {
90 // http.ProxyFromEnvironment doesn't respect CIDRs and that makes it impossible to exclude things like pod and service IPs from proxy settings
91 // ProxierWithNoProxyCIDR allows CIDR rules in NO_PROXY
92 t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
93 }
94 // If no custom dialer is set, use the default context dialer
95 if t.DialContext == nil && t.Dial == nil {
96 t.DialContext = defaultTransport.DialContext
97 }
98 if t.TLSHandshakeTimeout == 0 {
99 t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout
100 }
101 return t
102}
103
104// SetTransportDefaults applies the defaults from http.DefaultTransport
105// for the Proxy, Dial, and TLSHandshakeTimeout fields if unset
106func SetTransportDefaults(t *http.Transport) *http.Transport {
107 t = SetOldTransportDefaults(t)
108 // Allow clients to disable http2 if needed.
109 if s := os.Getenv("DISABLE_HTTP2"); len(s) > 0 {
110 glog.Infof("HTTP2 has been explicitly disabled")
111 } else {
112 if err := http2.ConfigureTransport(t); err != nil {
113 glog.Warningf("Transport failed http2 configuration: %v", err)
114 }
115 }
116 return t
117}
118
119type RoundTripperWrapper interface {
120 http.RoundTripper
121 WrappedRoundTripper() http.RoundTripper
122}
123
124type DialFunc func(ctx context.Context, net, addr string) (net.Conn, error)
125
126func DialerFor(transport http.RoundTripper) (DialFunc, error) {
127 if transport == nil {
128 return nil, nil
129 }
130
131 switch transport := transport.(type) {
132 case *http.Transport:
133 // transport.DialContext takes precedence over transport.Dial
134 if transport.DialContext != nil {
135 return transport.DialContext, nil
136 }
137 // adapt transport.Dial to the DialWithContext signature
138 if transport.Dial != nil {
139 return func(ctx context.Context, net, addr string) (net.Conn, error) {
140 return transport.Dial(net, addr)
141 }, nil
142 }
143 // otherwise return nil
144 return nil, nil
145 case RoundTripperWrapper:
146 return DialerFor(transport.WrappedRoundTripper())
147 default:
148 return nil, fmt.Errorf("unknown transport type: %T", transport)
149 }
150}
151
152type TLSClientConfigHolder interface {
153 TLSClientConfig() *tls.Config
154}
155
156func TLSClientConfig(transport http.RoundTripper) (*tls.Config, error) {
157 if transport == nil {
158 return nil, nil
159 }
160
161 switch transport := transport.(type) {
162 case *http.Transport:
163 return transport.TLSClientConfig, nil
164 case TLSClientConfigHolder:
165 return transport.TLSClientConfig(), nil
166 case RoundTripperWrapper:
167 return TLSClientConfig(transport.WrappedRoundTripper())
168 default:
169 return nil, fmt.Errorf("unknown transport type: %T", transport)
170 }
171}
172
173func FormatURL(scheme string, host string, port int, path string) *url.URL {
174 return &url.URL{
175 Scheme: scheme,
176 Host: net.JoinHostPort(host, strconv.Itoa(port)),
177 Path: path,
178 }
179}
180
181func GetHTTPClient(req *http.Request) string {
182 if ua := req.UserAgent(); len(ua) != 0 {
183 return ua
184 }
185 return "unknown"
186}
187
188// SourceIPs splits the comma separated X-Forwarded-For header or returns the X-Real-Ip header or req.RemoteAddr,
189// in that order, ignoring invalid IPs. It returns nil if all of these are empty or invalid.
190func SourceIPs(req *http.Request) []net.IP {
191 hdr := req.Header
192 // First check the X-Forwarded-For header for requests via proxy.
193 hdrForwardedFor := hdr.Get("X-Forwarded-For")
194 forwardedForIPs := []net.IP{}
195 if hdrForwardedFor != "" {
196 // X-Forwarded-For can be a csv of IPs in case of multiple proxies.
197 // Use the first valid one.
198 parts := strings.Split(hdrForwardedFor, ",")
199 for _, part := range parts {
200 ip := net.ParseIP(strings.TrimSpace(part))
201 if ip != nil {
202 forwardedForIPs = append(forwardedForIPs, ip)
203 }
204 }
205 }
206 if len(forwardedForIPs) > 0 {
207 return forwardedForIPs
208 }
209
210 // Try the X-Real-Ip header.
211 hdrRealIp := hdr.Get("X-Real-Ip")
212 if hdrRealIp != "" {
213 ip := net.ParseIP(hdrRealIp)
214 if ip != nil {
215 return []net.IP{ip}
216 }
217 }
218
219 // Fallback to Remote Address in request, which will give the correct client IP when there is no proxy.
220 // Remote Address in Go's HTTP server is in the form host:port so we need to split that first.
221 host, _, err := net.SplitHostPort(req.RemoteAddr)
222 if err == nil {
223 if remoteIP := net.ParseIP(host); remoteIP != nil {
224 return []net.IP{remoteIP}
225 }
226 }
227
228 // Fallback if Remote Address was just IP.
229 if remoteIP := net.ParseIP(req.RemoteAddr); remoteIP != nil {
230 return []net.IP{remoteIP}
231 }
232
233 return nil
234}
235
236// Extracts and returns the clients IP from the given request.
237// Looks at X-Forwarded-For header, X-Real-Ip header and request.RemoteAddr in that order.
238// Returns nil if none of them are set or is set to an invalid value.
239func GetClientIP(req *http.Request) net.IP {
240 ips := SourceIPs(req)
241 if len(ips) == 0 {
242 return nil
243 }
244 return ips[0]
245}
246
247// Prepares the X-Forwarded-For header for another forwarding hop by appending the previous sender's
248// IP address to the X-Forwarded-For chain.
249func AppendForwardedForHeader(req *http.Request) {
250 // Copied from net/http/httputil/reverseproxy.go:
251 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
252 // If we aren't the first proxy retain prior
253 // X-Forwarded-For information as a comma+space
254 // separated list and fold multiple headers into one.
255 if prior, ok := req.Header["X-Forwarded-For"]; ok {
256 clientIP = strings.Join(prior, ", ") + ", " + clientIP
257 }
258 req.Header.Set("X-Forwarded-For", clientIP)
259 }
260}
261
262var defaultProxyFuncPointer = fmt.Sprintf("%p", http.ProxyFromEnvironment)
263
264// isDefault checks to see if the transportProxierFunc is pointing to the default one
265func isDefault(transportProxier func(*http.Request) (*url.URL, error)) bool {
266 transportProxierPointer := fmt.Sprintf("%p", transportProxier)
267 return transportProxierPointer == defaultProxyFuncPointer
268}
269
270// NewProxierWithNoProxyCIDR constructs a Proxier function that respects CIDRs in NO_PROXY and delegates if
271// no matching CIDRs are found
272func NewProxierWithNoProxyCIDR(delegate func(req *http.Request) (*url.URL, error)) func(req *http.Request) (*url.URL, error) {
273 // we wrap the default method, so we only need to perform our check if the NO_PROXY (or no_proxy) envvar has a CIDR in it
274 noProxyEnv := os.Getenv("NO_PROXY")
275 if noProxyEnv == "" {
276 noProxyEnv = os.Getenv("no_proxy")
277 }
278 noProxyRules := strings.Split(noProxyEnv, ",")
279
280 cidrs := []*net.IPNet{}
281 for _, noProxyRule := range noProxyRules {
282 _, cidr, _ := net.ParseCIDR(noProxyRule)
283 if cidr != nil {
284 cidrs = append(cidrs, cidr)
285 }
286 }
287
288 if len(cidrs) == 0 {
289 return delegate
290 }
291
292 return func(req *http.Request) (*url.URL, error) {
293 ip := net.ParseIP(req.URL.Hostname())
294 if ip == nil {
295 return delegate(req)
296 }
297
298 for _, cidr := range cidrs {
299 if cidr.Contains(ip) {
300 return nil, nil
301 }
302 }
303
304 return delegate(req)
305 }
306}
307
308// DialerFunc implements Dialer for the provided function.
309type DialerFunc func(req *http.Request) (net.Conn, error)
310
311func (fn DialerFunc) Dial(req *http.Request) (net.Conn, error) {
312 return fn(req)
313}
314
315// Dialer dials a host and writes a request to it.
316type Dialer interface {
317 // Dial connects to the host specified by req's URL, writes the request to the connection, and
318 // returns the opened net.Conn.
319 Dial(req *http.Request) (net.Conn, error)
320}
321
322// ConnectWithRedirects uses dialer to send req, following up to 10 redirects (relative to
323// originalLocation). It returns the opened net.Conn and the raw response bytes.
324// If requireSameHostRedirects is true, only redirects to the same host are permitted.
325func ConnectWithRedirects(originalMethod string, originalLocation *url.URL, header http.Header, originalBody io.Reader, dialer Dialer, requireSameHostRedirects bool) (net.Conn, []byte, error) {
326 const (
327 maxRedirects = 9 // Fail on the 10th redirect
328 maxResponseSize = 16384 // play it safe to allow the potential for lots of / large headers
329 )
330
331 var (
332 location = originalLocation
333 method = originalMethod
334 intermediateConn net.Conn
335 rawResponse = bytes.NewBuffer(make([]byte, 0, 256))
336 body = originalBody
337 )
338
339 defer func() {
340 if intermediateConn != nil {
341 intermediateConn.Close()
342 }
343 }()
344
345redirectLoop:
346 for redirects := 0; ; redirects++ {
347 if redirects > maxRedirects {
348 return nil, nil, fmt.Errorf("too many redirects (%d)", redirects)
349 }
350
351 req, err := http.NewRequest(method, location.String(), body)
352 if err != nil {
353 return nil, nil, err
354 }
355
356 req.Header = header
357
358 intermediateConn, err = dialer.Dial(req)
359 if err != nil {
360 return nil, nil, err
361 }
362
363 // Peek at the backend response.
364 rawResponse.Reset()
365 respReader := bufio.NewReader(io.TeeReader(
366 io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes.
367 rawResponse)) // Save the raw response.
368 resp, err := http.ReadResponse(respReader, nil)
369 if err != nil {
370 // Unable to read the backend response; let the client handle it.
371 glog.Warningf("Error reading backend response: %v", err)
372 break redirectLoop
373 }
374
375 switch resp.StatusCode {
376 case http.StatusFound:
377 // Redirect, continue.
378 default:
379 // Don't redirect.
380 break redirectLoop
381 }
382
383 // Redirected requests switch to "GET" according to the HTTP spec:
384 // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3
385 method = "GET"
386 // don't send a body when following redirects
387 body = nil
388
389 resp.Body.Close() // not used
390
391 // Prepare to follow the redirect.
392 redirectStr := resp.Header.Get("Location")
393 if redirectStr == "" {
394 return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode)
395 }
396 // We have to parse relative to the current location, NOT originalLocation. For example,
397 // if we request http://foo.com/a and get back "http://bar.com/b", the result should be
398 // http://bar.com/b. If we then make that request and get back a redirect to "/c", the result
399 // should be http://bar.com/c, not http://foo.com/c.
400 location, err = location.Parse(redirectStr)
401 if err != nil {
402 return nil, nil, fmt.Errorf("malformed Location header: %v", err)
403 }
404
405 // Only follow redirects to the same host. Otherwise, propagate the redirect response back.
406 if requireSameHostRedirects && location.Hostname() != originalLocation.Hostname() {
407 break redirectLoop
408 }
409
410 // Reset the connection.
411 intermediateConn.Close()
412 intermediateConn = nil
413 }
414
415 connToReturn := intermediateConn
416 intermediateConn = nil // Don't close the connection when we return it.
417 return connToReturn, rawResponse.Bytes(), nil
418}
419
420// CloneRequest creates a shallow copy of the request along with a deep copy of the Headers.
421func CloneRequest(req *http.Request) *http.Request {
422 r := new(http.Request)
423
424 // shallow clone
425 *r = *req
426
427 // deep copy headers
428 r.Header = CloneHeader(req.Header)
429
430 return r
431}
432
433// CloneHeader creates a deep copy of an http.Header.
434func CloneHeader(in http.Header) http.Header {
435 out := make(http.Header, len(in))
436 for key, values := range in {
437 newValues := make([]string, len(values))
438 copy(newValues, values)
439 out[key] = newValues
440 }
441 return out
442}