blob: 48655063f6f2f1f6a21139137d6dad441d471ac0 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package transport
16
17import (
18 "crypto/ecdsa"
19 "crypto/elliptic"
20 "crypto/rand"
21 "crypto/tls"
22 "crypto/x509"
23 "crypto/x509/pkix"
24 "encoding/pem"
25 "errors"
26 "fmt"
27 "math/big"
28 "net"
29 "os"
30 "path/filepath"
31 "strings"
32 "time"
33
34 "github.com/coreos/etcd/pkg/tlsutil"
35)
36
37func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
38 if l, err = newListener(addr, scheme); err != nil {
39 return nil, err
40 }
41 return wrapTLS(addr, scheme, tlsinfo, l)
42}
43
44func newListener(addr string, scheme string) (net.Listener, error) {
45 if scheme == "unix" || scheme == "unixs" {
46 // unix sockets via unix://laddr
47 return NewUnixListener(addr)
48 }
49 return net.Listen("tcp", addr)
50}
51
52func wrapTLS(addr, scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
53 if scheme != "https" && scheme != "unixs" {
54 return l, nil
55 }
56 return newTLSListener(l, tlsinfo, checkSAN)
57}
58
59type TLSInfo struct {
60 CertFile string
61 KeyFile string
62 CAFile string // TODO: deprecate this in v4
63 TrustedCAFile string
64 ClientCertAuth bool
65 CRLFile string
66 InsecureSkipVerify bool
67
68 // ServerName ensures the cert matches the given host in case of discovery / virtual hosting
69 ServerName string
70
71 // HandshakeFailure is optionally called when a connection fails to handshake. The
72 // connection will be closed immediately afterwards.
73 HandshakeFailure func(*tls.Conn, error)
74
75 // CipherSuites is a list of supported cipher suites.
76 // If empty, Go auto-populates it by default.
77 // Note that cipher suites are prioritized in the given order.
78 CipherSuites []uint16
79
80 selfCert bool
81
82 // parseFunc exists to simplify testing. Typically, parseFunc
83 // should be left nil. In that case, tls.X509KeyPair will be used.
84 parseFunc func([]byte, []byte) (tls.Certificate, error)
85
86 // AllowedCN is a CN which must be provided by a client.
87 AllowedCN string
88}
89
90func (info TLSInfo) String() string {
91 return fmt.Sprintf("cert = %s, key = %s, ca = %s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.CAFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile)
92}
93
94func (info TLSInfo) Empty() bool {
95 return info.CertFile == "" && info.KeyFile == ""
96}
97
98func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) {
99 if err = os.MkdirAll(dirpath, 0700); err != nil {
100 return
101 }
102
103 certPath := filepath.Join(dirpath, "cert.pem")
104 keyPath := filepath.Join(dirpath, "key.pem")
105 _, errcert := os.Stat(certPath)
106 _, errkey := os.Stat(keyPath)
107 if errcert == nil && errkey == nil {
108 info.CertFile = certPath
109 info.KeyFile = keyPath
110 info.selfCert = true
111 return
112 }
113
114 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
115 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
116 if err != nil {
117 return
118 }
119
120 tmpl := x509.Certificate{
121 SerialNumber: serialNumber,
122 Subject: pkix.Name{Organization: []string{"etcd"}},
123 NotBefore: time.Now(),
124 NotAfter: time.Now().Add(365 * (24 * time.Hour)),
125
126 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
127 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
128 BasicConstraintsValid: true,
129 }
130
131 for _, host := range hosts {
132 h, _, _ := net.SplitHostPort(host)
133 if ip := net.ParseIP(h); ip != nil {
134 tmpl.IPAddresses = append(tmpl.IPAddresses, ip)
135 } else {
136 tmpl.DNSNames = append(tmpl.DNSNames, h)
137 }
138 }
139
140 priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
141 if err != nil {
142 return
143 }
144
145 derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv)
146 if err != nil {
147 return
148 }
149
150 certOut, err := os.Create(certPath)
151 if err != nil {
152 return
153 }
154 pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
155 certOut.Close()
156
157 b, err := x509.MarshalECPrivateKey(priv)
158 if err != nil {
159 return
160 }
161 keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
162 if err != nil {
163 return
164 }
165 pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
166 keyOut.Close()
167
168 return SelfCert(dirpath, hosts)
169}
170
171func (info TLSInfo) baseConfig() (*tls.Config, error) {
172 if info.KeyFile == "" || info.CertFile == "" {
173 return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
174 }
175
176 _, err := tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
177 if err != nil {
178 return nil, err
179 }
180
181 cfg := &tls.Config{
182 MinVersion: tls.VersionTLS12,
183 ServerName: info.ServerName,
184 }
185
186 if len(info.CipherSuites) > 0 {
187 cfg.CipherSuites = info.CipherSuites
188 }
189
190 if info.AllowedCN != "" {
191 cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
192 for _, chains := range verifiedChains {
193 if len(chains) != 0 {
194 if info.AllowedCN == chains[0].Subject.CommonName {
195 return nil
196 }
197 }
198 }
199 return errors.New("CommonName authentication failed")
200 }
201 }
202
203 // this only reloads certs when there's a client request
204 // TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
205 cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
206 return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
207 }
208 cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (*tls.Certificate, error) {
209 return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
210 }
211 return cfg, nil
212}
213
214// cafiles returns a list of CA file paths.
215func (info TLSInfo) cafiles() []string {
216 cs := make([]string, 0)
217 if info.CAFile != "" {
218 cs = append(cs, info.CAFile)
219 }
220 if info.TrustedCAFile != "" {
221 cs = append(cs, info.TrustedCAFile)
222 }
223 return cs
224}
225
226// ServerConfig generates a tls.Config object for use by an HTTP server.
227func (info TLSInfo) ServerConfig() (*tls.Config, error) {
228 cfg, err := info.baseConfig()
229 if err != nil {
230 return nil, err
231 }
232
233 cfg.ClientAuth = tls.NoClientCert
234 if info.CAFile != "" || info.ClientCertAuth {
235 cfg.ClientAuth = tls.RequireAndVerifyClientCert
236 }
237
238 CAFiles := info.cafiles()
239 if len(CAFiles) > 0 {
240 cp, err := tlsutil.NewCertPool(CAFiles)
241 if err != nil {
242 return nil, err
243 }
244 cfg.ClientCAs = cp
245 }
246
247 // "h2" NextProtos is necessary for enabling HTTP2 for go's HTTP server
248 cfg.NextProtos = []string{"h2"}
249
250 return cfg, nil
251}
252
253// ClientConfig generates a tls.Config object for use by an HTTP client.
254func (info TLSInfo) ClientConfig() (*tls.Config, error) {
255 var cfg *tls.Config
256 var err error
257
258 if !info.Empty() {
259 cfg, err = info.baseConfig()
260 if err != nil {
261 return nil, err
262 }
263 } else {
264 cfg = &tls.Config{ServerName: info.ServerName}
265 }
266 cfg.InsecureSkipVerify = info.InsecureSkipVerify
267
268 CAFiles := info.cafiles()
269 if len(CAFiles) > 0 {
270 cfg.RootCAs, err = tlsutil.NewCertPool(CAFiles)
271 if err != nil {
272 return nil, err
273 }
274 }
275
276 if info.selfCert {
277 cfg.InsecureSkipVerify = true
278 }
279 return cfg, nil
280}
281
282// IsClosedConnError returns true if the error is from closing listener, cmux.
283// copied from golang.org/x/net/http2/http2.go
284func IsClosedConnError(err error) bool {
285 // 'use of closed network connection' (Go <=1.8)
286 // 'use of closed file or network connection' (Go >1.8, internal/poll.ErrClosing)
287 // 'mux: listener closed' (cmux.ErrListenerClosed)
288 return err != nil && strings.Contains(err.Error(), "closed")
289}