blob: 6f1600945cc6f4de3c548da35f0403d465e89ff5 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001// Copyright 2017 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package transport
16
17import (
18 "context"
19 "crypto/tls"
20 "crypto/x509"
21 "fmt"
22 "io/ioutil"
23 "net"
24 "strings"
25 "sync"
26)
27
28// tlsListener overrides a TLS listener so it will reject client
29// certificates with insufficient SAN credentials or CRL revoked
30// certificates.
31type tlsListener struct {
32 net.Listener
33 connc chan net.Conn
34 donec chan struct{}
35 err error
36 handshakeFailure func(*tls.Conn, error)
37 check tlsCheckFunc
38}
39
40type tlsCheckFunc func(context.Context, *tls.Conn) error
41
42// NewTLSListener handshakes TLS connections and performs optional CRL checking.
43func NewTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
44 check := func(context.Context, *tls.Conn) error { return nil }
45 return newTLSListener(l, tlsinfo, check)
46}
47
48func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.Listener, error) {
49 if tlsinfo == nil || tlsinfo.Empty() {
50 l.Close()
51 return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
52 }
53 tlscfg, err := tlsinfo.ServerConfig()
54 if err != nil {
55 return nil, err
56 }
57
58 hf := tlsinfo.HandshakeFailure
59 if hf == nil {
60 hf = func(*tls.Conn, error) {}
61 }
62
63 if len(tlsinfo.CRLFile) > 0 {
64 prevCheck := check
65 check = func(ctx context.Context, tlsConn *tls.Conn) error {
66 if err := prevCheck(ctx, tlsConn); err != nil {
67 return err
68 }
69 st := tlsConn.ConnectionState()
70 if certs := st.PeerCertificates; len(certs) > 0 {
71 return checkCRL(tlsinfo.CRLFile, certs)
72 }
73 return nil
74 }
75 }
76
77 tlsl := &tlsListener{
78 Listener: tls.NewListener(l, tlscfg),
79 connc: make(chan net.Conn),
80 donec: make(chan struct{}),
81 handshakeFailure: hf,
82 check: check,
83 }
84 go tlsl.acceptLoop()
85 return tlsl, nil
86}
87
88func (l *tlsListener) Accept() (net.Conn, error) {
89 select {
90 case conn := <-l.connc:
91 return conn, nil
92 case <-l.donec:
93 return nil, l.err
94 }
95}
96
97func checkSAN(ctx context.Context, tlsConn *tls.Conn) error {
98 st := tlsConn.ConnectionState()
99 if certs := st.PeerCertificates; len(certs) > 0 {
100 addr := tlsConn.RemoteAddr().String()
101 return checkCertSAN(ctx, certs[0], addr)
102 }
103 return nil
104}
105
106// acceptLoop launches each TLS handshake in a separate goroutine
107// to prevent a hanging TLS connection from blocking other connections.
108func (l *tlsListener) acceptLoop() {
109 var wg sync.WaitGroup
110 var pendingMu sync.Mutex
111
112 pending := make(map[net.Conn]struct{})
113 ctx, cancel := context.WithCancel(context.Background())
114 defer func() {
115 cancel()
116 pendingMu.Lock()
117 for c := range pending {
118 c.Close()
119 }
120 pendingMu.Unlock()
121 wg.Wait()
122 close(l.donec)
123 }()
124
125 for {
126 conn, err := l.Listener.Accept()
127 if err != nil {
128 l.err = err
129 return
130 }
131
132 pendingMu.Lock()
133 pending[conn] = struct{}{}
134 pendingMu.Unlock()
135
136 wg.Add(1)
137 go func() {
138 defer func() {
139 if conn != nil {
140 conn.Close()
141 }
142 wg.Done()
143 }()
144
145 tlsConn := conn.(*tls.Conn)
146 herr := tlsConn.Handshake()
147 pendingMu.Lock()
148 delete(pending, conn)
149 pendingMu.Unlock()
150
151 if herr != nil {
152 l.handshakeFailure(tlsConn, herr)
153 return
154 }
155 if err := l.check(ctx, tlsConn); err != nil {
156 l.handshakeFailure(tlsConn, err)
157 return
158 }
159
160 select {
161 case l.connc <- tlsConn:
162 conn = nil
163 case <-ctx.Done():
164 }
165 }()
166 }
167}
168
169func checkCRL(crlPath string, cert []*x509.Certificate) error {
170 // TODO: cache
171 crlBytes, err := ioutil.ReadFile(crlPath)
172 if err != nil {
173 return err
174 }
175 certList, err := x509.ParseCRL(crlBytes)
176 if err != nil {
177 return err
178 }
179 revokedSerials := make(map[string]struct{})
180 for _, rc := range certList.TBSCertList.RevokedCertificates {
181 revokedSerials[string(rc.SerialNumber.Bytes())] = struct{}{}
182 }
183 for _, c := range cert {
184 serial := string(c.SerialNumber.Bytes())
185 if _, ok := revokedSerials[serial]; ok {
186 return fmt.Errorf("transport: certificate serial %x revoked", serial)
187 }
188 }
189 return nil
190}
191
192func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
193 if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
194 return nil
195 }
196 h, _, herr := net.SplitHostPort(remoteAddr)
197 if herr != nil {
198 return herr
199 }
200 if len(cert.IPAddresses) > 0 {
201 cerr := cert.VerifyHostname(h)
202 if cerr == nil {
203 return nil
204 }
205 if len(cert.DNSNames) == 0 {
206 return cerr
207 }
208 }
209 if len(cert.DNSNames) > 0 {
210 ok, err := isHostInDNS(ctx, h, cert.DNSNames)
211 if ok {
212 return nil
213 }
214 errStr := ""
215 if err != nil {
216 errStr = " (" + err.Error() + ")"
217 }
218 return fmt.Errorf("tls: %q does not match any of DNSNames %q"+errStr, h, cert.DNSNames)
219 }
220 return nil
221}
222
223func isHostInDNS(ctx context.Context, host string, dnsNames []string) (ok bool, err error) {
224 // reverse lookup
225 wildcards, names := []string{}, []string{}
226 for _, dns := range dnsNames {
227 if strings.HasPrefix(dns, "*.") {
228 wildcards = append(wildcards, dns[1:])
229 } else {
230 names = append(names, dns)
231 }
232 }
233 lnames, lerr := net.DefaultResolver.LookupAddr(ctx, host)
234 for _, name := range lnames {
235 // strip trailing '.' from PTR record
236 if name[len(name)-1] == '.' {
237 name = name[:len(name)-1]
238 }
239 for _, wc := range wildcards {
240 if strings.HasSuffix(name, wc) {
241 return true, nil
242 }
243 }
244 for _, n := range names {
245 if n == name {
246 return true, nil
247 }
248 }
249 }
250 err = lerr
251
252 // forward lookup
253 for _, dns := range names {
254 addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
255 if lerr != nil {
256 err = lerr
257 continue
258 }
259 for _, addr := range addrs {
260 if addr == host {
261 return true, nil
262 }
263 }
264 }
265 return false, err
266}
267
268func (l *tlsListener) Close() error {
269 err := l.Listener.Close()
270 <-l.donec
271 return err
272}