blob: 0fca05a008ae627f4f90695082c105489bdcac53 [file] [log] [blame]
khenaidoo59ce9dd2019-11-11 13:05:32 -05001package wsproxy
2
3import (
4 "bufio"
5 "io"
6 "net/http"
7 "strings"
8
9 "github.com/gorilla/websocket"
10 "github.com/sirupsen/logrus"
11 "golang.org/x/net/context"
12)
13
14// MethodOverrideParam defines the special URL parameter that is translated into the subsequent proxied streaming http request's method.
15//
16// Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
17var MethodOverrideParam = "method"
18
19// TokenCookieName defines the cookie name that is translated to an 'Authorization: Bearer' header in the streaming http request's headers.
20//
21// Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters.
22var TokenCookieName = "token"
23
24// RequestMutatorFunc can supply an alternate outgoing request.
25type RequestMutatorFunc func(incoming *http.Request, outgoing *http.Request) *http.Request
26
27// Proxy provides websocket transport upgrade to compatible endpoints.
28type Proxy struct {
29 h http.Handler
30 logger Logger
31 methodOverrideParam string
32 tokenCookieName string
33 requestMutator RequestMutatorFunc
34}
35
36// Logger collects log messages.
37type Logger interface {
38 Warnln(...interface{})
39 Debugln(...interface{})
40}
41
42func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
43 if !websocket.IsWebSocketUpgrade(r) {
44 p.h.ServeHTTP(w, r)
45 return
46 }
47 p.proxy(w, r)
48}
49
50// Option allows customization of the proxy.
51type Option func(*Proxy)
52
53// WithMethodParamOverride allows specification of the special http parameter that is used in the proxied streaming request.
54func WithMethodParamOverride(param string) Option {
55 return func(p *Proxy) {
56 p.methodOverrideParam = param
57 }
58}
59
60// WithTokenCookieName allows specification of the cookie that is supplied as an upstream 'Authorization: Bearer' http header.
61func WithTokenCookieName(param string) Option {
62 return func(p *Proxy) {
63 p.tokenCookieName = param
64 }
65}
66
67// WithRequestMutator allows a custom RequestMutatorFunc to be supplied.
68func WithRequestMutator(fn RequestMutatorFunc) Option {
69 return func(p *Proxy) {
70 p.requestMutator = fn
71 }
72}
73
74// WithLogger allows a custom FieldLogger to be supplied
75func WithLogger(logger Logger) Option {
76 return func(p *Proxy) {
77 p.logger = logger
78 }
79}
80
81// WebsocketProxy attempts to expose the underlying handler as a bidi websocket stream with newline-delimited
82// JSON as the content encoding.
83//
84// The HTTP Authorization header is either populated from the Sec-Websocket-Protocol field or by a cookie.
85// The cookie name is specified by the TokenCookieName value.
86//
87// example:
88// Sec-Websocket-Protocol: Bearer, foobar
89// is converted to:
90// Authorization: Bearer foobar
91//
92// Method can be overwritten with the MethodOverrideParam get parameter in the requested URL
93func WebsocketProxy(h http.Handler, opts ...Option) http.Handler {
94 p := &Proxy{
95 h: h,
96 logger: logrus.New(),
97 methodOverrideParam: MethodOverrideParam,
98 tokenCookieName: TokenCookieName,
99 }
100 for _, o := range opts {
101 o(p)
102 }
103 return p
104}
105
106// TODO(tmc): allow modification of upgrader settings?
107var upgrader = websocket.Upgrader{
108 ReadBufferSize: 1024,
109 WriteBufferSize: 1024,
110 CheckOrigin: func(r *http.Request) bool { return true },
111}
112
113func isClosedConnError(err error) bool {
114 str := err.Error()
115 if strings.Contains(str, "use of closed network connection") {
116 return true
117 }
118 return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway)
119}
120
121func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
122 var responseHeader http.Header
123 // If Sec-WebSocket-Protocol starts with "Bearer", respond in kind.
124 // TODO(tmc): consider customizability/extension point here.
125 if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") {
126 responseHeader = http.Header{
127 "Sec-WebSocket-Protocol": []string{"Bearer"},
128 }
129 }
130 conn, err := upgrader.Upgrade(w, r, responseHeader)
131 if err != nil {
132 p.logger.Warnln("error upgrading websocket:", err)
133 return
134 }
135 defer conn.Close()
136
137 ctx, cancelFn := context.WithCancel(context.Background())
138 defer cancelFn()
139
140 requestBodyR, requestBodyW := io.Pipe()
141 request, err := http.NewRequest(r.Method, r.URL.String(), requestBodyR)
142 if err != nil {
143 p.logger.Warnln("error preparing request:", err)
144 return
145 }
146 if swsp := r.Header.Get("Sec-WebSocket-Protocol"); swsp != "" {
147 request.Header.Set("Authorization", strings.Replace(swsp, "Bearer, ", "Bearer ", 1))
148 }
149 // If token cookie is present, populate Authorization header from the cookie instead.
150 if cookie, err := r.Cookie(p.tokenCookieName); err == nil {
151 request.Header.Set("Authorization", "Bearer "+cookie.Value)
152 }
153 if m := r.URL.Query().Get(p.methodOverrideParam); m != "" {
154 request.Method = m
155 }
156
157 if p.requestMutator != nil {
158 request = p.requestMutator(r, request)
159 }
160
161 responseBodyR, responseBodyW := io.Pipe()
162 response := newInMemoryResponseWriter(responseBodyW)
163 go func() {
164 <-ctx.Done()
165 p.logger.Debugln("closing pipes")
166 requestBodyW.CloseWithError(io.EOF)
167 responseBodyW.CloseWithError(io.EOF)
168 response.closed <- true
169 }()
170
171 go func() {
172 defer cancelFn()
173 p.h.ServeHTTP(response, request)
174 }()
175
176 // read loop -- take messages from websocket and write to http request
177 go func() {
178 defer func() {
179 cancelFn()
180 }()
181 for {
182 select {
183 case <-ctx.Done():
184 p.logger.Debugln("read loop done")
185 return
186 default:
187 }
188 p.logger.Debugln("[read] reading from socket.")
189 _, payload, err := conn.ReadMessage()
190 if err != nil {
191 if isClosedConnError(err) {
192 p.logger.Debugln("[read] websocket closed:", err)
193 return
194 }
195 p.logger.Warnln("error reading websocket message:", err)
196 return
197 }
198 p.logger.Debugln("[read] read payload:", string(payload))
199 p.logger.Debugln("[read] writing to requestBody:")
200 n, err := requestBodyW.Write(payload)
201 requestBodyW.Write([]byte("\n"))
202 p.logger.Debugln("[read] wrote to requestBody", n)
203 if err != nil {
204 p.logger.Warnln("[read] error writing message to upstream http server:", err)
205 return
206 }
207 }
208 }()
209 // write loop -- take messages from response and write to websocket
210 scanner := bufio.NewScanner(responseBodyR)
211 for scanner.Scan() {
212 if len(scanner.Bytes()) == 0 {
213 p.logger.Warnln("[write] empty scan", scanner.Err())
214 continue
215 }
216 p.logger.Debugln("[write] scanned", scanner.Text())
217 if err = conn.WriteMessage(websocket.TextMessage, scanner.Bytes()); err != nil {
218 p.logger.Warnln("[write] error writing websocket message:", err)
219 return
220 }
221 }
222 if err := scanner.Err(); err != nil {
223 p.logger.Warnln("scanner err:", err)
224 }
225}
226
227type inMemoryResponseWriter struct {
228 io.Writer
229 header http.Header
230 code int
231 closed chan bool
232}
233
234func newInMemoryResponseWriter(w io.Writer) *inMemoryResponseWriter {
235 return &inMemoryResponseWriter{
236 Writer: w,
237 header: http.Header{},
238 closed: make(chan bool, 1),
239 }
240}
241
242func (w *inMemoryResponseWriter) Write(b []byte) (int, error) {
243 return w.Writer.Write(b)
244}
245func (w *inMemoryResponseWriter) Header() http.Header {
246 return w.header
247}
248func (w *inMemoryResponseWriter) WriteHeader(code int) {
249 w.code = code
250}
251func (w *inMemoryResponseWriter) CloseNotify() <-chan bool {
252 return w.closed
253}
254func (w *inMemoryResponseWriter) Flush() {}