blob: ada4c4ec5027eb03f442499469cc7753d1254f7f [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001package wsproxy
2
3import (
4 "bufio"
5 "fmt"
6 "io"
7 "net/http"
8 "strings"
9
10 "github.com/gorilla/websocket"
11 "github.com/sirupsen/logrus"
12 "golang.org/x/net/context"
13)
14
15// MethodOverrideParam defines the special URL parameter that is translated into the subsequent proxied streaming http request's method.
16//
17// Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
18var MethodOverrideParam = "method"
19
20// TokenCookieName defines the cookie name that is translated to an 'Authorization: Bearer' header in the streaming http request's headers.
21//
22// Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
23var TokenCookieName = "token"
24
25// RequestMutatorFunc can supply an alternate outgoing request.
26type RequestMutatorFunc func(incoming *http.Request, outgoing *http.Request) *http.Request
27
28// Proxy provides websocket transport upgrade to compatible endpoints.
29type Proxy struct {
30 h http.Handler
31 logger Logger
32 methodOverrideParam string
33 tokenCookieName string
34 requestMutator RequestMutatorFunc
35 headerForwarder func(header string) bool
36}
37
38// Logger collects log messages.
39type Logger interface {
40 Warnln(...interface{})
41 Debugln(...interface{})
42}
43
44func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
45 if !websocket.IsWebSocketUpgrade(r) {
46 p.h.ServeHTTP(w, r)
47 return
48 }
49 p.proxy(w, r)
50}
51
52// Option allows customization of the proxy.
53type Option func(*Proxy)
54
55// WithMethodParamOverride allows specification of the special http parameter that is used in the proxied streaming request.
56func WithMethodParamOverride(param string) Option {
57 return func(p *Proxy) {
58 p.methodOverrideParam = param
59 }
60}
61
62// WithTokenCookieName allows specification of the cookie that is supplied as an upstream 'Authorization: Bearer' http header.
63func WithTokenCookieName(param string) Option {
64 return func(p *Proxy) {
65 p.tokenCookieName = param
66 }
67}
68
69// WithRequestMutator allows a custom RequestMutatorFunc to be supplied.
70func WithRequestMutator(fn RequestMutatorFunc) Option {
71 return func(p *Proxy) {
72 p.requestMutator = fn
73 }
74}
75
76// WithForwardedHeaders allows controlling which headers are forwarded.
77func WithForwardedHeaders(fn func(header string) bool) Option {
78 return func(p *Proxy) {
79 p.headerForwarder = fn
80 }
81}
82
83// WithLogger allows a custom FieldLogger to be supplied
84func WithLogger(logger Logger) Option {
85 return func(p *Proxy) {
86 p.logger = logger
87 }
88}
89
90var defaultHeadersToForward = map[string]bool{
91 "Origin": true,
92 "origin": true,
93 "Referer": true,
94 "referer": true,
95}
96
97func defaultHeaderForwarder(header string) bool {
98 return defaultHeadersToForward[header]
99}
100
101// WebsocketProxy attempts to expose the underlying handler as a bidi websocket stream with newline-delimited
102// JSON as the content encoding.
103//
104// The HTTP Authorization header is either populated from the Sec-Websocket-Protocol field or by a cookie.
105// The cookie name is specified by the TokenCookieName value.
106//
107// example:
108// Sec-Websocket-Protocol: Bearer, foobar
109// is converted to:
110// Authorization: Bearer foobar
111//
112// Method can be overwritten with the MethodOverrideParam get parameter in the requested URL
113func WebsocketProxy(h http.Handler, opts ...Option) http.Handler {
114 p := &Proxy{
115 h: h,
116 logger: logrus.New(),
117 methodOverrideParam: MethodOverrideParam,
118 tokenCookieName: TokenCookieName,
119 headerForwarder: defaultHeaderForwarder,
120 }
121 for _, o := range opts {
122 o(p)
123 }
124 return p
125}
126
127// TODO(tmc): allow modification of upgrader settings?
128var upgrader = websocket.Upgrader{
129 ReadBufferSize: 1024,
130 WriteBufferSize: 1024,
131 CheckOrigin: func(r *http.Request) bool { return true },
132}
133
134func isClosedConnError(err error) bool {
135 str := err.Error()
136 if strings.Contains(str, "use of closed network connection") {
137 return true
138 }
139 return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway)
140}
141
142func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
143 var responseHeader http.Header
144 // If Sec-WebSocket-Protocol starts with "Bearer", respond in kind.
145 // TODO(tmc): consider customizability/extension point here.
146 if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") {
147 responseHeader = http.Header{
148 "Sec-WebSocket-Protocol": []string{"Bearer"},
149 }
150 }
151 conn, err := upgrader.Upgrade(w, r, responseHeader)
152 if err != nil {
153 p.logger.Warnln("error upgrading websocket:", err)
154 return
155 }
156 defer conn.Close()
157
158 ctx, cancelFn := context.WithCancel(context.Background())
159 defer cancelFn()
160
161 requestBodyR, requestBodyW := io.Pipe()
162 request, err := http.NewRequest(r.Method, r.URL.String(), requestBodyR)
163 if err != nil {
164 p.logger.Warnln("error preparing request:", err)
165 return
166 }
167 if swsp := r.Header.Get("Sec-WebSocket-Protocol"); swsp != "" {
168 request.Header.Set("Authorization", transformSubProtocolHeader(swsp))
169 }
170 for header := range r.Header {
171 if p.headerForwarder(header) {
172 request.Header.Set(header, r.Header.Get(header))
173 }
174 }
175 // If token cookie is present, populate Authorization header from the cookie instead.
176 if cookie, err := r.Cookie(p.tokenCookieName); err == nil {
177 request.Header.Set("Authorization", "Bearer "+cookie.Value)
178 }
179 if m := r.URL.Query().Get(p.methodOverrideParam); m != "" {
180 request.Method = m
181 }
182
183 if p.requestMutator != nil {
184 request = p.requestMutator(r, request)
185 }
186
187 responseBodyR, responseBodyW := io.Pipe()
188 response := newInMemoryResponseWriter(responseBodyW)
189 go func() {
190 <-ctx.Done()
191 p.logger.Debugln("closing pipes")
192 requestBodyW.CloseWithError(io.EOF)
193 responseBodyW.CloseWithError(io.EOF)
194 response.closed <- true
195 }()
196
197 go func() {
198 defer cancelFn()
199 p.h.ServeHTTP(response, request)
200 }()
201
202 // read loop -- take messages from websocket and write to http request
203 go func() {
204 defer func() {
205 cancelFn()
206 }()
207 for {
208 select {
209 case <-ctx.Done():
210 p.logger.Debugln("read loop done")
211 return
212 default:
213 }
214 p.logger.Debugln("[read] reading from socket.")
215 _, payload, err := conn.ReadMessage()
216 if err != nil {
217 if isClosedConnError(err) {
218 p.logger.Debugln("[read] websocket closed:", err)
219 return
220 }
221 p.logger.Warnln("error reading websocket message:", err)
222 return
223 }
224 p.logger.Debugln("[read] read payload:", string(payload))
225 p.logger.Debugln("[read] writing to requestBody:")
226 n, err := requestBodyW.Write(payload)
227 requestBodyW.Write([]byte("\n"))
228 p.logger.Debugln("[read] wrote to requestBody", n)
229 if err != nil {
230 p.logger.Warnln("[read] error writing message to upstream http server:", err)
231 return
232 }
233 }
234 }()
235 // write loop -- take messages from response and write to websocket
236 scanner := bufio.NewScanner(responseBodyR)
237 for scanner.Scan() {
238 if len(scanner.Bytes()) == 0 {
239 p.logger.Warnln("[write] empty scan", scanner.Err())
240 continue
241 }
242 p.logger.Debugln("[write] scanned", scanner.Text())
243 if err = conn.WriteMessage(websocket.TextMessage, scanner.Bytes()); err != nil {
244 p.logger.Warnln("[write] error writing websocket message:", err)
245 return
246 }
247 }
248 if err := scanner.Err(); err != nil {
249 p.logger.Warnln("scanner err:", err)
250 }
251}
252
253type inMemoryResponseWriter struct {
254 io.Writer
255 header http.Header
256 code int
257 closed chan bool
258}
259
260func newInMemoryResponseWriter(w io.Writer) *inMemoryResponseWriter {
261 return &inMemoryResponseWriter{
262 Writer: w,
263 header: http.Header{},
264 closed: make(chan bool, 1),
265 }
266}
267
268// IE and Edge do not delimit Sec-WebSocket-Protocol strings with spaces
269func transformSubProtocolHeader(header string) string {
270 tokens := strings.SplitN(header, "Bearer,", 2)
271
272 if len(tokens) < 2 {
273 return ""
274 }
275
276 return fmt.Sprintf("Bearer %v", strings.Trim(tokens[1], " "))
277}
278
279func (w *inMemoryResponseWriter) Write(b []byte) (int, error) {
280 return w.Writer.Write(b)
281}
282func (w *inMemoryResponseWriter) Header() http.Header {
283 return w.header
284}
285func (w *inMemoryResponseWriter) WriteHeader(code int) {
286 w.code = code
287}
288func (w *inMemoryResponseWriter) CloseNotify() <-chan bool {
289 return w.closed
290}
291func (w *inMemoryResponseWriter) Flush() {}