blob: 70921627af71c6024ae2935e05921780f4d35014 [file] [log] [blame]
khenaidooab1f7bd2019-11-14 14:00:27 -05001package wsproxy
2
3import (
4 "bufio"
khenaidood948f772021-08-11 17:49:24 -04005 "fmt"
khenaidooab1f7bd2019-11-14 14:00:27 -05006 "io"
7 "net/http"
8 "strings"
khenaidood948f772021-08-11 17:49:24 -04009 "time"
khenaidooab1f7bd2019-11-14 14:00:27 -050010
11 "github.com/gorilla/websocket"
12 "github.com/sirupsen/logrus"
13 "golang.org/x/net/context"
14)
15
16// MethodOverrideParam defines the special URL parameter that is translated into the subsequent proxied streaming http request's method.
17//
18// Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
19var MethodOverrideParam = "method"
20
21// TokenCookieName defines the cookie name that is translated to an 'Authorization: Bearer' header in the streaming http request's headers.
22//
23// Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
24var TokenCookieName = "token"
25
26// RequestMutatorFunc can supply an alternate outgoing request.
27type RequestMutatorFunc func(incoming *http.Request, outgoing *http.Request) *http.Request
28
29// Proxy provides websocket transport upgrade to compatible endpoints.
30type Proxy struct {
khenaidood948f772021-08-11 17:49:24 -040031 h http.Handler
32 logger Logger
33 maxRespBodyBufferBytes int
34 methodOverrideParam string
35 tokenCookieName string
36 requestMutator RequestMutatorFunc
37 headerForwarder func(header string) bool
38 pingInterval time.Duration
39 pingWait time.Duration
40 pongWait time.Duration
khenaidooab1f7bd2019-11-14 14:00:27 -050041}
42
43// Logger collects log messages.
44type Logger interface {
45 Warnln(...interface{})
46 Debugln(...interface{})
47}
48
49func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
50 if !websocket.IsWebSocketUpgrade(r) {
51 p.h.ServeHTTP(w, r)
52 return
53 }
54 p.proxy(w, r)
55}
56
57// Option allows customization of the proxy.
58type Option func(*Proxy)
59
khenaidood948f772021-08-11 17:49:24 -040060// WithMaxRespBodyBufferSize allows specification of a custom size for the
61// buffer used while reading the response body. By default, the bufio.Scanner
62// used to read the response body sets the maximum token size to MaxScanTokenSize.
63func WithMaxRespBodyBufferSize(nBytes int) Option {
64 return func(p *Proxy) {
65 p.maxRespBodyBufferBytes = nBytes
66 }
67}
68
khenaidooab1f7bd2019-11-14 14:00:27 -050069// WithMethodParamOverride allows specification of the special http parameter that is used in the proxied streaming request.
70func WithMethodParamOverride(param string) Option {
71 return func(p *Proxy) {
72 p.methodOverrideParam = param
73 }
74}
75
76// WithTokenCookieName allows specification of the cookie that is supplied as an upstream 'Authorization: Bearer' http header.
77func WithTokenCookieName(param string) Option {
78 return func(p *Proxy) {
79 p.tokenCookieName = param
80 }
81}
82
83// WithRequestMutator allows a custom RequestMutatorFunc to be supplied.
84func WithRequestMutator(fn RequestMutatorFunc) Option {
85 return func(p *Proxy) {
86 p.requestMutator = fn
87 }
88}
89
khenaidood948f772021-08-11 17:49:24 -040090// WithForwardedHeaders allows controlling which headers are forwarded.
91func WithForwardedHeaders(fn func(header string) bool) Option {
92 return func(p *Proxy) {
93 p.headerForwarder = fn
94 }
95}
96
khenaidooab1f7bd2019-11-14 14:00:27 -050097// WithLogger allows a custom FieldLogger to be supplied
98func WithLogger(logger Logger) Option {
99 return func(p *Proxy) {
100 p.logger = logger
101 }
102}
103
khenaidood948f772021-08-11 17:49:24 -0400104// WithPingControl allows specification of ping pong control. The interval
105// parameter specifies the pingInterval between pings. The allowed wait time
106// for a pong response is (pingInterval * 10) / 9.
107func WithPingControl(interval time.Duration) Option {
108 return func(proxy *Proxy) {
109 proxy.pingInterval = interval
110 proxy.pongWait = (interval * 10) / 9
111 proxy.pingWait = proxy.pongWait / 6
112 }
113}
114
115var defaultHeadersToForward = map[string]bool{
116 "Origin": true,
117 "origin": true,
118 "Referer": true,
119 "referer": true,
120}
121
122func defaultHeaderForwarder(header string) bool {
123 return defaultHeadersToForward[header]
124}
125
khenaidooab1f7bd2019-11-14 14:00:27 -0500126// WebsocketProxy attempts to expose the underlying handler as a bidi websocket stream with newline-delimited
127// JSON as the content encoding.
128//
129// The HTTP Authorization header is either populated from the Sec-Websocket-Protocol field or by a cookie.
130// The cookie name is specified by the TokenCookieName value.
131//
132// example:
133// Sec-Websocket-Protocol: Bearer, foobar
134// is converted to:
135// Authorization: Bearer foobar
136//
137// Method can be overwritten with the MethodOverrideParam get parameter in the requested URL
138func WebsocketProxy(h http.Handler, opts ...Option) http.Handler {
139 p := &Proxy{
140 h: h,
141 logger: logrus.New(),
142 methodOverrideParam: MethodOverrideParam,
143 tokenCookieName: TokenCookieName,
khenaidood948f772021-08-11 17:49:24 -0400144 headerForwarder: defaultHeaderForwarder,
khenaidooab1f7bd2019-11-14 14:00:27 -0500145 }
146 for _, o := range opts {
147 o(p)
148 }
149 return p
150}
151
152// TODO(tmc): allow modification of upgrader settings?
153var upgrader = websocket.Upgrader{
154 ReadBufferSize: 1024,
155 WriteBufferSize: 1024,
156 CheckOrigin: func(r *http.Request) bool { return true },
157}
158
159func isClosedConnError(err error) bool {
160 str := err.Error()
161 if strings.Contains(str, "use of closed network connection") {
162 return true
163 }
164 return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway)
165}
166
167func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
168 var responseHeader http.Header
169 // If Sec-WebSocket-Protocol starts with "Bearer", respond in kind.
170 // TODO(tmc): consider customizability/extension point here.
171 if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") {
172 responseHeader = http.Header{
173 "Sec-WebSocket-Protocol": []string{"Bearer"},
174 }
175 }
176 conn, err := upgrader.Upgrade(w, r, responseHeader)
177 if err != nil {
178 p.logger.Warnln("error upgrading websocket:", err)
179 return
180 }
181 defer conn.Close()
182
183 ctx, cancelFn := context.WithCancel(context.Background())
184 defer cancelFn()
185
186 requestBodyR, requestBodyW := io.Pipe()
khenaidood948f772021-08-11 17:49:24 -0400187 request, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), requestBodyR)
khenaidooab1f7bd2019-11-14 14:00:27 -0500188 if err != nil {
189 p.logger.Warnln("error preparing request:", err)
190 return
191 }
192 if swsp := r.Header.Get("Sec-WebSocket-Protocol"); swsp != "" {
khenaidood948f772021-08-11 17:49:24 -0400193 request.Header.Set("Authorization", transformSubProtocolHeader(swsp))
194 }
195 for header := range r.Header {
196 if p.headerForwarder(header) {
197 request.Header.Set(header, r.Header.Get(header))
198 }
khenaidooab1f7bd2019-11-14 14:00:27 -0500199 }
200 // If token cookie is present, populate Authorization header from the cookie instead.
201 if cookie, err := r.Cookie(p.tokenCookieName); err == nil {
202 request.Header.Set("Authorization", "Bearer "+cookie.Value)
203 }
204 if m := r.URL.Query().Get(p.methodOverrideParam); m != "" {
205 request.Method = m
206 }
207
208 if p.requestMutator != nil {
209 request = p.requestMutator(r, request)
210 }
211
212 responseBodyR, responseBodyW := io.Pipe()
213 response := newInMemoryResponseWriter(responseBodyW)
214 go func() {
215 <-ctx.Done()
216 p.logger.Debugln("closing pipes")
217 requestBodyW.CloseWithError(io.EOF)
218 responseBodyW.CloseWithError(io.EOF)
219 response.closed <- true
220 }()
221
222 go func() {
223 defer cancelFn()
224 p.h.ServeHTTP(response, request)
225 }()
226
227 // read loop -- take messages from websocket and write to http request
228 go func() {
khenaidood948f772021-08-11 17:49:24 -0400229 if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
230 conn.SetReadDeadline(time.Now().Add(p.pongWait))
231 conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil })
232 }
khenaidooab1f7bd2019-11-14 14:00:27 -0500233 defer func() {
234 cancelFn()
235 }()
236 for {
237 select {
238 case <-ctx.Done():
239 p.logger.Debugln("read loop done")
240 return
241 default:
242 }
243 p.logger.Debugln("[read] reading from socket.")
244 _, payload, err := conn.ReadMessage()
245 if err != nil {
246 if isClosedConnError(err) {
247 p.logger.Debugln("[read] websocket closed:", err)
248 return
249 }
250 p.logger.Warnln("error reading websocket message:", err)
251 return
252 }
253 p.logger.Debugln("[read] read payload:", string(payload))
254 p.logger.Debugln("[read] writing to requestBody:")
255 n, err := requestBodyW.Write(payload)
256 requestBodyW.Write([]byte("\n"))
257 p.logger.Debugln("[read] wrote to requestBody", n)
258 if err != nil {
259 p.logger.Warnln("[read] error writing message to upstream http server:", err)
260 return
261 }
262 }
263 }()
khenaidood948f772021-08-11 17:49:24 -0400264 // ping write loop
265 if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
266 go func() {
267 ticker := time.NewTicker(p.pingInterval)
268 defer func() {
269 ticker.Stop()
270 conn.Close()
271 }()
272 for {
273 select {
274 case <-ctx.Done():
275 p.logger.Debugln("ping loop done")
276 return
277 case <-ticker.C:
278 conn.SetWriteDeadline(time.Now().Add(p.pingWait))
279 if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
280 return
281 }
282 }
283 }
284 }()
285 }
khenaidooab1f7bd2019-11-14 14:00:27 -0500286 // write loop -- take messages from response and write to websocket
287 scanner := bufio.NewScanner(responseBodyR)
khenaidood948f772021-08-11 17:49:24 -0400288
289 // if maxRespBodyBufferSize has been specified, use custom buffer for scanner
290 var scannerBuf []byte
291 if p.maxRespBodyBufferBytes > 0 {
292 scannerBuf = make([]byte, 0, 64*1024)
293 scanner.Buffer(scannerBuf, p.maxRespBodyBufferBytes)
294 }
295
khenaidooab1f7bd2019-11-14 14:00:27 -0500296 for scanner.Scan() {
297 if len(scanner.Bytes()) == 0 {
298 p.logger.Warnln("[write] empty scan", scanner.Err())
299 continue
300 }
301 p.logger.Debugln("[write] scanned", scanner.Text())
302 if err = conn.WriteMessage(websocket.TextMessage, scanner.Bytes()); err != nil {
303 p.logger.Warnln("[write] error writing websocket message:", err)
304 return
305 }
306 }
307 if err := scanner.Err(); err != nil {
308 p.logger.Warnln("scanner err:", err)
309 }
310}
311
312type inMemoryResponseWriter struct {
313 io.Writer
314 header http.Header
315 code int
316 closed chan bool
317}
318
319func newInMemoryResponseWriter(w io.Writer) *inMemoryResponseWriter {
320 return &inMemoryResponseWriter{
321 Writer: w,
322 header: http.Header{},
323 closed: make(chan bool, 1),
324 }
325}
326
khenaidood948f772021-08-11 17:49:24 -0400327// IE and Edge do not delimit Sec-WebSocket-Protocol strings with spaces
328func transformSubProtocolHeader(header string) string {
329 tokens := strings.SplitN(header, "Bearer,", 2)
330
331 if len(tokens) < 2 {
332 return ""
333 }
334
335 return fmt.Sprintf("Bearer %v", strings.Trim(tokens[1], " "))
336}
337
khenaidooab1f7bd2019-11-14 14:00:27 -0500338func (w *inMemoryResponseWriter) Write(b []byte) (int, error) {
339 return w.Writer.Write(b)
340}
341func (w *inMemoryResponseWriter) Header() http.Header {
342 return w.header
343}
344func (w *inMemoryResponseWriter) WriteHeader(code int) {
345 w.code = code
346}
347func (w *inMemoryResponseWriter) CloseNotify() <-chan bool {
348 return w.closed
349}
350func (w *inMemoryResponseWriter) Flush() {}