blob: 86e956bc8b77ad2c0e83839d22e8dcb6aed14ec9 [file] [log] [blame]
Scott Baker105df152020-04-13 15:55:14 -07001/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package credentials
20
21import (
22 "context"
23 "crypto/tls"
24 "crypto/x509"
25 "fmt"
26 "io/ioutil"
27 "net"
28
29 "google.golang.org/grpc/credentials/internal"
30)
31
32// TLSInfo contains the auth information for a TLS authenticated connection.
33// It implements the AuthInfo interface.
34type TLSInfo struct {
35 State tls.ConnectionState
36 CommonAuthInfo
37}
38
39// AuthType returns the type of TLSInfo as a string.
40func (t TLSInfo) AuthType() string {
41 return "tls"
42}
43
44// GetSecurityValue returns security info requested by channelz.
45func (t TLSInfo) GetSecurityValue() ChannelzSecurityValue {
46 v := &TLSChannelzSecurityValue{
47 StandardName: cipherSuiteLookup[t.State.CipherSuite],
48 }
49 // Currently there's no way to get LocalCertificate info from tls package.
50 if len(t.State.PeerCertificates) > 0 {
51 v.RemoteCertificate = t.State.PeerCertificates[0].Raw
52 }
53 return v
54}
55
56// tlsCreds is the credentials required for authenticating a connection using TLS.
57type tlsCreds struct {
58 // TLS configuration
59 config *tls.Config
60}
61
62func (c tlsCreds) Info() ProtocolInfo {
63 return ProtocolInfo{
64 SecurityProtocol: "tls",
65 SecurityVersion: "1.2",
66 ServerName: c.config.ServerName,
67 }
68}
69
70func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
71 // use local cfg to avoid clobbering ServerName if using multiple endpoints
72 cfg := cloneTLSConfig(c.config)
73 if cfg.ServerName == "" {
74 serverName, _, err := net.SplitHostPort(authority)
75 if err != nil {
76 // If the authority had no host port or if the authority cannot be parsed, use it as-is.
77 serverName = authority
78 }
79 cfg.ServerName = serverName
80 }
81 conn := tls.Client(rawConn, cfg)
82 errChannel := make(chan error, 1)
83 go func() {
84 errChannel <- conn.Handshake()
85 close(errChannel)
86 }()
87 select {
88 case err := <-errChannel:
89 if err != nil {
90 conn.Close()
91 return nil, nil, err
92 }
93 case <-ctx.Done():
94 conn.Close()
95 return nil, nil, ctx.Err()
96 }
97 return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState(), CommonAuthInfo{PrivacyAndIntegrity}}, nil
98}
99
100func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
101 conn := tls.Server(rawConn, c.config)
102 if err := conn.Handshake(); err != nil {
103 conn.Close()
104 return nil, nil, err
105 }
106 return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState(), CommonAuthInfo{PrivacyAndIntegrity}}, nil
107}
108
109func (c *tlsCreds) Clone() TransportCredentials {
110 return NewTLS(c.config)
111}
112
113func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
114 c.config.ServerName = serverNameOverride
115 return nil
116}
117
118const alpnProtoStrH2 = "h2"
119
120func appendH2ToNextProtos(ps []string) []string {
121 for _, p := range ps {
122 if p == alpnProtoStrH2 {
123 return ps
124 }
125 }
126 ret := make([]string, 0, len(ps)+1)
127 ret = append(ret, ps...)
128 return append(ret, alpnProtoStrH2)
129}
130
131// NewTLS uses c to construct a TransportCredentials based on TLS.
132func NewTLS(c *tls.Config) TransportCredentials {
133 tc := &tlsCreds{cloneTLSConfig(c)}
134 tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
135 return tc
136}
137
138// NewClientTLSFromCert constructs TLS credentials from the provided root
139// certificate authority certificate(s) to validate server connections. If
140// certificates to establish the identity of the client need to be included in
141// the credentials (eg: for mTLS), use NewTLS instead, where a complete
142// tls.Config can be specified.
143// serverNameOverride is for testing only. If set to a non empty string,
144// it will override the virtual host name of authority (e.g. :authority header
145// field) in requests.
146func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials {
147 return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
148}
149
150// NewClientTLSFromFile constructs TLS credentials from the provided root
151// certificate authority certificate file(s) to validate server connections. If
152// certificates to establish the identity of the client need to be included in
153// the credentials (eg: for mTLS), use NewTLS instead, where a complete
154// tls.Config can be specified.
155// serverNameOverride is for testing only. If set to a non empty string,
156// it will override the virtual host name of authority (e.g. :authority header
157// field) in requests.
158func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) {
159 b, err := ioutil.ReadFile(certFile)
160 if err != nil {
161 return nil, err
162 }
163 cp := x509.NewCertPool()
164 if !cp.AppendCertsFromPEM(b) {
165 return nil, fmt.Errorf("credentials: failed to append certificates")
166 }
167 return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
168}
169
170// NewServerTLSFromCert constructs TLS credentials from the input certificate for server.
171func NewServerTLSFromCert(cert *tls.Certificate) TransportCredentials {
172 return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
173}
174
175// NewServerTLSFromFile constructs TLS credentials from the input certificate file and key
176// file for server.
177func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error) {
178 cert, err := tls.LoadX509KeyPair(certFile, keyFile)
179 if err != nil {
180 return nil, err
181 }
182 return NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}), nil
183}
184
185// TLSChannelzSecurityValue defines the struct that TLS protocol should return
186// from GetSecurityValue(), containing security info like cipher and certificate used.
187//
188// This API is EXPERIMENTAL.
189type TLSChannelzSecurityValue struct {
190 ChannelzSecurityValue
191 StandardName string
192 LocalCertificate []byte
193 RemoteCertificate []byte
194}
195
196var cipherSuiteLookup = map[uint16]string{
197 tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA",
198 tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
199 tls.TLS_RSA_WITH_AES_128_CBC_SHA: "TLS_RSA_WITH_AES_128_CBC_SHA",
200 tls.TLS_RSA_WITH_AES_256_CBC_SHA: "TLS_RSA_WITH_AES_256_CBC_SHA",
201 tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "TLS_RSA_WITH_AES_128_GCM_SHA256",
202 tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "TLS_RSA_WITH_AES_256_GCM_SHA384",
203 tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA",
204 tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
205 tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
206 tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "TLS_ECDHE_RSA_WITH_RC4_128_SHA",
207 tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA",
208 tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
209 tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
210 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
211 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
212 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
213 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
214 tls.TLS_FALLBACK_SCSV: "TLS_FALLBACK_SCSV",
215 tls.TLS_RSA_WITH_AES_128_CBC_SHA256: "TLS_RSA_WITH_AES_128_CBC_SHA256",
216 tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256",
217 tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
218 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305",
219 tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
220}
221
222// cloneTLSConfig returns a shallow clone of the exported
223// fields of cfg, ignoring the unexported sync.Once, which
224// contains a mutex and must not be copied.
225//
226// If cfg is nil, a new zero tls.Config is returned.
227//
228// TODO: inline this function if possible.
229func cloneTLSConfig(cfg *tls.Config) *tls.Config {
230 if cfg == nil {
231 return &tls.Config{}
232 }
233
234 return cfg.Clone()
235}