blob: 80403423d8a67a0ece2315ccd2a1d4d80d6367da [file] [log] [blame]
khenaidoo59ce9dd2019-11-11 13:05:32 -05001// Copyright 2016 The CMux Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12// implied. See the License for the specific language governing
13// permissions and limitations under the License.
14
15package cmux
16
17import (
18 "fmt"
19 "io"
20 "net"
21 "sync"
22 "time"
23)
24
25// Matcher matches a connection based on its content.
26type Matcher func(io.Reader) bool
27
28// MatchWriter is a match that can also write response (say to do handshake).
29type MatchWriter func(io.Writer, io.Reader) bool
30
31// ErrorHandler handles an error and returns whether
32// the mux should continue serving the listener.
33type ErrorHandler func(error) bool
34
35var _ net.Error = ErrNotMatched{}
36
37// ErrNotMatched is returned whenever a connection is not matched by any of
38// the matchers registered in the multiplexer.
39type ErrNotMatched struct {
40 c net.Conn
41}
42
43func (e ErrNotMatched) Error() string {
44 return fmt.Sprintf("mux: connection %v not matched by an matcher",
45 e.c.RemoteAddr())
46}
47
48// Temporary implements the net.Error interface.
49func (e ErrNotMatched) Temporary() bool { return true }
50
51// Timeout implements the net.Error interface.
52func (e ErrNotMatched) Timeout() bool { return false }
53
54type errListenerClosed string
55
56func (e errListenerClosed) Error() string { return string(e) }
57func (e errListenerClosed) Temporary() bool { return false }
58func (e errListenerClosed) Timeout() bool { return false }
59
60// ErrListenerClosed is returned from muxListener.Accept when the underlying
61// listener is closed.
62var ErrListenerClosed = errListenerClosed("mux: listener closed")
63
64// for readability of readTimeout
65var noTimeout time.Duration
66
67// New instantiates a new connection multiplexer.
68func New(l net.Listener) CMux {
69 return &cMux{
70 root: l,
71 bufLen: 1024,
72 errh: func(_ error) bool { return true },
73 donec: make(chan struct{}),
74 readTimeout: noTimeout,
75 }
76}
77
78// CMux is a multiplexer for network connections.
79type CMux interface {
80 // Match returns a net.Listener that sees (i.e., accepts) only
81 // the connections matched by at least one of the matcher.
82 //
83 // The order used to call Match determines the priority of matchers.
84 Match(...Matcher) net.Listener
85 // MatchWithWriters returns a net.Listener that accepts only the
86 // connections that matched by at least of the matcher writers.
87 //
88 // Prefer Matchers over MatchWriters, since the latter can write on the
89 // connection before the actual handler.
90 //
91 // The order used to call Match determines the priority of matchers.
92 MatchWithWriters(...MatchWriter) net.Listener
93 // Serve starts multiplexing the listener. Serve blocks and perhaps
94 // should be invoked concurrently within a go routine.
95 Serve() error
96 // HandleError registers an error handler that handles listener errors.
97 HandleError(ErrorHandler)
98 // sets a timeout for the read of matchers
99 SetReadTimeout(time.Duration)
100}
101
102type matchersListener struct {
103 ss []MatchWriter
104 l muxListener
105}
106
107type cMux struct {
108 root net.Listener
109 bufLen int
110 errh ErrorHandler
111 donec chan struct{}
112 sls []matchersListener
113 readTimeout time.Duration
114}
115
116func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
117 mws := make([]MatchWriter, 0, len(matchers))
118 for _, m := range matchers {
119 cm := m
120 mws = append(mws, func(w io.Writer, r io.Reader) bool {
121 return cm(r)
122 })
123 }
124 return mws
125}
126
127func (m *cMux) Match(matchers ...Matcher) net.Listener {
128 mws := matchersToMatchWriters(matchers)
129 return m.MatchWithWriters(mws...)
130}
131
132func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
133 ml := muxListener{
134 Listener: m.root,
135 connc: make(chan net.Conn, m.bufLen),
136 }
137 m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
138 return ml
139}
140
141func (m *cMux) SetReadTimeout(t time.Duration) {
142 m.readTimeout = t
143}
144
145func (m *cMux) Serve() error {
146 var wg sync.WaitGroup
147
148 defer func() {
149 close(m.donec)
150 wg.Wait()
151
152 for _, sl := range m.sls {
153 close(sl.l.connc)
154 // Drain the connections enqueued for the listener.
155 for c := range sl.l.connc {
156 _ = c.Close()
157 }
158 }
159 }()
160
161 for {
162 c, err := m.root.Accept()
163 if err != nil {
164 if !m.handleErr(err) {
165 return err
166 }
167 continue
168 }
169
170 wg.Add(1)
171 go m.serve(c, m.donec, &wg)
172 }
173}
174
175func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
176 defer wg.Done()
177
178 muc := newMuxConn(c)
179 if m.readTimeout > noTimeout {
180 _ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
181 }
182 for _, sl := range m.sls {
183 for _, s := range sl.ss {
184 matched := s(muc.Conn, muc.startSniffing())
185 if matched {
186 muc.doneSniffing()
187 if m.readTimeout > noTimeout {
188 _ = c.SetReadDeadline(time.Time{})
189 }
190 select {
191 case sl.l.connc <- muc:
192 case <-donec:
193 _ = c.Close()
194 }
195 return
196 }
197 }
198 }
199
200 _ = c.Close()
201 err := ErrNotMatched{c: c}
202 if !m.handleErr(err) {
203 _ = m.root.Close()
204 }
205}
206
207func (m *cMux) HandleError(h ErrorHandler) {
208 m.errh = h
209}
210
211func (m *cMux) handleErr(err error) bool {
212 if !m.errh(err) {
213 return false
214 }
215
216 if ne, ok := err.(net.Error); ok {
217 return ne.Temporary()
218 }
219
220 return false
221}
222
223type muxListener struct {
224 net.Listener
225 connc chan net.Conn
226}
227
228func (l muxListener) Accept() (net.Conn, error) {
229 c, ok := <-l.connc
230 if !ok {
231 return nil, ErrListenerClosed
232 }
233 return c, nil
234}
235
236// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
237type MuxConn struct {
238 net.Conn
239 buf bufferedReader
240}
241
242func newMuxConn(c net.Conn) *MuxConn {
243 return &MuxConn{
244 Conn: c,
245 buf: bufferedReader{source: c},
246 }
247}
248
249// From the io.Reader documentation:
250//
251// When Read encounters an error or end-of-file condition after
252// successfully reading n > 0 bytes, it returns the number of
253// bytes read. It may return the (non-nil) error from the same call
254// or return the error (and n == 0) from a subsequent call.
255// An instance of this general case is that a Reader returning
256// a non-zero number of bytes at the end of the input stream may
257// return either err == EOF or err == nil. The next Read should
258// return 0, EOF.
259func (m *MuxConn) Read(p []byte) (int, error) {
260 return m.buf.Read(p)
261}
262
263func (m *MuxConn) startSniffing() io.Reader {
264 m.buf.reset(true)
265 return &m.buf
266}
267
268func (m *MuxConn) doneSniffing() {
269 m.buf.reset(false)
270}