sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package websocket |
| 6 | |
| 7 | import ( |
| 8 | "bufio" |
| 9 | "errors" |
| 10 | "io" |
| 11 | "net/http" |
| 12 | "net/url" |
| 13 | "strings" |
| 14 | "time" |
| 15 | ) |
| 16 | |
| 17 | // HandshakeError describes an error with the handshake from the peer. |
| 18 | type HandshakeError struct { |
| 19 | message string |
| 20 | } |
| 21 | |
| 22 | func (e HandshakeError) Error() string { return e.message } |
| 23 | |
| 24 | // Upgrader specifies parameters for upgrading an HTTP connection to a |
| 25 | // WebSocket connection. |
| 26 | type Upgrader struct { |
| 27 | // HandshakeTimeout specifies the duration for the handshake to complete. |
| 28 | HandshakeTimeout time.Duration |
| 29 | |
| 30 | // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer |
| 31 | // size is zero, then buffers allocated by the HTTP server are used. The |
| 32 | // I/O buffer sizes do not limit the size of the messages that can be sent |
| 33 | // or received. |
| 34 | ReadBufferSize, WriteBufferSize int |
| 35 | |
| 36 | // WriteBufferPool is a pool of buffers for write operations. If the value |
| 37 | // is not set, then write buffers are allocated to the connection for the |
| 38 | // lifetime of the connection. |
| 39 | // |
| 40 | // A pool is most useful when the application has a modest volume of writes |
| 41 | // across a large number of connections. |
| 42 | // |
| 43 | // Applications should use a single pool for each unique value of |
| 44 | // WriteBufferSize. |
| 45 | WriteBufferPool BufferPool |
| 46 | |
| 47 | // Subprotocols specifies the server's supported protocols in order of |
| 48 | // preference. If this field is not nil, then the Upgrade method negotiates a |
| 49 | // subprotocol by selecting the first match in this list with a protocol |
| 50 | // requested by the client. If there's no match, then no protocol is |
| 51 | // negotiated (the Sec-Websocket-Protocol header is not included in the |
| 52 | // handshake response). |
| 53 | Subprotocols []string |
| 54 | |
| 55 | // Error specifies the function for generating HTTP error responses. If Error |
| 56 | // is nil, then http.Error is used to generate the HTTP response. |
| 57 | Error func(w http.ResponseWriter, r *http.Request, status int, reason error) |
| 58 | |
| 59 | // CheckOrigin returns true if the request Origin header is acceptable. If |
| 60 | // CheckOrigin is nil, then a safe default is used: return false if the |
| 61 | // Origin request header is present and the origin host is not equal to |
| 62 | // request Host header. |
| 63 | // |
| 64 | // A CheckOrigin function should carefully validate the request origin to |
| 65 | // prevent cross-site request forgery. |
| 66 | CheckOrigin func(r *http.Request) bool |
| 67 | |
| 68 | // EnableCompression specify if the server should attempt to negotiate per |
| 69 | // message compression (RFC 7692). Setting this value to true does not |
| 70 | // guarantee that compression will be supported. Currently only "no context |
| 71 | // takeover" modes are supported. |
| 72 | EnableCompression bool |
| 73 | } |
| 74 | |
| 75 | func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { |
| 76 | err := HandshakeError{reason} |
| 77 | if u.Error != nil { |
| 78 | u.Error(w, r, status, err) |
| 79 | } else { |
| 80 | w.Header().Set("Sec-Websocket-Version", "13") |
| 81 | http.Error(w, http.StatusText(status), status) |
| 82 | } |
| 83 | return nil, err |
| 84 | } |
| 85 | |
| 86 | // checkSameOrigin returns true if the origin is not set or is equal to the request host. |
| 87 | func checkSameOrigin(r *http.Request) bool { |
| 88 | origin := r.Header["Origin"] |
| 89 | if len(origin) == 0 { |
| 90 | return true |
| 91 | } |
| 92 | u, err := url.Parse(origin[0]) |
| 93 | if err != nil { |
| 94 | return false |
| 95 | } |
| 96 | return equalASCIIFold(u.Host, r.Host) |
| 97 | } |
| 98 | |
| 99 | func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { |
| 100 | if u.Subprotocols != nil { |
| 101 | clientProtocols := Subprotocols(r) |
| 102 | for _, serverProtocol := range u.Subprotocols { |
| 103 | for _, clientProtocol := range clientProtocols { |
| 104 | if clientProtocol == serverProtocol { |
| 105 | return clientProtocol |
| 106 | } |
| 107 | } |
| 108 | } |
| 109 | } else if responseHeader != nil { |
| 110 | return responseHeader.Get("Sec-Websocket-Protocol") |
| 111 | } |
| 112 | return "" |
| 113 | } |
| 114 | |
| 115 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. |
| 116 | // |
| 117 | // The responseHeader is included in the response to the client's upgrade |
| 118 | // request. Use the responseHeader to specify cookies (Set-Cookie) and the |
| 119 | // application negotiated subprotocol (Sec-WebSocket-Protocol). |
| 120 | // |
| 121 | // If the upgrade fails, then Upgrade replies to the client with an HTTP error |
| 122 | // response. |
| 123 | func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { |
| 124 | const badHandshake = "websocket: the client is not using the websocket protocol: " |
| 125 | |
| 126 | if !tokenListContainsValue(r.Header, "Connection", "upgrade") { |
| 127 | return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") |
| 128 | } |
| 129 | |
| 130 | if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { |
| 131 | return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") |
| 132 | } |
| 133 | |
| 134 | if r.Method != "GET" { |
| 135 | return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") |
| 136 | } |
| 137 | |
| 138 | if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { |
| 139 | return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") |
| 140 | } |
| 141 | |
| 142 | if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { |
| 143 | return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") |
| 144 | } |
| 145 | |
| 146 | checkOrigin := u.CheckOrigin |
| 147 | if checkOrigin == nil { |
| 148 | checkOrigin = checkSameOrigin |
| 149 | } |
| 150 | if !checkOrigin(r) { |
| 151 | return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") |
| 152 | } |
| 153 | |
| 154 | challengeKey := r.Header.Get("Sec-Websocket-Key") |
| 155 | if challengeKey == "" { |
| 156 | return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank") |
| 157 | } |
| 158 | |
| 159 | subprotocol := u.selectSubprotocol(r, responseHeader) |
| 160 | |
| 161 | // Negotiate PMCE |
| 162 | var compress bool |
| 163 | if u.EnableCompression { |
| 164 | for _, ext := range parseExtensions(r.Header) { |
| 165 | if ext[""] != "permessage-deflate" { |
| 166 | continue |
| 167 | } |
| 168 | compress = true |
| 169 | break |
| 170 | } |
| 171 | } |
| 172 | |
| 173 | h, ok := w.(http.Hijacker) |
| 174 | if !ok { |
| 175 | return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") |
| 176 | } |
| 177 | var brw *bufio.ReadWriter |
| 178 | netConn, brw, err := h.Hijack() |
| 179 | if err != nil { |
| 180 | return u.returnError(w, r, http.StatusInternalServerError, err.Error()) |
| 181 | } |
| 182 | |
| 183 | if brw.Reader.Buffered() > 0 { |
| 184 | netConn.Close() |
| 185 | return nil, errors.New("websocket: client sent data before handshake is complete") |
| 186 | } |
| 187 | |
| 188 | var br *bufio.Reader |
| 189 | if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { |
| 190 | // Reuse hijacked buffered reader as connection reader. |
| 191 | br = brw.Reader |
| 192 | } |
| 193 | |
| 194 | buf := bufioWriterBuffer(netConn, brw.Writer) |
| 195 | |
| 196 | var writeBuf []byte |
| 197 | if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { |
| 198 | // Reuse hijacked write buffer as connection buffer. |
| 199 | writeBuf = buf |
| 200 | } |
| 201 | |
| 202 | c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) |
| 203 | c.subprotocol = subprotocol |
| 204 | |
| 205 | if compress { |
| 206 | c.newCompressionWriter = compressNoContextTakeover |
| 207 | c.newDecompressionReader = decompressNoContextTakeover |
| 208 | } |
| 209 | |
| 210 | // Use larger of hijacked buffer and connection write buffer for header. |
| 211 | p := buf |
| 212 | if len(c.writeBuf) > len(p) { |
| 213 | p = c.writeBuf |
| 214 | } |
| 215 | p = p[:0] |
| 216 | |
| 217 | p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) |
| 218 | p = append(p, computeAcceptKey(challengeKey)...) |
| 219 | p = append(p, "\r\n"...) |
| 220 | if c.subprotocol != "" { |
| 221 | p = append(p, "Sec-WebSocket-Protocol: "...) |
| 222 | p = append(p, c.subprotocol...) |
| 223 | p = append(p, "\r\n"...) |
| 224 | } |
| 225 | if compress { |
| 226 | p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) |
| 227 | } |
| 228 | for k, vs := range responseHeader { |
| 229 | if k == "Sec-Websocket-Protocol" { |
| 230 | continue |
| 231 | } |
| 232 | for _, v := range vs { |
| 233 | p = append(p, k...) |
| 234 | p = append(p, ": "...) |
| 235 | for i := 0; i < len(v); i++ { |
| 236 | b := v[i] |
| 237 | if b <= 31 { |
| 238 | // prevent response splitting. |
| 239 | b = ' ' |
| 240 | } |
| 241 | p = append(p, b) |
| 242 | } |
| 243 | p = append(p, "\r\n"...) |
| 244 | } |
| 245 | } |
| 246 | p = append(p, "\r\n"...) |
| 247 | |
| 248 | // Clear deadlines set by HTTP server. |
| 249 | netConn.SetDeadline(time.Time{}) |
| 250 | |
| 251 | if u.HandshakeTimeout > 0 { |
| 252 | netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) |
| 253 | } |
| 254 | if _, err = netConn.Write(p); err != nil { |
| 255 | netConn.Close() |
| 256 | return nil, err |
| 257 | } |
| 258 | if u.HandshakeTimeout > 0 { |
| 259 | netConn.SetWriteDeadline(time.Time{}) |
| 260 | } |
| 261 | |
| 262 | return c, nil |
| 263 | } |
| 264 | |
| 265 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. |
| 266 | // |
| 267 | // Deprecated: Use websocket.Upgrader instead. |
| 268 | // |
| 269 | // Upgrade does not perform origin checking. The application is responsible for |
| 270 | // checking the Origin header before calling Upgrade. An example implementation |
| 271 | // of the same origin policy check is: |
| 272 | // |
| 273 | // if req.Header.Get("Origin") != "http://"+req.Host { |
| 274 | // http.Error(w, "Origin not allowed", http.StatusForbidden) |
| 275 | // return |
| 276 | // } |
| 277 | // |
| 278 | // If the endpoint supports subprotocols, then the application is responsible |
| 279 | // for negotiating the protocol used on the connection. Use the Subprotocols() |
| 280 | // function to get the subprotocols requested by the client. Use the |
| 281 | // Sec-Websocket-Protocol response header to specify the subprotocol selected |
| 282 | // by the application. |
| 283 | // |
| 284 | // The responseHeader is included in the response to the client's upgrade |
| 285 | // request. Use the responseHeader to specify cookies (Set-Cookie) and the |
| 286 | // negotiated subprotocol (Sec-Websocket-Protocol). |
| 287 | // |
| 288 | // The connection buffers IO to the underlying network connection. The |
| 289 | // readBufSize and writeBufSize parameters specify the size of the buffers to |
| 290 | // use. Messages can be larger than the buffers. |
| 291 | // |
| 292 | // If the request is not a valid WebSocket handshake, then Upgrade returns an |
| 293 | // error of type HandshakeError. Applications should handle this error by |
| 294 | // replying to the client with an HTTP error response. |
| 295 | func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { |
| 296 | u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} |
| 297 | u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { |
| 298 | // don't return errors to maintain backwards compatibility |
| 299 | } |
| 300 | u.CheckOrigin = func(r *http.Request) bool { |
| 301 | // allow all connections by default |
| 302 | return true |
| 303 | } |
| 304 | return u.Upgrade(w, r, responseHeader) |
| 305 | } |
| 306 | |
| 307 | // Subprotocols returns the subprotocols requested by the client in the |
| 308 | // Sec-Websocket-Protocol header. |
| 309 | func Subprotocols(r *http.Request) []string { |
| 310 | h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) |
| 311 | if h == "" { |
| 312 | return nil |
| 313 | } |
| 314 | protocols := strings.Split(h, ",") |
| 315 | for i := range protocols { |
| 316 | protocols[i] = strings.TrimSpace(protocols[i]) |
| 317 | } |
| 318 | return protocols |
| 319 | } |
| 320 | |
| 321 | // IsWebSocketUpgrade returns true if the client requested upgrade to the |
| 322 | // WebSocket protocol. |
| 323 | func IsWebSocketUpgrade(r *http.Request) bool { |
| 324 | return tokenListContainsValue(r.Header, "Connection", "upgrade") && |
| 325 | tokenListContainsValue(r.Header, "Upgrade", "websocket") |
| 326 | } |
| 327 | |
| 328 | // bufioReaderSize size returns the size of a bufio.Reader. |
| 329 | func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int { |
| 330 | // This code assumes that peek on a reset reader returns |
| 331 | // bufio.Reader.buf[:0]. |
| 332 | // TODO: Use bufio.Reader.Size() after Go 1.10 |
| 333 | br.Reset(originalReader) |
| 334 | if p, err := br.Peek(0); err == nil { |
| 335 | return cap(p) |
| 336 | } |
| 337 | return 0 |
| 338 | } |
| 339 | |
| 340 | // writeHook is an io.Writer that records the last slice passed to it vio |
| 341 | // io.Writer.Write. |
| 342 | type writeHook struct { |
| 343 | p []byte |
| 344 | } |
| 345 | |
| 346 | func (wh *writeHook) Write(p []byte) (int, error) { |
| 347 | wh.p = p |
| 348 | return len(p), nil |
| 349 | } |
| 350 | |
| 351 | // bufioWriterBuffer grabs the buffer from a bufio.Writer. |
| 352 | func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { |
| 353 | // This code assumes that bufio.Writer.buf[:1] is passed to the |
| 354 | // bufio.Writer's underlying writer. |
| 355 | var wh writeHook |
| 356 | bw.Reset(&wh) |
| 357 | bw.WriteByte(0) |
| 358 | bw.Flush() |
| 359 | |
| 360 | bw.Reset(originalWriter) |
| 361 | |
| 362 | return wh.p[:cap(wh.p)] |
| 363 | } |