blob: 2f5ddf35f30a1c43bc0b87f912c9042075ade0b1 [file] [log] [blame]
David K. Bainbridge528b3182017-01-23 08:51:59 -08001// Copyright 2013 Canonical Ltd.
2// Licensed under the LGPLv3, see LICENCE file for details.
3
4package utils
5
6import (
7 "crypto/tls"
8 "encoding/base64"
9 "fmt"
10 "net"
11 "net/http"
12 "strings"
13 "sync"
14)
15
16var insecureClient = (*http.Client)(nil)
17var insecureClientMutex = sync.Mutex{}
18
19func init() {
20 defaultTransport := http.DefaultTransport.(*http.Transport)
21 installHTTPDialShim(defaultTransport)
22 registerFileProtocol(defaultTransport)
23}
24
25// registerFileProtocol registers support for file:// URLs on the given transport.
26func registerFileProtocol(transport *http.Transport) {
27 transport.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
28}
29
30// SSLHostnameVerification is used as a switch for when a given provider might
31// use self-signed credentials and we should not try to verify the hostname on
32// the TLS/SSL certificates
33type SSLHostnameVerification bool
34
35const (
36 // VerifySSLHostnames ensures we verify the hostname on the certificate
37 // matches the host we are connecting and is signed
38 VerifySSLHostnames = SSLHostnameVerification(true)
39 // NoVerifySSLHostnames informs us to skip verifying the hostname
40 // matches a valid certificate
41 NoVerifySSLHostnames = SSLHostnameVerification(false)
42)
43
44// GetHTTPClient returns either a standard http client or
45// non validating client depending on the value of verify.
46func GetHTTPClient(verify SSLHostnameVerification) *http.Client {
47 if verify == VerifySSLHostnames {
48 return GetValidatingHTTPClient()
49 }
50 return GetNonValidatingHTTPClient()
51}
52
53// GetValidatingHTTPClient returns a new http.Client that
54// verifies the server's certificate chain and hostname.
55func GetValidatingHTTPClient() *http.Client {
56 return &http.Client{}
57}
58
59// GetNonValidatingHTTPClient returns a new http.Client that
60// does not verify the server's certificate chain and hostname.
61func GetNonValidatingHTTPClient() *http.Client {
62 return &http.Client{
63 Transport: NewHttpTLSTransport(&tls.Config{
64 InsecureSkipVerify: true,
65 }),
66 }
67}
68
69// BasicAuthHeader creates a header that contains just the "Authorization"
70// entry. The implementation was originally taked from net/http but this is
71// needed externally from the http request object in order to use this with
72// our websockets. See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt
73// "To receive authorization, the client sends the userid and password,
74// separated by a single colon (":") character, within a base64 encoded string
75// in the credentials."
76func BasicAuthHeader(username, password string) http.Header {
77 auth := username + ":" + password
78 encoded := "Basic " + base64.StdEncoding.EncodeToString([]byte(auth))
79 return http.Header{
80 "Authorization": {encoded},
81 }
82}
83
84// ParseBasicAuth attempts to find an Authorization header in the supplied
85// http.Header and if found parses it as a Basic header. See 2 (end of page 4)
86// http://www.ietf.org/rfc/rfc2617.txt "To receive authorization, the client
87// sends the userid and password, separated by a single colon (":") character,
88// within a base64 encoded string in the credentials."
89func ParseBasicAuthHeader(h http.Header) (userid, password string, err error) {
90 parts := strings.Fields(h.Get("Authorization"))
91 if len(parts) != 2 || parts[0] != "Basic" {
92 return "", "", fmt.Errorf("invalid or missing HTTP auth header")
93 }
94 // Challenge is a base64-encoded "tag:pass" string.
95 // See RFC 2617, Section 2.
96 challenge, err := base64.StdEncoding.DecodeString(parts[1])
97 if err != nil {
98 return "", "", fmt.Errorf("invalid HTTP auth encoding")
99 }
100 tokens := strings.SplitN(string(challenge), ":", 2)
101 if len(tokens) != 2 {
102 return "", "", fmt.Errorf("invalid HTTP auth contents")
103 }
104 return tokens[0], tokens[1], nil
105}
106
107// OutgoingAccessAllowed determines whether connections other than
108// localhost can be dialled.
109var OutgoingAccessAllowed = true
110
111func isLocalAddr(addr string) bool {
112 host, _, err := net.SplitHostPort(addr)
113 if err != nil {
114 return false
115 }
116 return host == "localhost" || net.ParseIP(host).IsLoopback()
117}