blob: 6dc1d1d6d92ace012ac03b2a78704fc40d2d79b6 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001// Copyright 2016 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package tcpproxy
16
17import (
18 "fmt"
19 "io"
20 "math/rand"
21 "net"
22 "sync"
23 "time"
24
25 "github.com/coreos/pkg/capnslog"
26)
27
28var (
29 plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "proxy/tcpproxy")
30)
31
32type remote struct {
33 mu sync.Mutex
34 srv *net.SRV
35 addr string
36 inactive bool
37}
38
39func (r *remote) inactivate() {
40 r.mu.Lock()
41 defer r.mu.Unlock()
42 r.inactive = true
43}
44
45func (r *remote) tryReactivate() error {
46 conn, err := net.Dial("tcp", r.addr)
47 if err != nil {
48 return err
49 }
50 conn.Close()
51 r.mu.Lock()
52 defer r.mu.Unlock()
53 r.inactive = false
54 return nil
55}
56
57func (r *remote) isActive() bool {
58 r.mu.Lock()
59 defer r.mu.Unlock()
60 return !r.inactive
61}
62
63type TCPProxy struct {
64 Listener net.Listener
65 Endpoints []*net.SRV
66 MonitorInterval time.Duration
67
68 donec chan struct{}
69
70 mu sync.Mutex // guards the following fields
71 remotes []*remote
72 pickCount int // for round robin
73}
74
75func (tp *TCPProxy) Run() error {
76 tp.donec = make(chan struct{})
77 if tp.MonitorInterval == 0 {
78 tp.MonitorInterval = 5 * time.Minute
79 }
80 for _, srv := range tp.Endpoints {
81 addr := fmt.Sprintf("%s:%d", srv.Target, srv.Port)
82 tp.remotes = append(tp.remotes, &remote{srv: srv, addr: addr})
83 }
84
85 eps := []string{}
86 for _, ep := range tp.Endpoints {
87 eps = append(eps, fmt.Sprintf("%s:%d", ep.Target, ep.Port))
88 }
89 plog.Printf("ready to proxy client requests to %+v", eps)
90
91 go tp.runMonitor()
92 for {
93 in, err := tp.Listener.Accept()
94 if err != nil {
95 return err
96 }
97
98 go tp.serve(in)
99 }
100}
101
102func (tp *TCPProxy) pick() *remote {
103 var weighted []*remote
104 var unweighted []*remote
105
106 bestPr := uint16(65535)
107 w := 0
108 // find best priority class
109 for _, r := range tp.remotes {
110 switch {
111 case !r.isActive():
112 case r.srv.Priority < bestPr:
113 bestPr = r.srv.Priority
114 w = 0
115 weighted = nil
116 unweighted = []*remote{r}
117 fallthrough
118 case r.srv.Priority == bestPr:
119 if r.srv.Weight > 0 {
120 weighted = append(weighted, r)
121 w += int(r.srv.Weight)
122 } else {
123 unweighted = append(unweighted, r)
124 }
125 }
126 }
127 if weighted != nil {
128 if len(unweighted) > 0 && rand.Intn(100) == 1 {
129 // In the presence of records containing weights greater
130 // than 0, records with weight 0 should have a very small
131 // chance of being selected.
132 r := unweighted[tp.pickCount%len(unweighted)]
133 tp.pickCount++
134 return r
135 }
136 // choose a uniform random number between 0 and the sum computed
137 // (inclusive), and select the RR whose running sum value is the
138 // first in the selected order
139 choose := rand.Intn(w)
140 for i := 0; i < len(weighted); i++ {
141 choose -= int(weighted[i].srv.Weight)
142 if choose <= 0 {
143 return weighted[i]
144 }
145 }
146 }
147 if unweighted != nil {
148 for i := 0; i < len(tp.remotes); i++ {
149 picked := tp.remotes[tp.pickCount%len(tp.remotes)]
150 tp.pickCount++
151 if picked.isActive() {
152 return picked
153 }
154 }
155 }
156 return nil
157}
158
159func (tp *TCPProxy) serve(in net.Conn) {
160 var (
161 err error
162 out net.Conn
163 )
164
165 for {
166 tp.mu.Lock()
167 remote := tp.pick()
168 tp.mu.Unlock()
169 if remote == nil {
170 break
171 }
172 // TODO: add timeout
173 out, err = net.Dial("tcp", remote.addr)
174 if err == nil {
175 break
176 }
177 remote.inactivate()
178 plog.Warningf("deactivated endpoint [%s] due to %v for %v", remote.addr, err, tp.MonitorInterval)
179 }
180
181 if out == nil {
182 in.Close()
183 return
184 }
185
186 go func() {
187 io.Copy(in, out)
188 in.Close()
189 out.Close()
190 }()
191
192 io.Copy(out, in)
193 out.Close()
194 in.Close()
195}
196
197func (tp *TCPProxy) runMonitor() {
198 for {
199 select {
200 case <-time.After(tp.MonitorInterval):
201 tp.mu.Lock()
202 for _, rem := range tp.remotes {
203 if rem.isActive() {
204 continue
205 }
206 go func(r *remote) {
207 if err := r.tryReactivate(); err != nil {
208 plog.Warningf("failed to activate endpoint [%s] due to %v (stay inactive for another %v)", r.addr, err, tp.MonitorInterval)
209 } else {
210 plog.Printf("activated %s", r.addr)
211 }
212 }(rem)
213 }
214 tp.mu.Unlock()
215 case <-tp.donec:
216 return
217 }
218 }
219}
220
221func (tp *TCPProxy) Stop() {
222 // graceful shutdown?
223 // shutdown current connections?
224 tp.Listener.Close()
225 close(tp.donec)
226}