| /* |
| Copyright 2016 The Kubernetes Authors. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| */ |
| |
| package net |
| |
| import ( |
| "bufio" |
| "bytes" |
| "context" |
| "crypto/tls" |
| "fmt" |
| "io" |
| "net" |
| "net/http" |
| "net/url" |
| "os" |
| "path" |
| "strconv" |
| "strings" |
| |
| "golang.org/x/net/http2" |
| "k8s.io/klog" |
| ) |
| |
| // JoinPreservingTrailingSlash does a path.Join of the specified elements, |
| // preserving any trailing slash on the last non-empty segment |
| func JoinPreservingTrailingSlash(elem ...string) string { |
| // do the basic path join |
| result := path.Join(elem...) |
| |
| // find the last non-empty segment |
| for i := len(elem) - 1; i >= 0; i-- { |
| if len(elem[i]) > 0 { |
| // if the last segment ended in a slash, ensure our result does as well |
| if strings.HasSuffix(elem[i], "/") && !strings.HasSuffix(result, "/") { |
| result += "/" |
| } |
| break |
| } |
| } |
| |
| return result |
| } |
| |
| // IsProbableEOF returns true if the given error resembles a connection termination |
| // scenario that would justify assuming that the watch is empty. |
| // These errors are what the Go http stack returns back to us which are general |
| // connection closure errors (strongly correlated) and callers that need to |
| // differentiate probable errors in connection behavior between normal "this is |
| // disconnected" should use the method. |
| func IsProbableEOF(err error) bool { |
| if err == nil { |
| return false |
| } |
| if uerr, ok := err.(*url.Error); ok { |
| err = uerr.Err |
| } |
| msg := err.Error() |
| switch { |
| case err == io.EOF: |
| return true |
| case msg == "http: can't write HTTP request on broken connection": |
| return true |
| case strings.Contains(msg, "http2: server sent GOAWAY and closed the connection"): |
| return true |
| case strings.Contains(msg, "connection reset by peer"): |
| return true |
| case strings.Contains(strings.ToLower(msg), "use of closed network connection"): |
| return true |
| } |
| return false |
| } |
| |
| var defaultTransport = http.DefaultTransport.(*http.Transport) |
| |
| // SetOldTransportDefaults applies the defaults from http.DefaultTransport |
| // for the Proxy, Dial, and TLSHandshakeTimeout fields if unset |
| func SetOldTransportDefaults(t *http.Transport) *http.Transport { |
| if t.Proxy == nil || isDefault(t.Proxy) { |
| // http.ProxyFromEnvironment doesn't respect CIDRs and that makes it impossible to exclude things like pod and service IPs from proxy settings |
| // ProxierWithNoProxyCIDR allows CIDR rules in NO_PROXY |
| t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment) |
| } |
| // If no custom dialer is set, use the default context dialer |
| if t.DialContext == nil && t.Dial == nil { |
| t.DialContext = defaultTransport.DialContext |
| } |
| if t.TLSHandshakeTimeout == 0 { |
| t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout |
| } |
| return t |
| } |
| |
| // SetTransportDefaults applies the defaults from http.DefaultTransport |
| // for the Proxy, Dial, and TLSHandshakeTimeout fields if unset |
| func SetTransportDefaults(t *http.Transport) *http.Transport { |
| t = SetOldTransportDefaults(t) |
| // Allow clients to disable http2 if needed. |
| if s := os.Getenv("DISABLE_HTTP2"); len(s) > 0 { |
| klog.Infof("HTTP2 has been explicitly disabled") |
| } else { |
| if err := http2.ConfigureTransport(t); err != nil { |
| klog.Warningf("Transport failed http2 configuration: %v", err) |
| } |
| } |
| return t |
| } |
| |
| type RoundTripperWrapper interface { |
| http.RoundTripper |
| WrappedRoundTripper() http.RoundTripper |
| } |
| |
| type DialFunc func(ctx context.Context, net, addr string) (net.Conn, error) |
| |
| func DialerFor(transport http.RoundTripper) (DialFunc, error) { |
| if transport == nil { |
| return nil, nil |
| } |
| |
| switch transport := transport.(type) { |
| case *http.Transport: |
| // transport.DialContext takes precedence over transport.Dial |
| if transport.DialContext != nil { |
| return transport.DialContext, nil |
| } |
| // adapt transport.Dial to the DialWithContext signature |
| if transport.Dial != nil { |
| return func(ctx context.Context, net, addr string) (net.Conn, error) { |
| return transport.Dial(net, addr) |
| }, nil |
| } |
| // otherwise return nil |
| return nil, nil |
| case RoundTripperWrapper: |
| return DialerFor(transport.WrappedRoundTripper()) |
| default: |
| return nil, fmt.Errorf("unknown transport type: %T", transport) |
| } |
| } |
| |
| type TLSClientConfigHolder interface { |
| TLSClientConfig() *tls.Config |
| } |
| |
| func TLSClientConfig(transport http.RoundTripper) (*tls.Config, error) { |
| if transport == nil { |
| return nil, nil |
| } |
| |
| switch transport := transport.(type) { |
| case *http.Transport: |
| return transport.TLSClientConfig, nil |
| case TLSClientConfigHolder: |
| return transport.TLSClientConfig(), nil |
| case RoundTripperWrapper: |
| return TLSClientConfig(transport.WrappedRoundTripper()) |
| default: |
| return nil, fmt.Errorf("unknown transport type: %T", transport) |
| } |
| } |
| |
| func FormatURL(scheme string, host string, port int, path string) *url.URL { |
| return &url.URL{ |
| Scheme: scheme, |
| Host: net.JoinHostPort(host, strconv.Itoa(port)), |
| Path: path, |
| } |
| } |
| |
| func GetHTTPClient(req *http.Request) string { |
| if ua := req.UserAgent(); len(ua) != 0 { |
| return ua |
| } |
| return "unknown" |
| } |
| |
| // SourceIPs splits the comma separated X-Forwarded-For header or returns the X-Real-Ip header or req.RemoteAddr, |
| // in that order, ignoring invalid IPs. It returns nil if all of these are empty or invalid. |
| func SourceIPs(req *http.Request) []net.IP { |
| hdr := req.Header |
| // First check the X-Forwarded-For header for requests via proxy. |
| hdrForwardedFor := hdr.Get("X-Forwarded-For") |
| forwardedForIPs := []net.IP{} |
| if hdrForwardedFor != "" { |
| // X-Forwarded-For can be a csv of IPs in case of multiple proxies. |
| // Use the first valid one. |
| parts := strings.Split(hdrForwardedFor, ",") |
| for _, part := range parts { |
| ip := net.ParseIP(strings.TrimSpace(part)) |
| if ip != nil { |
| forwardedForIPs = append(forwardedForIPs, ip) |
| } |
| } |
| } |
| if len(forwardedForIPs) > 0 { |
| return forwardedForIPs |
| } |
| |
| // Try the X-Real-Ip header. |
| hdrRealIp := hdr.Get("X-Real-Ip") |
| if hdrRealIp != "" { |
| ip := net.ParseIP(hdrRealIp) |
| if ip != nil { |
| return []net.IP{ip} |
| } |
| } |
| |
| // Fallback to Remote Address in request, which will give the correct client IP when there is no proxy. |
| // Remote Address in Go's HTTP server is in the form host:port so we need to split that first. |
| host, _, err := net.SplitHostPort(req.RemoteAddr) |
| if err == nil { |
| if remoteIP := net.ParseIP(host); remoteIP != nil { |
| return []net.IP{remoteIP} |
| } |
| } |
| |
| // Fallback if Remote Address was just IP. |
| if remoteIP := net.ParseIP(req.RemoteAddr); remoteIP != nil { |
| return []net.IP{remoteIP} |
| } |
| |
| return nil |
| } |
| |
| // Extracts and returns the clients IP from the given request. |
| // Looks at X-Forwarded-For header, X-Real-Ip header and request.RemoteAddr in that order. |
| // Returns nil if none of them are set or is set to an invalid value. |
| func GetClientIP(req *http.Request) net.IP { |
| ips := SourceIPs(req) |
| if len(ips) == 0 { |
| return nil |
| } |
| return ips[0] |
| } |
| |
| // Prepares the X-Forwarded-For header for another forwarding hop by appending the previous sender's |
| // IP address to the X-Forwarded-For chain. |
| func AppendForwardedForHeader(req *http.Request) { |
| // Copied from net/http/httputil/reverseproxy.go: |
| if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { |
| // If we aren't the first proxy retain prior |
| // X-Forwarded-For information as a comma+space |
| // separated list and fold multiple headers into one. |
| if prior, ok := req.Header["X-Forwarded-For"]; ok { |
| clientIP = strings.Join(prior, ", ") + ", " + clientIP |
| } |
| req.Header.Set("X-Forwarded-For", clientIP) |
| } |
| } |
| |
| var defaultProxyFuncPointer = fmt.Sprintf("%p", http.ProxyFromEnvironment) |
| |
| // isDefault checks to see if the transportProxierFunc is pointing to the default one |
| func isDefault(transportProxier func(*http.Request) (*url.URL, error)) bool { |
| transportProxierPointer := fmt.Sprintf("%p", transportProxier) |
| return transportProxierPointer == defaultProxyFuncPointer |
| } |
| |
| // NewProxierWithNoProxyCIDR constructs a Proxier function that respects CIDRs in NO_PROXY and delegates if |
| // no matching CIDRs are found |
| func NewProxierWithNoProxyCIDR(delegate func(req *http.Request) (*url.URL, error)) func(req *http.Request) (*url.URL, error) { |
| // we wrap the default method, so we only need to perform our check if the NO_PROXY (or no_proxy) envvar has a CIDR in it |
| noProxyEnv := os.Getenv("NO_PROXY") |
| if noProxyEnv == "" { |
| noProxyEnv = os.Getenv("no_proxy") |
| } |
| noProxyRules := strings.Split(noProxyEnv, ",") |
| |
| cidrs := []*net.IPNet{} |
| for _, noProxyRule := range noProxyRules { |
| _, cidr, _ := net.ParseCIDR(noProxyRule) |
| if cidr != nil { |
| cidrs = append(cidrs, cidr) |
| } |
| } |
| |
| if len(cidrs) == 0 { |
| return delegate |
| } |
| |
| return func(req *http.Request) (*url.URL, error) { |
| ip := net.ParseIP(req.URL.Hostname()) |
| if ip == nil { |
| return delegate(req) |
| } |
| |
| for _, cidr := range cidrs { |
| if cidr.Contains(ip) { |
| return nil, nil |
| } |
| } |
| |
| return delegate(req) |
| } |
| } |
| |
| // DialerFunc implements Dialer for the provided function. |
| type DialerFunc func(req *http.Request) (net.Conn, error) |
| |
| func (fn DialerFunc) Dial(req *http.Request) (net.Conn, error) { |
| return fn(req) |
| } |
| |
| // Dialer dials a host and writes a request to it. |
| type Dialer interface { |
| // Dial connects to the host specified by req's URL, writes the request to the connection, and |
| // returns the opened net.Conn. |
| Dial(req *http.Request) (net.Conn, error) |
| } |
| |
| // ConnectWithRedirects uses dialer to send req, following up to 10 redirects (relative to |
| // originalLocation). It returns the opened net.Conn and the raw response bytes. |
| // If requireSameHostRedirects is true, only redirects to the same host are permitted. |
| func ConnectWithRedirects(originalMethod string, originalLocation *url.URL, header http.Header, originalBody io.Reader, dialer Dialer, requireSameHostRedirects bool) (net.Conn, []byte, error) { |
| const ( |
| maxRedirects = 9 // Fail on the 10th redirect |
| maxResponseSize = 16384 // play it safe to allow the potential for lots of / large headers |
| ) |
| |
| var ( |
| location = originalLocation |
| method = originalMethod |
| intermediateConn net.Conn |
| rawResponse = bytes.NewBuffer(make([]byte, 0, 256)) |
| body = originalBody |
| ) |
| |
| defer func() { |
| if intermediateConn != nil { |
| intermediateConn.Close() |
| } |
| }() |
| |
| redirectLoop: |
| for redirects := 0; ; redirects++ { |
| if redirects > maxRedirects { |
| return nil, nil, fmt.Errorf("too many redirects (%d)", redirects) |
| } |
| |
| req, err := http.NewRequest(method, location.String(), body) |
| if err != nil { |
| return nil, nil, err |
| } |
| |
| req.Header = header |
| |
| intermediateConn, err = dialer.Dial(req) |
| if err != nil { |
| return nil, nil, err |
| } |
| |
| // Peek at the backend response. |
| rawResponse.Reset() |
| respReader := bufio.NewReader(io.TeeReader( |
| io.LimitReader(intermediateConn, maxResponseSize), // Don't read more than maxResponseSize bytes. |
| rawResponse)) // Save the raw response. |
| resp, err := http.ReadResponse(respReader, nil) |
| if err != nil { |
| // Unable to read the backend response; let the client handle it. |
| klog.Warningf("Error reading backend response: %v", err) |
| break redirectLoop |
| } |
| |
| switch resp.StatusCode { |
| case http.StatusFound: |
| // Redirect, continue. |
| default: |
| // Don't redirect. |
| break redirectLoop |
| } |
| |
| // Redirected requests switch to "GET" according to the HTTP spec: |
| // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3 |
| method = "GET" |
| // don't send a body when following redirects |
| body = nil |
| |
| resp.Body.Close() // not used |
| |
| // Prepare to follow the redirect. |
| redirectStr := resp.Header.Get("Location") |
| if redirectStr == "" { |
| return nil, nil, fmt.Errorf("%d response missing Location header", resp.StatusCode) |
| } |
| // We have to parse relative to the current location, NOT originalLocation. For example, |
| // if we request http://foo.com/a and get back "http://bar.com/b", the result should be |
| // http://bar.com/b. If we then make that request and get back a redirect to "/c", the result |
| // should be http://bar.com/c, not http://foo.com/c. |
| location, err = location.Parse(redirectStr) |
| if err != nil { |
| return nil, nil, fmt.Errorf("malformed Location header: %v", err) |
| } |
| |
| // Only follow redirects to the same host. Otherwise, propagate the redirect response back. |
| if requireSameHostRedirects && location.Hostname() != originalLocation.Hostname() { |
| break redirectLoop |
| } |
| |
| // Reset the connection. |
| intermediateConn.Close() |
| intermediateConn = nil |
| } |
| |
| connToReturn := intermediateConn |
| intermediateConn = nil // Don't close the connection when we return it. |
| return connToReturn, rawResponse.Bytes(), nil |
| } |
| |
| // CloneRequest creates a shallow copy of the request along with a deep copy of the Headers. |
| func CloneRequest(req *http.Request) *http.Request { |
| r := new(http.Request) |
| |
| // shallow clone |
| *r = *req |
| |
| // deep copy headers |
| r.Header = CloneHeader(req.Header) |
| |
| return r |
| } |
| |
| // CloneHeader creates a deep copy of an http.Header. |
| func CloneHeader(in http.Header) http.Header { |
| out := make(http.Header, len(in)) |
| for key, values := range in { |
| newValues := make([]string, len(values)) |
| copy(newValues, values) |
| out[key] = newValues |
| } |
| return out |
| } |