blob: 078f00d9b979dee9286018f9b603a7b6bc347669 [file] [log] [blame]
Scott Bakere7144bc2019-10-01 14:16:47 -07001/*
2Copyright 2016 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package net
18
19import (
20 "bufio"
21 "bytes"
22 "context"
23 "crypto/tls"
24 "fmt"
25 "io"
26 "net"
27 "net/http"
28 "net/url"
29 "os"
30 "path"
31 "strconv"
32 "strings"
33
34 "golang.org/x/net/http2"
35 "k8s.io/klog"
36)
37
38// JoinPreservingTrailingSlash does a path.Join of the specified elements,
39// preserving any trailing slash on the last non-empty segment
40func JoinPreservingTrailingSlash(elem ...string) string {
41 // do the basic path join
42 result := path.Join(elem...)
43
44 // find the last non-empty segment
45 for i := len(elem) - 1; i >= 0; i-- {
46 if len(elem[i]) > 0 {
47 // if the last segment ended in a slash, ensure our result does as well
48 if strings.HasSuffix(elem[i], "/") && !strings.HasSuffix(result, "/") {
49 result += "/"
50 }
51 break
52 }
53 }
54
55 return result
56}
57
58// IsProbableEOF returns true if the given error resembles a connection termination
59// scenario that would justify assuming that the watch is empty.
60// These errors are what the Go http stack returns back to us which are general
61// connection closure errors (strongly correlated) and callers that need to
62// differentiate probable errors in connection behavior between normal "this is
63// disconnected" should use the method.
64func IsProbableEOF(err error) bool {
65 if err == nil {
66 return false
67 }
68 if uerr, ok := err.(*url.Error); ok {
69 err = uerr.Err
70 }
71 msg := err.Error()
72 switch {
73 case err == io.EOF:
74 return true
75 case msg == "http: can't write HTTP request on broken connection":
76 return true
77 case strings.Contains(msg, "http2: server sent GOAWAY and closed the connection"):
78 return true
79 case strings.Contains(msg, "connection reset by peer"):
80 return true
81 case strings.Contains(strings.ToLower(msg), "use of closed network connection"):
82 return true
83 }
84 return false
85}
86
87var defaultTransport = http.DefaultTransport.(*http.Transport)
88
89// SetOldTransportDefaults applies the defaults from http.DefaultTransport
90// for the Proxy, Dial, and TLSHandshakeTimeout fields if unset
91func SetOldTransportDefaults(t *http.Transport) *http.Transport {
92 if t.Proxy == nil || isDefault(t.Proxy) {
93 // http.ProxyFromEnvironment doesn't respect CIDRs and that makes it impossible to exclude things like pod and service IPs from proxy settings
94 // ProxierWithNoProxyCIDR allows CIDR rules in NO_PROXY
95 t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
96 }
97 // If no custom dialer is set, use the default context dialer
98 if t.DialContext == nil && t.Dial == nil {
99 t.DialContext = defaultTransport.DialContext
100 }
101 if t.TLSHandshakeTimeout == 0 {
102 t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout
103 }
104 return t
105}
106
107// SetTransportDefaults applies the defaults from http.DefaultTransport
108// for the Proxy, Dial, and TLSHandshakeTimeout fields if unset
109func SetTransportDefaults(t *http.Transport) *http.Transport {
110 t = SetOldTransportDefaults(t)
111 // Allow clients to disable http2 if needed.
112 if s := os.Getenv("DISABLE_HTTP2"); len(s) > 0 {
113 klog.Infof("HTTP2 has been explicitly disabled")
114 } else {
115 if err := http2.ConfigureTransport(t); err != nil {
116 klog.Warningf("Transport failed http2 configuration: %v", err)
117 }
118 }
119 return t
120}
121
122type RoundTripperWrapper interface {
123 http.RoundTripper
124 WrappedRoundTripper() http.RoundTripper
125}
126
127type DialFunc func(ctx context.Context, net, addr string) (net.Conn, error)
128
129func DialerFor(transport http.RoundTripper) (DialFunc, error) {
130 if transport == nil {
131 return nil, nil
132 }
133
134 switch transport := transport.(type) {
135 case *http.Transport:
136 // transport.DialContext takes precedence over transport.Dial
137 if transport.DialContext != nil {
138 return transport.DialContext, nil
139 }
140 // adapt transport.Dial to the DialWithContext signature
141 if transport.Dial != nil {
142 return func(ctx context.Context, net, addr string) (net.Conn, error) {
143 return transport.Dial(net, addr)
144 }, nil
145 }
146 // otherwise return nil
147 return nil, nil
148 case RoundTripperWrapper:
149 return DialerFor(transport.WrappedRoundTripper())
150 default:
151 return nil, fmt.Errorf("unknown transport type: %T", transport)
152 }
153}
154
155type TLSClientConfigHolder interface {
156 TLSClientConfig() *tls.Config
157}
158
159func TLSClientConfig(transport http.RoundTripper) (*tls.Config, error) {
160 if transport == nil {
161 return nil, nil
162 }
163
164 switch transport := transport.(type) {
165 case *http.Transport:
166 return transport.TLSClientConfig, nil
167 case TLSClientConfigHolder:
168 return transport.TLSClientConfig(), nil
169 case RoundTripperWrapper:
170 return TLSClientConfig(transport.WrappedRoundTripper())
171 default:
172 return nil, fmt.Errorf("unknown transport type: %T", transport)
173 }
174}
175
176func FormatURL(scheme string, host string, port int, path string) *url.URL {
177 return &url.URL{
178 Scheme: scheme,
179 Host: net.JoinHostPort(host, strconv.Itoa(port)),
180 Path: path,
181 }
182}
183
184func GetHTTPClient(req *http.Request) string {
185 if ua := req.UserAgent(); len(ua) != 0 {
186 return ua
187 }
188 return "unknown"
189}
190
191// SourceIPs splits the comma separated X-Forwarded-For header or returns the X-Real-Ip header or req.RemoteAddr,
192// in that order, ignoring invalid IPs. It returns nil if all of these are empty or invalid.
193func SourceIPs(req *http.Request) []net.IP {
194 hdr := req.Header
195 // First check the X-Forwarded-For header for requests via proxy.
196 hdrForwardedFor := hdr.Get("X-Forwarded-For")
197 forwardedForIPs := []net.IP{}
198 if hdrForwardedFor != "" {
199 // X-Forwarded-For can be a csv of IPs in case of multiple proxies.
200 // Use the first valid one.
201 parts := strings.Split(hdrForwardedFor, ",")
202 for _, part := range parts {
203 ip := net.ParseIP(strings.TrimSpace(part))
204 if ip != nil {
205 forwardedForIPs = append(forwardedForIPs, ip)
206 }
207 }
208 }
209 if len(forwardedForIPs) > 0 {
210 return forwardedForIPs
211 }
212
213 // Try the X-Real-Ip header.
214 hdrRealIp := hdr.Get("X-Real-Ip")
215 if hdrRealIp != "" {
216 ip := net.ParseIP(hdrRealIp)
217 if ip != nil {
218 return []net.IP{ip}
219 }
220 }
221
222 // Fallback to Remote Address in request, which will give the correct client IP when there is no proxy.
223 // Remote Address in Go's HTTP server is in the form host:port so we need to split that first.
224 host, _, err := net.SplitHostPort(req.RemoteAddr)
225 if err == nil {
226 if remoteIP := net.ParseIP(host); remoteIP != nil {
227 return []net.IP{remoteIP}
228 }
229 }
230
231 // Fallback if Remote Address was just IP.
232 if remoteIP := net.ParseIP(req.RemoteAddr); remoteIP != nil {
233 return []net.IP{remoteIP}
234 }
235
236 return nil
237}
238
239// Extracts and returns the clients IP from the given request.
240// Looks at X-Forwarded-For header, X-Real-Ip header and request.RemoteAddr in that order.
241// Returns nil if none of them are set or is set to an invalid value.
242func GetClientIP(req *http.Request) net.IP {
243 ips := SourceIPs(req)
244 if len(ips) == 0 {
245 return nil
246 }
247 return ips[0]
248}
249
250// Prepares the X-Forwarded-For header for another forwarding hop by appending the previous sender's
251// IP address to the X-Forwarded-For chain.
252func AppendForwardedForHeader(req *http.Request) {
253 // Copied from net/http/httputil/reverseproxy.go:
254 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
255 // If we aren't the first proxy retain prior
256 // X-Forwarded-For information as a comma+space
257 // separated list and fold multiple headers into one.
258 if prior, ok := req.Header["X-Forwarded-For"]; ok {
259 clientIP = strings.Join(prior, ", ") + ", " + clientIP
260 }
261 req.Header.Set("X-Forwarded-For", clientIP)
262 }
263}
264
265var defaultProxyFuncPointer = fmt.Sprintf("%p", http.ProxyFromEnvironment)
266
267// isDefault checks to see if the transportProxierFunc is pointing to the default one
268func isDefault(transportProxier func(*http.Request) (*url.URL, error)) bool {
269 transportProxierPointer := fmt.Sprintf("%p", transportProxier)
270 return transportProxierPointer == defaultProxyFuncPointer
271}
272
273// NewProxierWithNoProxyCIDR constructs a Proxier function that respects CIDRs in NO_PROXY and delegates if
274// no matching CIDRs are found
275func NewProxierWithNoProxyCIDR(delegate func(req *http.Request) (*url.URL, error)) func(req *http.Request) (*url.URL, error) {
276 // we wrap the default method, so we only need to perform our check if the NO_PROXY (or no_proxy) envvar has a CIDR in it
277 noProxyEnv := os.Getenv("NO_PROXY")
278 if noProxyEnv == "" {
279 noProxyEnv = os.Getenv("no_proxy")
280 }
281 noProxyRules := strings.Split(noProxyEnv, ",")
282
283 cidrs := []*net.IPNet{}
284 for _, noProxyRule := range noProxyRules {
285 _, cidr, _ := net.ParseCIDR(noProxyRule)
286 if cidr != nil {
287 cidrs = append(cidrs, cidr)
288 }
289 }
290
291 if len(cidrs) == 0 {
292 return delegate
293 }
294
295 return func(req *http.Request) (*url.URL, error) {
296 ip := net.ParseIP(req.URL.Hostname())
297 if ip == nil {
298 return delegate(req)
299 }
300
301 for _, cidr := range cidrs {
302 if cidr.Contains(ip) {
303 return nil, nil
304 }
305 }
306
307 return delegate(req)
308 }
309}
310
311// DialerFunc implements Dialer for the provided function.
312type DialerFunc func(req *http.Request) (net.Conn, error)
313
314func (fn DialerFunc) Dial(req *http.Request) (net.Conn, error) {
315 return fn(req)
316}
317
318// Dialer dials a host and writes a request to it.
319type Dialer interface {
320 // Dial connects to the host specified by req's URL, writes the request to the connection, and
321 // returns the opened net.Conn.
322 Dial(req *http.Request) (net.Conn, error)
323}
324
325// ConnectWithRedirects uses dialer to send req, following up to 10 redirects (relative to
326// originalLocation). It returns the opened net.Conn and the raw response bytes.
327// If requireSameHostRedirects is true, only redirects to the same host are permitted.
328func ConnectWithRedirects(originalMethod string, originalLocation *url.URL, header http.Header, originalBody io.Reader, dialer Dialer, requireSameHostRedirects bool) (net.Conn, []byte, error) {
329 const (
330 maxRedirects = 9 // Fail on the 10th redirect
331 maxResponseSize = 16384 // play it safe to allow the potential for lots of / large headers
332 )
333
334 var (
335 location = originalLocation
336 method = originalMethod
337 intermediateConn net.Conn
338 rawResponse = bytes.NewBuffer(make([]byte, 0, 256))
339 body = originalBody
340 )
341
342 defer func() {
343 if intermediateConn != nil {
344 intermediateConn.Close()
345 }
346 }()
347
348redirectLoop:
349 for redirects := 0; ; redirects++ {
350 if redirects > maxRedirects {
351 return nil, nil, fmt.Errorf("too many redirects (%d)", redirects)
352 }
353
354 req, err := http.NewRequest(method, location.String(), body)
355 if err != nil {
356 return nil, nil, err
357 }
358
359 req.Header = header
360
361 intermediateConn, err = dialer.Dial(req)
362 if err != nil {
363 return nil, nil, err
364 }
365
366 // Peek at the backend response.
367 rawResponse.Reset()
368 respReader := bufio.NewReader(io.TeeReader(
369 io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes.
370 rawResponse)) // Save the raw response.
371 resp, err := http.ReadResponse(respReader, nil)
372 if err != nil {
373 // Unable to read the backend response; let the client handle it.
374 klog.Warningf("Error reading backend response: %v", err)
375 break redirectLoop
376 }
377
378 switch resp.StatusCode {
379 case http.StatusFound:
380 // Redirect, continue.
381 default:
382 // Don't redirect.
383 break redirectLoop
384 }
385
386 // Redirected requests switch to "GET" according to the HTTP spec:
387 // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3
388 method = "GET"
389 // don't send a body when following redirects
390 body = nil
391
392 resp.Body.Close() // not used
393
394 // Prepare to follow the redirect.
395 redirectStr := resp.Header.Get("Location")
396 if redirectStr == "" {
397 return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode)
398 }
399 // We have to parse relative to the current location, NOT originalLocation. For example,
400 // if we request http://foo.com/a and get back "http://bar.com/b", the result should be
401 // http://bar.com/b. If we then make that request and get back a redirect to "/c", the result
402 // should be http://bar.com/c, not http://foo.com/c.
403 location, err = location.Parse(redirectStr)
404 if err != nil {
405 return nil, nil, fmt.Errorf("malformed Location header: %v", err)
406 }
407
408 // Only follow redirects to the same host. Otherwise, propagate the redirect response back.
409 if requireSameHostRedirects && location.Hostname() != originalLocation.Hostname() {
410 break redirectLoop
411 }
412
413 // Reset the connection.
414 intermediateConn.Close()
415 intermediateConn = nil
416 }
417
418 connToReturn := intermediateConn
419 intermediateConn = nil // Don't close the connection when we return it.
420 return connToReturn, rawResponse.Bytes(), nil
421}
422
423// CloneRequest creates a shallow copy of the request along with a deep copy of the Headers.
424func CloneRequest(req *http.Request) *http.Request {
425 r := new(http.Request)
426
427 // shallow clone
428 *r = *req
429
430 // deep copy headers
431 r.Header = CloneHeader(req.Header)
432
433 return r
434}
435
436// CloneHeader creates a deep copy of an http.Header.
437func CloneHeader(in http.Header) http.Header {
438 out := make(http.Header, len(in))
439 for key, values := range in {
440 newValues := make([]string, len(values))
441 copy(newValues, values)
442 out[key] = newValues
443 }
444 return out
445}