| package wsproxy |
| |
| import ( |
| "bufio" |
| "fmt" |
| "io" |
| "net/http" |
| "strings" |
| "time" |
| |
| "github.com/gorilla/websocket" |
| "github.com/sirupsen/logrus" |
| "golang.org/x/net/context" |
| ) |
| |
| // MethodOverrideParam defines the special URL parameter that is translated into the subsequent proxied streaming http request's method. |
| // |
| // Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters. |
| var MethodOverrideParam = "method" |
| |
| // TokenCookieName defines the cookie name that is translated to an 'Authorization: Bearer' header in the streaming http request's headers. |
| // |
| // Deprecated: it is preferable to use the Options parameters to WebSocketProxy to supply parameters. |
| var TokenCookieName = "token" |
| |
| // RequestMutatorFunc can supply an alternate outgoing request. |
| type RequestMutatorFunc func(incoming *http.Request, outgoing *http.Request) *http.Request |
| |
| // Proxy provides websocket transport upgrade to compatible endpoints. |
| type Proxy struct { |
| h http.Handler |
| logger Logger |
| maxRespBodyBufferBytes int |
| methodOverrideParam string |
| tokenCookieName string |
| requestMutator RequestMutatorFunc |
| headerForwarder func(header string) bool |
| pingInterval time.Duration |
| pingWait time.Duration |
| pongWait time.Duration |
| } |
| |
| // Logger collects log messages. |
| type Logger interface { |
| Warnln(...interface{}) |
| Debugln(...interface{}) |
| } |
| |
| func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { |
| if !websocket.IsWebSocketUpgrade(r) { |
| p.h.ServeHTTP(w, r) |
| return |
| } |
| p.proxy(w, r) |
| } |
| |
| // Option allows customization of the proxy. |
| type Option func(*Proxy) |
| |
| // WithMaxRespBodyBufferSize allows specification of a custom size for the |
| // buffer used while reading the response body. By default, the bufio.Scanner |
| // used to read the response body sets the maximum token size to MaxScanTokenSize. |
| func WithMaxRespBodyBufferSize(nBytes int) Option { |
| return func(p *Proxy) { |
| p.maxRespBodyBufferBytes = nBytes |
| } |
| } |
| |
| // WithMethodParamOverride allows specification of the special http parameter that is used in the proxied streaming request. |
| func WithMethodParamOverride(param string) Option { |
| return func(p *Proxy) { |
| p.methodOverrideParam = param |
| } |
| } |
| |
| // WithTokenCookieName allows specification of the cookie that is supplied as an upstream 'Authorization: Bearer' http header. |
| func WithTokenCookieName(param string) Option { |
| return func(p *Proxy) { |
| p.tokenCookieName = param |
| } |
| } |
| |
| // WithRequestMutator allows a custom RequestMutatorFunc to be supplied. |
| func WithRequestMutator(fn RequestMutatorFunc) Option { |
| return func(p *Proxy) { |
| p.requestMutator = fn |
| } |
| } |
| |
| // WithForwardedHeaders allows controlling which headers are forwarded. |
| func WithForwardedHeaders(fn func(header string) bool) Option { |
| return func(p *Proxy) { |
| p.headerForwarder = fn |
| } |
| } |
| |
| // WithLogger allows a custom FieldLogger to be supplied |
| func WithLogger(logger Logger) Option { |
| return func(p *Proxy) { |
| p.logger = logger |
| } |
| } |
| |
| // WithPingControl allows specification of ping pong control. The interval |
| // parameter specifies the pingInterval between pings. The allowed wait time |
| // for a pong response is (pingInterval * 10) / 9. |
| func WithPingControl(interval time.Duration) Option { |
| return func(proxy *Proxy) { |
| proxy.pingInterval = interval |
| proxy.pongWait = (interval * 10) / 9 |
| proxy.pingWait = proxy.pongWait / 6 |
| } |
| } |
| |
| var defaultHeadersToForward = map[string]bool{ |
| "Origin": true, |
| "origin": true, |
| "Referer": true, |
| "referer": true, |
| } |
| |
| func defaultHeaderForwarder(header string) bool { |
| return defaultHeadersToForward[header] |
| } |
| |
| // WebsocketProxy attempts to expose the underlying handler as a bidi websocket stream with newline-delimited |
| // JSON as the content encoding. |
| // |
| // The HTTP Authorization header is either populated from the Sec-Websocket-Protocol field or by a cookie. |
| // The cookie name is specified by the TokenCookieName value. |
| // |
| // example: |
| // Sec-Websocket-Protocol: Bearer, foobar |
| // is converted to: |
| // Authorization: Bearer foobar |
| // |
| // Method can be overwritten with the MethodOverrideParam get parameter in the requested URL |
| func WebsocketProxy(h http.Handler, opts ...Option) http.Handler { |
| p := &Proxy{ |
| h: h, |
| logger: logrus.New(), |
| methodOverrideParam: MethodOverrideParam, |
| tokenCookieName: TokenCookieName, |
| headerForwarder: defaultHeaderForwarder, |
| } |
| for _, o := range opts { |
| o(p) |
| } |
| return p |
| } |
| |
| // TODO(tmc): allow modification of upgrader settings? |
| var upgrader = websocket.Upgrader{ |
| ReadBufferSize: 1024, |
| WriteBufferSize: 1024, |
| CheckOrigin: func(r *http.Request) bool { return true }, |
| } |
| |
| func isClosedConnError(err error) bool { |
| str := err.Error() |
| if strings.Contains(str, "use of closed network connection") { |
| return true |
| } |
| return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) |
| } |
| |
| func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) { |
| var responseHeader http.Header |
| // If Sec-WebSocket-Protocol starts with "Bearer", respond in kind. |
| // TODO(tmc): consider customizability/extension point here. |
| if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") { |
| responseHeader = http.Header{ |
| "Sec-WebSocket-Protocol": []string{"Bearer"}, |
| } |
| } |
| conn, err := upgrader.Upgrade(w, r, responseHeader) |
| if err != nil { |
| p.logger.Warnln("error upgrading websocket:", err) |
| return |
| } |
| defer conn.Close() |
| |
| ctx, cancelFn := context.WithCancel(context.Background()) |
| defer cancelFn() |
| |
| requestBodyR, requestBodyW := io.Pipe() |
| request, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), requestBodyR) |
| if err != nil { |
| p.logger.Warnln("error preparing request:", err) |
| return |
| } |
| if swsp := r.Header.Get("Sec-WebSocket-Protocol"); swsp != "" { |
| request.Header.Set("Authorization", transformSubProtocolHeader(swsp)) |
| } |
| for header := range r.Header { |
| if p.headerForwarder(header) { |
| request.Header.Set(header, r.Header.Get(header)) |
| } |
| } |
| // If token cookie is present, populate Authorization header from the cookie instead. |
| if cookie, err := r.Cookie(p.tokenCookieName); err == nil { |
| request.Header.Set("Authorization", "Bearer "+cookie.Value) |
| } |
| if m := r.URL.Query().Get(p.methodOverrideParam); m != "" { |
| request.Method = m |
| } |
| |
| if p.requestMutator != nil { |
| request = p.requestMutator(r, request) |
| } |
| |
| responseBodyR, responseBodyW := io.Pipe() |
| response := newInMemoryResponseWriter(responseBodyW) |
| go func() { |
| <-ctx.Done() |
| p.logger.Debugln("closing pipes") |
| requestBodyW.CloseWithError(io.EOF) |
| responseBodyW.CloseWithError(io.EOF) |
| response.closed <- true |
| }() |
| |
| go func() { |
| defer cancelFn() |
| p.h.ServeHTTP(response, request) |
| }() |
| |
| // read loop -- take messages from websocket and write to http request |
| go func() { |
| if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { |
| conn.SetReadDeadline(time.Now().Add(p.pongWait)) |
| conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil }) |
| } |
| defer func() { |
| cancelFn() |
| }() |
| for { |
| select { |
| case <-ctx.Done(): |
| p.logger.Debugln("read loop done") |
| return |
| default: |
| } |
| p.logger.Debugln("[read] reading from socket.") |
| _, payload, err := conn.ReadMessage() |
| if err != nil { |
| if isClosedConnError(err) { |
| p.logger.Debugln("[read] websocket closed:", err) |
| return |
| } |
| p.logger.Warnln("error reading websocket message:", err) |
| return |
| } |
| p.logger.Debugln("[read] read payload:", string(payload)) |
| p.logger.Debugln("[read] writing to requestBody:") |
| n, err := requestBodyW.Write(payload) |
| requestBodyW.Write([]byte("\n")) |
| p.logger.Debugln("[read] wrote to requestBody", n) |
| if err != nil { |
| p.logger.Warnln("[read] error writing message to upstream http server:", err) |
| return |
| } |
| } |
| }() |
| // ping write loop |
| if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 { |
| go func() { |
| ticker := time.NewTicker(p.pingInterval) |
| defer func() { |
| ticker.Stop() |
| conn.Close() |
| }() |
| for { |
| select { |
| case <-ctx.Done(): |
| p.logger.Debugln("ping loop done") |
| return |
| case <-ticker.C: |
| conn.SetWriteDeadline(time.Now().Add(p.pingWait)) |
| if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { |
| return |
| } |
| } |
| } |
| }() |
| } |
| // write loop -- take messages from response and write to websocket |
| scanner := bufio.NewScanner(responseBodyR) |
| |
| // if maxRespBodyBufferSize has been specified, use custom buffer for scanner |
| var scannerBuf []byte |
| if p.maxRespBodyBufferBytes > 0 { |
| scannerBuf = make([]byte, 0, 64*1024) |
| scanner.Buffer(scannerBuf, p.maxRespBodyBufferBytes) |
| } |
| |
| for scanner.Scan() { |
| if len(scanner.Bytes()) == 0 { |
| p.logger.Warnln("[write] empty scan", scanner.Err()) |
| continue |
| } |
| p.logger.Debugln("[write] scanned", scanner.Text()) |
| if err = conn.WriteMessage(websocket.TextMessage, scanner.Bytes()); err != nil { |
| p.logger.Warnln("[write] error writing websocket message:", err) |
| return |
| } |
| } |
| if err := scanner.Err(); err != nil { |
| p.logger.Warnln("scanner err:", err) |
| } |
| } |
| |
| type inMemoryResponseWriter struct { |
| io.Writer |
| header http.Header |
| code int |
| closed chan bool |
| } |
| |
| func newInMemoryResponseWriter(w io.Writer) *inMemoryResponseWriter { |
| return &inMemoryResponseWriter{ |
| Writer: w, |
| header: http.Header{}, |
| closed: make(chan bool, 1), |
| } |
| } |
| |
| // IE and Edge do not delimit Sec-WebSocket-Protocol strings with spaces |
| func transformSubProtocolHeader(header string) string { |
| tokens := strings.SplitN(header, "Bearer,", 2) |
| |
| if len(tokens) < 2 { |
| return "" |
| } |
| |
| return fmt.Sprintf("Bearer %v", strings.Trim(tokens[1], " ")) |
| } |
| |
| func (w *inMemoryResponseWriter) Write(b []byte) (int, error) { |
| return w.Writer.Write(b) |
| } |
| func (w *inMemoryResponseWriter) Header() http.Header { |
| return w.header |
| } |
| func (w *inMemoryResponseWriter) WriteHeader(code int) { |
| w.code = code |
| } |
| func (w *inMemoryResponseWriter) CloseNotify() <-chan bool { |
| return w.closed |
| } |
| func (w *inMemoryResponseWriter) Flush() {} |