blob: d88176cc3d590dbadebc106ba77f889e2e1be4ac [file] [log] [blame]
David K. Bainbridge215e0242017-09-05 23:18:24 -07001package libtrust
2
3import (
4 "bytes"
5 "crypto"
6 "crypto/elliptic"
7 "crypto/tls"
8 "crypto/x509"
9 "encoding/base32"
10 "encoding/base64"
11 "encoding/binary"
12 "encoding/pem"
13 "errors"
14 "fmt"
15 "math/big"
16 "net/url"
17 "os"
18 "path/filepath"
19 "strings"
20 "time"
21)
22
23// LoadOrCreateTrustKey will load a PrivateKey from the specified path
24func LoadOrCreateTrustKey(trustKeyPath string) (PrivateKey, error) {
25 if err := os.MkdirAll(filepath.Dir(trustKeyPath), 0700); err != nil {
26 return nil, err
27 }
28
29 trustKey, err := LoadKeyFile(trustKeyPath)
30 if err == ErrKeyFileDoesNotExist {
31 trustKey, err = GenerateECP256PrivateKey()
32 if err != nil {
33 return nil, fmt.Errorf("error generating key: %s", err)
34 }
35
36 if err := SaveKey(trustKeyPath, trustKey); err != nil {
37 return nil, fmt.Errorf("error saving key file: %s", err)
38 }
39
40 dir, file := filepath.Split(trustKeyPath)
41 if err := SavePublicKey(filepath.Join(dir, "public-"+file), trustKey.PublicKey()); err != nil {
42 return nil, fmt.Errorf("error saving public key file: %s", err)
43 }
44 } else if err != nil {
45 return nil, fmt.Errorf("error loading key file: %s", err)
46 }
47 return trustKey, nil
48}
49
50// NewIdentityAuthTLSClientConfig returns a tls.Config configured to use identity
51// based authentication from the specified dockerUrl, the rootConfigPath and
52// the server name to which it is connecting.
53// If trustUnknownHosts is true it will automatically add the host to the
54// known-hosts.json in rootConfigPath.
55func NewIdentityAuthTLSClientConfig(dockerUrl string, trustUnknownHosts bool, rootConfigPath string, serverName string) (*tls.Config, error) {
56 tlsConfig := newTLSConfig()
57
58 trustKeyPath := filepath.Join(rootConfigPath, "key.json")
59 knownHostsPath := filepath.Join(rootConfigPath, "known-hosts.json")
60
61 u, err := url.Parse(dockerUrl)
62 if err != nil {
63 return nil, fmt.Errorf("unable to parse machine url")
64 }
65
66 if u.Scheme == "unix" {
67 return nil, nil
68 }
69
70 addr := u.Host
71 proto := "tcp"
72
73 trustKey, err := LoadOrCreateTrustKey(trustKeyPath)
74 if err != nil {
75 return nil, fmt.Errorf("unable to load trust key: %s", err)
76 }
77
78 knownHosts, err := LoadKeySetFile(knownHostsPath)
79 if err != nil {
80 return nil, fmt.Errorf("could not load trusted hosts file: %s", err)
81 }
82
83 allowedHosts, err := FilterByHosts(knownHosts, addr, false)
84 if err != nil {
85 return nil, fmt.Errorf("error filtering hosts: %s", err)
86 }
87
88 certPool, err := GenerateCACertPool(trustKey, allowedHosts)
89 if err != nil {
90 return nil, fmt.Errorf("Could not create CA pool: %s", err)
91 }
92
93 tlsConfig.ServerName = serverName
94 tlsConfig.RootCAs = certPool
95
96 x509Cert, err := GenerateSelfSignedClientCert(trustKey)
97 if err != nil {
98 return nil, fmt.Errorf("certificate generation error: %s", err)
99 }
100
101 tlsConfig.Certificates = []tls.Certificate{{
102 Certificate: [][]byte{x509Cert.Raw},
103 PrivateKey: trustKey.CryptoPrivateKey(),
104 Leaf: x509Cert,
105 }}
106
107 tlsConfig.InsecureSkipVerify = true
108
109 testConn, err := tls.Dial(proto, addr, tlsConfig)
110 if err != nil {
111 return nil, fmt.Errorf("tls Handshake error: %s", err)
112 }
113
114 opts := x509.VerifyOptions{
115 Roots: tlsConfig.RootCAs,
116 CurrentTime: time.Now(),
117 DNSName: tlsConfig.ServerName,
118 Intermediates: x509.NewCertPool(),
119 }
120
121 certs := testConn.ConnectionState().PeerCertificates
122 for i, cert := range certs {
123 if i == 0 {
124 continue
125 }
126 opts.Intermediates.AddCert(cert)
127 }
128
129 if _, err := certs[0].Verify(opts); err != nil {
130 if _, ok := err.(x509.UnknownAuthorityError); ok {
131 if trustUnknownHosts {
132 pubKey, err := FromCryptoPublicKey(certs[0].PublicKey)
133 if err != nil {
134 return nil, fmt.Errorf("error extracting public key from cert: %s", err)
135 }
136
137 pubKey.AddExtendedField("hosts", []string{addr})
138
139 if err := AddKeySetFile(knownHostsPath, pubKey); err != nil {
140 return nil, fmt.Errorf("error adding machine to known hosts: %s", err)
141 }
142 } else {
143 return nil, fmt.Errorf("unable to connect. unknown host: %s", addr)
144 }
145 }
146 }
147
148 testConn.Close()
149 tlsConfig.InsecureSkipVerify = false
150
151 return tlsConfig, nil
152}
153
154// joseBase64UrlEncode encodes the given data using the standard base64 url
155// encoding format but with all trailing '=' characters ommitted in accordance
156// with the jose specification.
157// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
158func joseBase64UrlEncode(b []byte) string {
159 return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
160}
161
162// joseBase64UrlDecode decodes the given string using the standard base64 url
163// decoder but first adds the appropriate number of trailing '=' characters in
164// accordance with the jose specification.
165// http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2
166func joseBase64UrlDecode(s string) ([]byte, error) {
167 s = strings.Replace(s, "\n", "", -1)
168 s = strings.Replace(s, " ", "", -1)
169 switch len(s) % 4 {
170 case 0:
171 case 2:
172 s += "=="
173 case 3:
174 s += "="
175 default:
176 return nil, errors.New("illegal base64url string")
177 }
178 return base64.URLEncoding.DecodeString(s)
179}
180
181func keyIDEncode(b []byte) string {
182 s := strings.TrimRight(base32.StdEncoding.EncodeToString(b), "=")
183 var buf bytes.Buffer
184 var i int
185 for i = 0; i < len(s)/4-1; i++ {
186 start := i * 4
187 end := start + 4
188 buf.WriteString(s[start:end] + ":")
189 }
190 buf.WriteString(s[i*4:])
191 return buf.String()
192}
193
194func keyIDFromCryptoKey(pubKey PublicKey) string {
195 // Generate and return a 'libtrust' fingerprint of the public key.
196 // For an RSA key this should be:
197 // SHA256(DER encoded ASN1)
198 // Then truncated to 240 bits and encoded into 12 base32 groups like so:
199 // ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP
200 derBytes, err := x509.MarshalPKIXPublicKey(pubKey.CryptoPublicKey())
201 if err != nil {
202 return ""
203 }
204 hasher := crypto.SHA256.New()
205 hasher.Write(derBytes)
206 return keyIDEncode(hasher.Sum(nil)[:30])
207}
208
209func stringFromMap(m map[string]interface{}, key string) (string, error) {
210 val, ok := m[key]
211 if !ok {
212 return "", fmt.Errorf("%q value not specified", key)
213 }
214
215 str, ok := val.(string)
216 if !ok {
217 return "", fmt.Errorf("%q value must be a string", key)
218 }
219 delete(m, key)
220
221 return str, nil
222}
223
224func parseECCoordinate(cB64Url string, curve elliptic.Curve) (*big.Int, error) {
225 curveByteLen := (curve.Params().BitSize + 7) >> 3
226
227 cBytes, err := joseBase64UrlDecode(cB64Url)
228 if err != nil {
229 return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
230 }
231 cByteLength := len(cBytes)
232 if cByteLength != curveByteLen {
233 return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", cByteLength, curveByteLen)
234 }
235 return new(big.Int).SetBytes(cBytes), nil
236}
237
238func parseECPrivateParam(dB64Url string, curve elliptic.Curve) (*big.Int, error) {
239 dBytes, err := joseBase64UrlDecode(dB64Url)
240 if err != nil {
241 return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
242 }
243
244 // The length of this octet string MUST be ceiling(log-base-2(n)/8)
245 // octets (where n is the order of the curve). This is because the private
246 // key d must be in the interval [1, n-1] so the bitlength of d should be
247 // no larger than the bitlength of n-1. The easiest way to find the octet
248 // length is to take bitlength(n-1), add 7 to force a carry, and shift this
249 // bit sequence right by 3, which is essentially dividing by 8 and adding
250 // 1 if there is any remainder. Thus, the private key value d should be
251 // output to (bitlength(n-1)+7)>>3 octets.
252 n := curve.Params().N
253 octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
254 dByteLength := len(dBytes)
255
256 if dByteLength != octetLength {
257 return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", dByteLength, octetLength)
258 }
259
260 return new(big.Int).SetBytes(dBytes), nil
261}
262
263func parseRSAModulusParam(nB64Url string) (*big.Int, error) {
264 nBytes, err := joseBase64UrlDecode(nB64Url)
265 if err != nil {
266 return nil, fmt.Errorf("invalid base64 URL encoding: %s", err)
267 }
268
269 return new(big.Int).SetBytes(nBytes), nil
270}
271
272func serializeRSAPublicExponentParam(e int) []byte {
273 // We MUST use the minimum number of octets to represent E.
274 // E is supposed to be 65537 for performance and security reasons
275 // and is what golang's rsa package generates, but it might be
276 // different if imported from some other generator.
277 buf := make([]byte, 4)
278 binary.BigEndian.PutUint32(buf, uint32(e))
279 var i int
280 for i = 0; i < 8; i++ {
281 if buf[i] != 0 {
282 break
283 }
284 }
285 return buf[i:]
286}
287
288func parseRSAPublicExponentParam(eB64Url string) (int, error) {
289 eBytes, err := joseBase64UrlDecode(eB64Url)
290 if err != nil {
291 return 0, fmt.Errorf("invalid base64 URL encoding: %s", err)
292 }
293 // Only the minimum number of bytes were used to represent E, but
294 // binary.BigEndian.Uint32 expects at least 4 bytes, so we need
295 // to add zero padding if necassary.
296 byteLen := len(eBytes)
297 buf := make([]byte, 4-byteLen, 4)
298 eBytes = append(buf, eBytes...)
299
300 return int(binary.BigEndian.Uint32(eBytes)), nil
301}
302
303func parseRSAPrivateKeyParamFromMap(m map[string]interface{}, key string) (*big.Int, error) {
304 b64Url, err := stringFromMap(m, key)
305 if err != nil {
306 return nil, err
307 }
308
309 paramBytes, err := joseBase64UrlDecode(b64Url)
310 if err != nil {
311 return nil, fmt.Errorf("invaled base64 URL encoding: %s", err)
312 }
313
314 return new(big.Int).SetBytes(paramBytes), nil
315}
316
317func createPemBlock(name string, derBytes []byte, headers map[string]interface{}) (*pem.Block, error) {
318 pemBlock := &pem.Block{Type: name, Bytes: derBytes, Headers: map[string]string{}}
319 for k, v := range headers {
320 switch val := v.(type) {
321 case string:
322 pemBlock.Headers[k] = val
323 case []string:
324 if k == "hosts" {
325 pemBlock.Headers[k] = strings.Join(val, ",")
326 } else {
327 // Return error, non-encodable type
328 }
329 default:
330 // Return error, non-encodable type
331 }
332 }
333
334 return pemBlock, nil
335}
336
337func pubKeyFromPEMBlock(pemBlock *pem.Block) (PublicKey, error) {
338 cryptoPublicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes)
339 if err != nil {
340 return nil, fmt.Errorf("unable to decode Public Key PEM data: %s", err)
341 }
342
343 pubKey, err := FromCryptoPublicKey(cryptoPublicKey)
344 if err != nil {
345 return nil, err
346 }
347
348 addPEMHeadersToKey(pemBlock, pubKey)
349
350 return pubKey, nil
351}
352
353func addPEMHeadersToKey(pemBlock *pem.Block, pubKey PublicKey) {
354 for key, value := range pemBlock.Headers {
355 var safeVal interface{}
356 if key == "hosts" {
357 safeVal = strings.Split(value, ",")
358 } else {
359 safeVal = value
360 }
361 pubKey.AddExtendedField(key, safeVal)
362 }
363}