blob: 5ba921e72dc06ebc3b3a1bbba231b1eb27f396ca [file] [log] [blame]
khenaidooab1f7bd2019-11-14 14:00:27 -05001// Copyright 2016 The CMux Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12// implied. See the License for the specific language governing
13// permissions and limitations under the License.
14
15package cmux
16
17import (
khenaidood948f772021-08-11 17:49:24 -040018 "errors"
khenaidooab1f7bd2019-11-14 14:00:27 -050019 "fmt"
20 "io"
21 "net"
22 "sync"
23 "time"
24)
25
26// Matcher matches a connection based on its content.
27type Matcher func(io.Reader) bool
28
29// MatchWriter is a match that can also write response (say to do handshake).
30type MatchWriter func(io.Writer, io.Reader) bool
31
32// ErrorHandler handles an error and returns whether
33// the mux should continue serving the listener.
34type ErrorHandler func(error) bool
35
36var _ net.Error = ErrNotMatched{}
37
38// ErrNotMatched is returned whenever a connection is not matched by any of
39// the matchers registered in the multiplexer.
40type ErrNotMatched struct {
41 c net.Conn
42}
43
44func (e ErrNotMatched) Error() string {
45 return fmt.Sprintf("mux: connection %v not matched by an matcher",
46 e.c.RemoteAddr())
47}
48
49// Temporary implements the net.Error interface.
50func (e ErrNotMatched) Temporary() bool { return true }
51
52// Timeout implements the net.Error interface.
53func (e ErrNotMatched) Timeout() bool { return false }
54
55type errListenerClosed string
56
57func (e errListenerClosed) Error() string { return string(e) }
58func (e errListenerClosed) Temporary() bool { return false }
59func (e errListenerClosed) Timeout() bool { return false }
60
61// ErrListenerClosed is returned from muxListener.Accept when the underlying
62// listener is closed.
63var ErrListenerClosed = errListenerClosed("mux: listener closed")
64
khenaidood948f772021-08-11 17:49:24 -040065// ErrServerClosed is returned from muxListener.Accept when mux server is closed.
66var ErrServerClosed = errors.New("mux: server closed")
67
khenaidooab1f7bd2019-11-14 14:00:27 -050068// for readability of readTimeout
69var noTimeout time.Duration
70
71// New instantiates a new connection multiplexer.
72func New(l net.Listener) CMux {
73 return &cMux{
74 root: l,
75 bufLen: 1024,
76 errh: func(_ error) bool { return true },
77 donec: make(chan struct{}),
78 readTimeout: noTimeout,
79 }
80}
81
82// CMux is a multiplexer for network connections.
83type CMux interface {
84 // Match returns a net.Listener that sees (i.e., accepts) only
85 // the connections matched by at least one of the matcher.
86 //
87 // The order used to call Match determines the priority of matchers.
88 Match(...Matcher) net.Listener
89 // MatchWithWriters returns a net.Listener that accepts only the
90 // connections that matched by at least of the matcher writers.
91 //
92 // Prefer Matchers over MatchWriters, since the latter can write on the
93 // connection before the actual handler.
94 //
95 // The order used to call Match determines the priority of matchers.
96 MatchWithWriters(...MatchWriter) net.Listener
97 // Serve starts multiplexing the listener. Serve blocks and perhaps
98 // should be invoked concurrently within a go routine.
99 Serve() error
khenaidood948f772021-08-11 17:49:24 -0400100 // Closes cmux server and stops accepting any connections on listener
101 Close()
khenaidooab1f7bd2019-11-14 14:00:27 -0500102 // HandleError registers an error handler that handles listener errors.
103 HandleError(ErrorHandler)
104 // sets a timeout for the read of matchers
105 SetReadTimeout(time.Duration)
106}
107
108type matchersListener struct {
109 ss []MatchWriter
110 l muxListener
111}
112
113type cMux struct {
114 root net.Listener
115 bufLen int
116 errh ErrorHandler
khenaidooab1f7bd2019-11-14 14:00:27 -0500117 sls []matchersListener
118 readTimeout time.Duration
khenaidood948f772021-08-11 17:49:24 -0400119 donec chan struct{}
120 mu sync.Mutex
khenaidooab1f7bd2019-11-14 14:00:27 -0500121}
122
123func matchersToMatchWriters(matchers []Matcher) []MatchWriter {
124 mws := make([]MatchWriter, 0, len(matchers))
125 for _, m := range matchers {
126 cm := m
127 mws = append(mws, func(w io.Writer, r io.Reader) bool {
128 return cm(r)
129 })
130 }
131 return mws
132}
133
134func (m *cMux) Match(matchers ...Matcher) net.Listener {
135 mws := matchersToMatchWriters(matchers)
136 return m.MatchWithWriters(mws...)
137}
138
139func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener {
140 ml := muxListener{
141 Listener: m.root,
142 connc: make(chan net.Conn, m.bufLen),
khenaidood948f772021-08-11 17:49:24 -0400143 donec: make(chan struct{}),
khenaidooab1f7bd2019-11-14 14:00:27 -0500144 }
145 m.sls = append(m.sls, matchersListener{ss: matchers, l: ml})
146 return ml
147}
148
149func (m *cMux) SetReadTimeout(t time.Duration) {
150 m.readTimeout = t
151}
152
153func (m *cMux) Serve() error {
154 var wg sync.WaitGroup
155
156 defer func() {
khenaidood948f772021-08-11 17:49:24 -0400157 m.closeDoneChans()
khenaidooab1f7bd2019-11-14 14:00:27 -0500158 wg.Wait()
159
160 for _, sl := range m.sls {
161 close(sl.l.connc)
162 // Drain the connections enqueued for the listener.
163 for c := range sl.l.connc {
164 _ = c.Close()
165 }
166 }
167 }()
168
169 for {
170 c, err := m.root.Accept()
171 if err != nil {
172 if !m.handleErr(err) {
173 return err
174 }
175 continue
176 }
177
178 wg.Add(1)
179 go m.serve(c, m.donec, &wg)
180 }
181}
182
183func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
184 defer wg.Done()
185
186 muc := newMuxConn(c)
187 if m.readTimeout > noTimeout {
188 _ = c.SetReadDeadline(time.Now().Add(m.readTimeout))
189 }
190 for _, sl := range m.sls {
191 for _, s := range sl.ss {
192 matched := s(muc.Conn, muc.startSniffing())
193 if matched {
194 muc.doneSniffing()
195 if m.readTimeout > noTimeout {
196 _ = c.SetReadDeadline(time.Time{})
197 }
198 select {
199 case sl.l.connc <- muc:
200 case <-donec:
201 _ = c.Close()
202 }
203 return
204 }
205 }
206 }
207
208 _ = c.Close()
209 err := ErrNotMatched{c: c}
210 if !m.handleErr(err) {
211 _ = m.root.Close()
212 }
213}
214
khenaidood948f772021-08-11 17:49:24 -0400215func (m *cMux) Close() {
216 m.closeDoneChans()
217}
218
219func (m *cMux) closeDoneChans() {
220 m.mu.Lock()
221 defer m.mu.Unlock()
222
223 select {
224 case <-m.donec:
225 // Already closed. Don't close again
226 default:
227 close(m.donec)
228 }
229 for _, sl := range m.sls {
230 select {
231 case <-sl.l.donec:
232 // Already closed. Don't close again
233 default:
234 close(sl.l.donec)
235 }
236 }
237}
238
khenaidooab1f7bd2019-11-14 14:00:27 -0500239func (m *cMux) HandleError(h ErrorHandler) {
240 m.errh = h
241}
242
243func (m *cMux) handleErr(err error) bool {
244 if !m.errh(err) {
245 return false
246 }
247
248 if ne, ok := err.(net.Error); ok {
249 return ne.Temporary()
250 }
251
252 return false
253}
254
255type muxListener struct {
256 net.Listener
257 connc chan net.Conn
khenaidood948f772021-08-11 17:49:24 -0400258 donec chan struct{}
khenaidooab1f7bd2019-11-14 14:00:27 -0500259}
260
261func (l muxListener) Accept() (net.Conn, error) {
khenaidood948f772021-08-11 17:49:24 -0400262 select {
263 case c, ok := <-l.connc:
264 if !ok {
265 return nil, ErrListenerClosed
266 }
267 return c, nil
268 case <-l.donec:
269 return nil, ErrServerClosed
khenaidooab1f7bd2019-11-14 14:00:27 -0500270 }
khenaidooab1f7bd2019-11-14 14:00:27 -0500271}
272
273// MuxConn wraps a net.Conn and provides transparent sniffing of connection data.
274type MuxConn struct {
275 net.Conn
276 buf bufferedReader
277}
278
279func newMuxConn(c net.Conn) *MuxConn {
280 return &MuxConn{
281 Conn: c,
282 buf: bufferedReader{source: c},
283 }
284}
285
286// From the io.Reader documentation:
287//
288// When Read encounters an error or end-of-file condition after
289// successfully reading n > 0 bytes, it returns the number of
290// bytes read. It may return the (non-nil) error from the same call
291// or return the error (and n == 0) from a subsequent call.
292// An instance of this general case is that a Reader returning
293// a non-zero number of bytes at the end of the input stream may
294// return either err == EOF or err == nil. The next Read should
295// return 0, EOF.
296func (m *MuxConn) Read(p []byte) (int, error) {
297 return m.buf.Read(p)
298}
299
300func (m *MuxConn) startSniffing() io.Reader {
301 m.buf.reset(true)
302 return &m.buf
303}
304
305func (m *MuxConn) doneSniffing() {
306 m.buf.reset(false)
307}