David K. Bainbridge | 215e024 | 2017-09-05 23:18:24 -0700 | [diff] [blame] | 1 | package libtrust |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "crypto" |
| 6 | "crypto/elliptic" |
| 7 | "crypto/tls" |
| 8 | "crypto/x509" |
| 9 | "encoding/base32" |
| 10 | "encoding/base64" |
| 11 | "encoding/binary" |
| 12 | "encoding/pem" |
| 13 | "errors" |
| 14 | "fmt" |
| 15 | "math/big" |
| 16 | "net/url" |
| 17 | "os" |
| 18 | "path/filepath" |
| 19 | "strings" |
| 20 | "time" |
| 21 | ) |
| 22 | |
| 23 | // LoadOrCreateTrustKey will load a PrivateKey from the specified path |
| 24 | func LoadOrCreateTrustKey(trustKeyPath string) (PrivateKey, error) { |
| 25 | if err := os.MkdirAll(filepath.Dir(trustKeyPath), 0700); err != nil { |
| 26 | return nil, err |
| 27 | } |
| 28 | |
| 29 | trustKey, err := LoadKeyFile(trustKeyPath) |
| 30 | if err == ErrKeyFileDoesNotExist { |
| 31 | trustKey, err = GenerateECP256PrivateKey() |
| 32 | if err != nil { |
| 33 | return nil, fmt.Errorf("error generating key: %s", err) |
| 34 | } |
| 35 | |
| 36 | if err := SaveKey(trustKeyPath, trustKey); err != nil { |
| 37 | return nil, fmt.Errorf("error saving key file: %s", err) |
| 38 | } |
| 39 | |
| 40 | dir, file := filepath.Split(trustKeyPath) |
| 41 | if err := SavePublicKey(filepath.Join(dir, "public-"+file), trustKey.PublicKey()); err != nil { |
| 42 | return nil, fmt.Errorf("error saving public key file: %s", err) |
| 43 | } |
| 44 | } else if err != nil { |
| 45 | return nil, fmt.Errorf("error loading key file: %s", err) |
| 46 | } |
| 47 | return trustKey, nil |
| 48 | } |
| 49 | |
| 50 | // NewIdentityAuthTLSClientConfig returns a tls.Config configured to use identity |
| 51 | // based authentication from the specified dockerUrl, the rootConfigPath and |
| 52 | // the server name to which it is connecting. |
| 53 | // If trustUnknownHosts is true it will automatically add the host to the |
| 54 | // known-hosts.json in rootConfigPath. |
| 55 | func NewIdentityAuthTLSClientConfig(dockerUrl string, trustUnknownHosts bool, rootConfigPath string, serverName string) (*tls.Config, error) { |
| 56 | tlsConfig := newTLSConfig() |
| 57 | |
| 58 | trustKeyPath := filepath.Join(rootConfigPath, "key.json") |
| 59 | knownHostsPath := filepath.Join(rootConfigPath, "known-hosts.json") |
| 60 | |
| 61 | u, err := url.Parse(dockerUrl) |
| 62 | if err != nil { |
| 63 | return nil, fmt.Errorf("unable to parse machine url") |
| 64 | } |
| 65 | |
| 66 | if u.Scheme == "unix" { |
| 67 | return nil, nil |
| 68 | } |
| 69 | |
| 70 | addr := u.Host |
| 71 | proto := "tcp" |
| 72 | |
| 73 | trustKey, err := LoadOrCreateTrustKey(trustKeyPath) |
| 74 | if err != nil { |
| 75 | return nil, fmt.Errorf("unable to load trust key: %s", err) |
| 76 | } |
| 77 | |
| 78 | knownHosts, err := LoadKeySetFile(knownHostsPath) |
| 79 | if err != nil { |
| 80 | return nil, fmt.Errorf("could not load trusted hosts file: %s", err) |
| 81 | } |
| 82 | |
| 83 | allowedHosts, err := FilterByHosts(knownHosts, addr, false) |
| 84 | if err != nil { |
| 85 | return nil, fmt.Errorf("error filtering hosts: %s", err) |
| 86 | } |
| 87 | |
| 88 | certPool, err := GenerateCACertPool(trustKey, allowedHosts) |
| 89 | if err != nil { |
| 90 | return nil, fmt.Errorf("Could not create CA pool: %s", err) |
| 91 | } |
| 92 | |
| 93 | tlsConfig.ServerName = serverName |
| 94 | tlsConfig.RootCAs = certPool |
| 95 | |
| 96 | x509Cert, err := GenerateSelfSignedClientCert(trustKey) |
| 97 | if err != nil { |
| 98 | return nil, fmt.Errorf("certificate generation error: %s", err) |
| 99 | } |
| 100 | |
| 101 | tlsConfig.Certificates = []tls.Certificate{{ |
| 102 | Certificate: [][]byte{x509Cert.Raw}, |
| 103 | PrivateKey: trustKey.CryptoPrivateKey(), |
| 104 | Leaf: x509Cert, |
| 105 | }} |
| 106 | |
| 107 | tlsConfig.InsecureSkipVerify = true |
| 108 | |
| 109 | testConn, err := tls.Dial(proto, addr, tlsConfig) |
| 110 | if err != nil { |
| 111 | return nil, fmt.Errorf("tls Handshake error: %s", err) |
| 112 | } |
| 113 | |
| 114 | opts := x509.VerifyOptions{ |
| 115 | Roots: tlsConfig.RootCAs, |
| 116 | CurrentTime: time.Now(), |
| 117 | DNSName: tlsConfig.ServerName, |
| 118 | Intermediates: x509.NewCertPool(), |
| 119 | } |
| 120 | |
| 121 | certs := testConn.ConnectionState().PeerCertificates |
| 122 | for i, cert := range certs { |
| 123 | if i == 0 { |
| 124 | continue |
| 125 | } |
| 126 | opts.Intermediates.AddCert(cert) |
| 127 | } |
| 128 | |
| 129 | if _, err := certs[0].Verify(opts); err != nil { |
| 130 | if _, ok := err.(x509.UnknownAuthorityError); ok { |
| 131 | if trustUnknownHosts { |
| 132 | pubKey, err := FromCryptoPublicKey(certs[0].PublicKey) |
| 133 | if err != nil { |
| 134 | return nil, fmt.Errorf("error extracting public key from cert: %s", err) |
| 135 | } |
| 136 | |
| 137 | pubKey.AddExtendedField("hosts", []string{addr}) |
| 138 | |
| 139 | if err := AddKeySetFile(knownHostsPath, pubKey); err != nil { |
| 140 | return nil, fmt.Errorf("error adding machine to known hosts: %s", err) |
| 141 | } |
| 142 | } else { |
| 143 | return nil, fmt.Errorf("unable to connect. unknown host: %s", addr) |
| 144 | } |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | testConn.Close() |
| 149 | tlsConfig.InsecureSkipVerify = false |
| 150 | |
| 151 | return tlsConfig, nil |
| 152 | } |
| 153 | |
| 154 | // joseBase64UrlEncode encodes the given data using the standard base64 url |
| 155 | // encoding format but with all trailing '=' characters ommitted in accordance |
| 156 | // with the jose specification. |
| 157 | // http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2 |
| 158 | func joseBase64UrlEncode(b []byte) string { |
| 159 | return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=") |
| 160 | } |
| 161 | |
| 162 | // joseBase64UrlDecode decodes the given string using the standard base64 url |
| 163 | // decoder but first adds the appropriate number of trailing '=' characters in |
| 164 | // accordance with the jose specification. |
| 165 | // http://tools.ietf.org/html/draft-ietf-jose-json-web-signature-31#section-2 |
| 166 | func joseBase64UrlDecode(s string) ([]byte, error) { |
| 167 | s = strings.Replace(s, "\n", "", -1) |
| 168 | s = strings.Replace(s, " ", "", -1) |
| 169 | switch len(s) % 4 { |
| 170 | case 0: |
| 171 | case 2: |
| 172 | s += "==" |
| 173 | case 3: |
| 174 | s += "=" |
| 175 | default: |
| 176 | return nil, errors.New("illegal base64url string") |
| 177 | } |
| 178 | return base64.URLEncoding.DecodeString(s) |
| 179 | } |
| 180 | |
| 181 | func keyIDEncode(b []byte) string { |
| 182 | s := strings.TrimRight(base32.StdEncoding.EncodeToString(b), "=") |
| 183 | var buf bytes.Buffer |
| 184 | var i int |
| 185 | for i = 0; i < len(s)/4-1; i++ { |
| 186 | start := i * 4 |
| 187 | end := start + 4 |
| 188 | buf.WriteString(s[start:end] + ":") |
| 189 | } |
| 190 | buf.WriteString(s[i*4:]) |
| 191 | return buf.String() |
| 192 | } |
| 193 | |
| 194 | func keyIDFromCryptoKey(pubKey PublicKey) string { |
| 195 | // Generate and return a 'libtrust' fingerprint of the public key. |
| 196 | // For an RSA key this should be: |
| 197 | // SHA256(DER encoded ASN1) |
| 198 | // Then truncated to 240 bits and encoded into 12 base32 groups like so: |
| 199 | // ABCD:EFGH:IJKL:MNOP:QRST:UVWX:YZ23:4567:ABCD:EFGH:IJKL:MNOP |
| 200 | derBytes, err := x509.MarshalPKIXPublicKey(pubKey.CryptoPublicKey()) |
| 201 | if err != nil { |
| 202 | return "" |
| 203 | } |
| 204 | hasher := crypto.SHA256.New() |
| 205 | hasher.Write(derBytes) |
| 206 | return keyIDEncode(hasher.Sum(nil)[:30]) |
| 207 | } |
| 208 | |
| 209 | func stringFromMap(m map[string]interface{}, key string) (string, error) { |
| 210 | val, ok := m[key] |
| 211 | if !ok { |
| 212 | return "", fmt.Errorf("%q value not specified", key) |
| 213 | } |
| 214 | |
| 215 | str, ok := val.(string) |
| 216 | if !ok { |
| 217 | return "", fmt.Errorf("%q value must be a string", key) |
| 218 | } |
| 219 | delete(m, key) |
| 220 | |
| 221 | return str, nil |
| 222 | } |
| 223 | |
| 224 | func parseECCoordinate(cB64Url string, curve elliptic.Curve) (*big.Int, error) { |
| 225 | curveByteLen := (curve.Params().BitSize + 7) >> 3 |
| 226 | |
| 227 | cBytes, err := joseBase64UrlDecode(cB64Url) |
| 228 | if err != nil { |
| 229 | return nil, fmt.Errorf("invalid base64 URL encoding: %s", err) |
| 230 | } |
| 231 | cByteLength := len(cBytes) |
| 232 | if cByteLength != curveByteLen { |
| 233 | return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", cByteLength, curveByteLen) |
| 234 | } |
| 235 | return new(big.Int).SetBytes(cBytes), nil |
| 236 | } |
| 237 | |
| 238 | func parseECPrivateParam(dB64Url string, curve elliptic.Curve) (*big.Int, error) { |
| 239 | dBytes, err := joseBase64UrlDecode(dB64Url) |
| 240 | if err != nil { |
| 241 | return nil, fmt.Errorf("invalid base64 URL encoding: %s", err) |
| 242 | } |
| 243 | |
| 244 | // The length of this octet string MUST be ceiling(log-base-2(n)/8) |
| 245 | // octets (where n is the order of the curve). This is because the private |
| 246 | // key d must be in the interval [1, n-1] so the bitlength of d should be |
| 247 | // no larger than the bitlength of n-1. The easiest way to find the octet |
| 248 | // length is to take bitlength(n-1), add 7 to force a carry, and shift this |
| 249 | // bit sequence right by 3, which is essentially dividing by 8 and adding |
| 250 | // 1 if there is any remainder. Thus, the private key value d should be |
| 251 | // output to (bitlength(n-1)+7)>>3 octets. |
| 252 | n := curve.Params().N |
| 253 | octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3 |
| 254 | dByteLength := len(dBytes) |
| 255 | |
| 256 | if dByteLength != octetLength { |
| 257 | return nil, fmt.Errorf("invalid number of octets: got %d, should be %d", dByteLength, octetLength) |
| 258 | } |
| 259 | |
| 260 | return new(big.Int).SetBytes(dBytes), nil |
| 261 | } |
| 262 | |
| 263 | func parseRSAModulusParam(nB64Url string) (*big.Int, error) { |
| 264 | nBytes, err := joseBase64UrlDecode(nB64Url) |
| 265 | if err != nil { |
| 266 | return nil, fmt.Errorf("invalid base64 URL encoding: %s", err) |
| 267 | } |
| 268 | |
| 269 | return new(big.Int).SetBytes(nBytes), nil |
| 270 | } |
| 271 | |
| 272 | func serializeRSAPublicExponentParam(e int) []byte { |
| 273 | // We MUST use the minimum number of octets to represent E. |
| 274 | // E is supposed to be 65537 for performance and security reasons |
| 275 | // and is what golang's rsa package generates, but it might be |
| 276 | // different if imported from some other generator. |
| 277 | buf := make([]byte, 4) |
| 278 | binary.BigEndian.PutUint32(buf, uint32(e)) |
| 279 | var i int |
| 280 | for i = 0; i < 8; i++ { |
| 281 | if buf[i] != 0 { |
| 282 | break |
| 283 | } |
| 284 | } |
| 285 | return buf[i:] |
| 286 | } |
| 287 | |
| 288 | func parseRSAPublicExponentParam(eB64Url string) (int, error) { |
| 289 | eBytes, err := joseBase64UrlDecode(eB64Url) |
| 290 | if err != nil { |
| 291 | return 0, fmt.Errorf("invalid base64 URL encoding: %s", err) |
| 292 | } |
| 293 | // Only the minimum number of bytes were used to represent E, but |
| 294 | // binary.BigEndian.Uint32 expects at least 4 bytes, so we need |
| 295 | // to add zero padding if necassary. |
| 296 | byteLen := len(eBytes) |
| 297 | buf := make([]byte, 4-byteLen, 4) |
| 298 | eBytes = append(buf, eBytes...) |
| 299 | |
| 300 | return int(binary.BigEndian.Uint32(eBytes)), nil |
| 301 | } |
| 302 | |
| 303 | func parseRSAPrivateKeyParamFromMap(m map[string]interface{}, key string) (*big.Int, error) { |
| 304 | b64Url, err := stringFromMap(m, key) |
| 305 | if err != nil { |
| 306 | return nil, err |
| 307 | } |
| 308 | |
| 309 | paramBytes, err := joseBase64UrlDecode(b64Url) |
| 310 | if err != nil { |
| 311 | return nil, fmt.Errorf("invaled base64 URL encoding: %s", err) |
| 312 | } |
| 313 | |
| 314 | return new(big.Int).SetBytes(paramBytes), nil |
| 315 | } |
| 316 | |
| 317 | func createPemBlock(name string, derBytes []byte, headers map[string]interface{}) (*pem.Block, error) { |
| 318 | pemBlock := &pem.Block{Type: name, Bytes: derBytes, Headers: map[string]string{}} |
| 319 | for k, v := range headers { |
| 320 | switch val := v.(type) { |
| 321 | case string: |
| 322 | pemBlock.Headers[k] = val |
| 323 | case []string: |
| 324 | if k == "hosts" { |
| 325 | pemBlock.Headers[k] = strings.Join(val, ",") |
| 326 | } else { |
| 327 | // Return error, non-encodable type |
| 328 | } |
| 329 | default: |
| 330 | // Return error, non-encodable type |
| 331 | } |
| 332 | } |
| 333 | |
| 334 | return pemBlock, nil |
| 335 | } |
| 336 | |
| 337 | func pubKeyFromPEMBlock(pemBlock *pem.Block) (PublicKey, error) { |
| 338 | cryptoPublicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes) |
| 339 | if err != nil { |
| 340 | return nil, fmt.Errorf("unable to decode Public Key PEM data: %s", err) |
| 341 | } |
| 342 | |
| 343 | pubKey, err := FromCryptoPublicKey(cryptoPublicKey) |
| 344 | if err != nil { |
| 345 | return nil, err |
| 346 | } |
| 347 | |
| 348 | addPEMHeadersToKey(pemBlock, pubKey) |
| 349 | |
| 350 | return pubKey, nil |
| 351 | } |
| 352 | |
| 353 | func addPEMHeadersToKey(pemBlock *pem.Block, pubKey PublicKey) { |
| 354 | for key, value := range pemBlock.Headers { |
| 355 | var safeVal interface{} |
| 356 | if key == "hosts" { |
| 357 | safeVal = strings.Split(value, ",") |
| 358 | } else { |
| 359 | safeVal = value |
| 360 | } |
| 361 | pubKey.AddExtendedField(key, safeVal) |
| 362 | } |
| 363 | } |