blob: 053641806116c0b4a39b77cd4c2e6fdd7580785f [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package connection
8
9import (
10 "bytes"
11 "crypto/tls"
12 "crypto/x509"
13 "encoding/asn1"
14 "encoding/hex"
15 "encoding/pem"
16 "errors"
17 "fmt"
18 "io/ioutil"
19 "strings"
20)
21
22// TLSConfig contains options for configuring a TLS connection to the server.
23type TLSConfig struct {
24 *tls.Config
25 clientCertPass func() string
26}
27
28// NewTLSConfig creates a new TLSConfig.
29func NewTLSConfig() *TLSConfig {
30 cfg := &TLSConfig{}
31 cfg.Config = new(tls.Config)
32
33 return cfg
34}
35
36// SetClientCertDecryptPassword sets a function to retrieve the decryption password
37// necessary to read a certificate. This is a function instead of a string to
38// provide greater flexibility when deciding how to retrieve and store the password.
39func (c *TLSConfig) SetClientCertDecryptPassword(f func() string) {
40 c.clientCertPass = f
41}
42
43// SetInsecure sets whether the client should verify the server's certificate
44// chain and hostnames.
45func (c *TLSConfig) SetInsecure(allow bool) {
46 c.InsecureSkipVerify = allow
47}
48
49// AddCACertFromFile adds a root CA certificate to the configuration given a path
50// to the containing file.
51func (c *TLSConfig) AddCACertFromFile(file string) error {
52 data, err := ioutil.ReadFile(file)
53 if err != nil {
54 return err
55 }
56
57 certBytes, err := loadCert(data)
58 if err != nil {
59 return err
60 }
61
62 cert, err := x509.ParseCertificate(certBytes)
63 if err != nil {
64 return err
65 }
66
67 if c.RootCAs == nil {
68 c.RootCAs = x509.NewCertPool()
69 }
70
71 c.RootCAs.AddCert(cert)
72
73 return nil
74}
75
76// AddClientCertFromFile adds a client certificate to the configuration given a path to the
77// containing file and returns the certificate's subject name.
78func (c *TLSConfig) AddClientCertFromFile(clientFile string) (string, error) {
79 data, err := ioutil.ReadFile(clientFile)
80 if err != nil {
81 return "", err
82 }
83
84 var currentBlock *pem.Block
85 var certBlock, certDecodedBlock, keyBlock []byte
86
87 remaining := data
88 start := 0
89 for {
90 currentBlock, remaining = pem.Decode(remaining)
91 if currentBlock == nil {
92 break
93 }
94
95 if currentBlock.Type == "CERTIFICATE" {
96 certBlock = data[start : len(data)-len(remaining)]
97 certDecodedBlock = currentBlock.Bytes
98 start += len(certBlock)
99 } else if strings.HasSuffix(currentBlock.Type, "PRIVATE KEY") {
100 if c.clientCertPass != nil && x509.IsEncryptedPEMBlock(currentBlock) {
101 var encoded bytes.Buffer
102 buf, err := x509.DecryptPEMBlock(currentBlock, []byte(c.clientCertPass()))
103 if err != nil {
104 return "", err
105 }
106
107 pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: buf})
108 keyBlock = encoded.Bytes()
109 start = len(data) - len(remaining)
110 } else {
111 keyBlock = data[start : len(data)-len(remaining)]
112 start += len(keyBlock)
113 }
114 }
115 }
116 if len(certBlock) == 0 {
117 return "", fmt.Errorf("failed to find CERTIFICATE")
118 }
119 if len(keyBlock) == 0 {
120 return "", fmt.Errorf("failed to find PRIVATE KEY")
121 }
122
123 cert, err := tls.X509KeyPair(certBlock, keyBlock)
124 if err != nil {
125 return "", err
126 }
127
128 c.Certificates = append(c.Certificates, cert)
129
130 // The documentation for the tls.X509KeyPair indicates that the Leaf certificate is not
131 // retained.
132 crt, err := x509.ParseCertificate(certDecodedBlock)
133 if err != nil {
134 return "", err
135 }
136
137 return x509CertSubject(crt), nil
138}
139
140func loadCert(data []byte) ([]byte, error) {
141 var certBlock *pem.Block
142
143 for certBlock == nil {
144 if data == nil || len(data) == 0 {
145 return nil, errors.New(".pem file must have both a CERTIFICATE and an RSA PRIVATE KEY section")
146 }
147
148 block, rest := pem.Decode(data)
149 if block == nil {
150 return nil, errors.New("invalid .pem file")
151 }
152
153 switch block.Type {
154 case "CERTIFICATE":
155 if certBlock != nil {
156 return nil, errors.New("multiple CERTIFICATE sections in .pem file")
157 }
158
159 certBlock = block
160 }
161
162 data = rest
163 }
164
165 return certBlock.Bytes, nil
166}
167
168// Because the functionality to convert a pkix.Name to a string wasn't added until Go 1.10, we
169// need to copy the implementation (along with the attributeTypeNames map below).
170func x509CertSubject(cert *x509.Certificate) string {
171 r := cert.Subject.ToRDNSequence()
172
173 s := ""
174 for i := 0; i < len(r); i++ {
175 rdn := r[len(r)-1-i]
176 if i > 0 {
177 s += ","
178 }
179 for j, tv := range rdn {
180 if j > 0 {
181 s += "+"
182 }
183
184 oidString := tv.Type.String()
185 typeName, ok := attributeTypeNames[oidString]
186 if !ok {
187 derBytes, err := asn1.Marshal(tv.Value)
188 if err == nil {
189 s += oidString + "=#" + hex.EncodeToString(derBytes)
190 continue // No value escaping necessary.
191 }
192
193 typeName = oidString
194 }
195
196 valueString := fmt.Sprint(tv.Value)
197 escaped := make([]rune, 0, len(valueString))
198
199 for k, c := range valueString {
200 escape := false
201
202 switch c {
203 case ',', '+', '"', '\\', '<', '>', ';':
204 escape = true
205
206 case ' ':
207 escape = k == 0 || k == len(valueString)-1
208
209 case '#':
210 escape = k == 0
211 }
212
213 if escape {
214 escaped = append(escaped, '\\', c)
215 } else {
216 escaped = append(escaped, c)
217 }
218 }
219
220 s += typeName + "=" + string(escaped)
221 }
222 }
223
224 return s
225}
226
227var attributeTypeNames = map[string]string{
228 "2.5.4.6": "C",
229 "2.5.4.10": "O",
230 "2.5.4.11": "OU",
231 "2.5.4.3": "CN",
232 "2.5.4.5": "SERIALNUMBER",
233 "2.5.4.7": "L",
234 "2.5.4.8": "ST",
235 "2.5.4.9": "STREET",
236 "2.5.4.17": "POSTALCODE",
237}