blob: 2dc2012924e3d2e59340b9ec69a088addc4fb704 [file] [log] [blame]
khenaidood948f772021-08-11 17:49:24 -04001// Copyright 2019 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package credentials implements gRPC credential interface with etcd specific logic.
16// e.g., client handshake with custom authority parameter
17package credentials
18
19import (
20 "context"
21 "crypto/tls"
22 "net"
23 "sync"
24
25 "github.com/coreos/etcd/clientv3/balancer/resolver/endpoint"
26 "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
27 grpccredentials "google.golang.org/grpc/credentials"
28)
29
30// Config defines gRPC credential configuration.
31type Config struct {
32 TLSConfig *tls.Config
33}
34
35// Bundle defines gRPC credential interface.
36type Bundle interface {
37 grpccredentials.Bundle
38 UpdateAuthToken(token string)
39}
40
41// NewBundle constructs a new gRPC credential bundle.
42func NewBundle(cfg Config) Bundle {
43 return &bundle{
44 tc: newTransportCredential(cfg.TLSConfig),
45 rc: newPerRPCCredential(),
46 }
47}
48
49// bundle implements "grpccredentials.Bundle" interface.
50type bundle struct {
51 tc *transportCredential
52 rc *perRPCCredential
53}
54
55func (b *bundle) TransportCredentials() grpccredentials.TransportCredentials {
56 return b.tc
57}
58
59func (b *bundle) PerRPCCredentials() grpccredentials.PerRPCCredentials {
60 return b.rc
61}
62
63func (b *bundle) NewWithMode(mode string) (grpccredentials.Bundle, error) {
64 // no-op
65 return nil, nil
66}
67
68// transportCredential implements "grpccredentials.TransportCredentials" interface.
69// transportCredential wraps TransportCredentials to track which
70// addresses are dialed for which endpoints, and then sets the authority when checking the endpoint's cert to the
71// hostname or IP of the dialed endpoint.
72// This is a workaround of a gRPC load balancer issue. gRPC uses the dialed target's service name as the authority when
73// checking all endpoint certs, which does not work for etcd servers using their hostname or IP as the Subject Alternative Name
74// in their TLS certs.
75// To enable, include both WithTransportCredentials(creds) and WithContextDialer(creds.Dialer)
76// when dialing.
77type transportCredential struct {
78 gtc grpccredentials.TransportCredentials
79 mu sync.Mutex
80 // addrToEndpoint maps from the connection addresses that are dialed to the hostname or IP of the
81 // endpoint provided to the dialer when dialing
82 addrToEndpoint map[string]string
83}
84
85func newTransportCredential(cfg *tls.Config) *transportCredential {
86 return &transportCredential{
87 gtc: grpccredentials.NewTLS(cfg),
88 addrToEndpoint: map[string]string{},
89 }
90}
91
92func (tc *transportCredential) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, grpccredentials.AuthInfo, error) {
93 // Set the authority when checking the endpoint's cert to the hostname or IP of the dialed endpoint
94 tc.mu.Lock()
95 dialEp, ok := tc.addrToEndpoint[rawConn.RemoteAddr().String()]
96 tc.mu.Unlock()
97 if ok {
98 _, host, _ := endpoint.ParseEndpoint(dialEp)
99 authority = host
100 }
101 return tc.gtc.ClientHandshake(ctx, authority, rawConn)
102}
103
104// return true if given string is an IP.
105func isIP(ep string) bool {
106 return net.ParseIP(ep) != nil
107}
108
109func (tc *transportCredential) ServerHandshake(rawConn net.Conn) (net.Conn, grpccredentials.AuthInfo, error) {
110 return tc.gtc.ServerHandshake(rawConn)
111}
112
113func (tc *transportCredential) Info() grpccredentials.ProtocolInfo {
114 return tc.gtc.Info()
115}
116
117func (tc *transportCredential) Clone() grpccredentials.TransportCredentials {
118 copy := map[string]string{}
119 tc.mu.Lock()
120 for k, v := range tc.addrToEndpoint {
121 copy[k] = v
122 }
123 tc.mu.Unlock()
124 return &transportCredential{
125 gtc: tc.gtc.Clone(),
126 addrToEndpoint: copy,
127 }
128}
129
130func (tc *transportCredential) OverrideServerName(serverNameOverride string) error {
131 return tc.gtc.OverrideServerName(serverNameOverride)
132}
133
134func (tc *transportCredential) Dialer(ctx context.Context, dialEp string) (net.Conn, error) {
135 // Keep track of which addresses are dialed for which endpoints
136 conn, err := endpoint.Dialer(ctx, dialEp)
137 if conn != nil {
138 tc.mu.Lock()
139 tc.addrToEndpoint[conn.RemoteAddr().String()] = dialEp
140 tc.mu.Unlock()
141 }
142 return conn, err
143}
144
145// perRPCCredential implements "grpccredentials.PerRPCCredentials" interface.
146type perRPCCredential struct {
147 authToken string
148 authTokenMu sync.RWMutex
149}
150
151func newPerRPCCredential() *perRPCCredential { return &perRPCCredential{} }
152
153func (rc *perRPCCredential) RequireTransportSecurity() bool { return false }
154
155func (rc *perRPCCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) {
156 rc.authTokenMu.RLock()
157 authToken := rc.authToken
158 rc.authTokenMu.RUnlock()
159 return map[string]string{rpctypes.TokenFieldNameGRPC: authToken}, nil
160}
161
162func (b *bundle) UpdateAuthToken(token string) {
163 if b.rc == nil {
164 return
165 }
166 b.rc.UpdateAuthToken(token)
167}
168
169func (rc *perRPCCredential) UpdateAuthToken(token string) {
170 rc.authTokenMu.Lock()
171 rc.authToken = token
172 rc.authTokenMu.Unlock()
173}