David K. Bainbridge | 215e024 | 2017-09-05 23:18:24 -0700 | [diff] [blame] | 1 | // Copyright 2009 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package websocket |
| 6 | |
| 7 | import ( |
| 8 | "bufio" |
| 9 | "fmt" |
| 10 | "io" |
| 11 | "net/http" |
| 12 | ) |
| 13 | |
| 14 | func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) { |
| 15 | var hs serverHandshaker = &hybiServerHandshaker{Config: config} |
| 16 | code, err := hs.ReadHandshake(buf.Reader, req) |
| 17 | if err == ErrBadWebSocketVersion { |
| 18 | fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
| 19 | fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion) |
| 20 | buf.WriteString("\r\n") |
| 21 | buf.WriteString(err.Error()) |
| 22 | buf.Flush() |
| 23 | return |
| 24 | } |
| 25 | if err != nil { |
| 26 | fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
| 27 | buf.WriteString("\r\n") |
| 28 | buf.WriteString(err.Error()) |
| 29 | buf.Flush() |
| 30 | return |
| 31 | } |
| 32 | if handshake != nil { |
| 33 | err = handshake(config, req) |
| 34 | if err != nil { |
| 35 | code = http.StatusForbidden |
| 36 | fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
| 37 | buf.WriteString("\r\n") |
| 38 | buf.Flush() |
| 39 | return |
| 40 | } |
| 41 | } |
| 42 | err = hs.AcceptHandshake(buf.Writer) |
| 43 | if err != nil { |
| 44 | code = http.StatusBadRequest |
| 45 | fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) |
| 46 | buf.WriteString("\r\n") |
| 47 | buf.Flush() |
| 48 | return |
| 49 | } |
| 50 | conn = hs.NewServerConn(buf, rwc, req) |
| 51 | return |
| 52 | } |
| 53 | |
| 54 | // Server represents a server of a WebSocket. |
| 55 | type Server struct { |
| 56 | // Config is a WebSocket configuration for new WebSocket connection. |
| 57 | Config |
| 58 | |
| 59 | // Handshake is an optional function in WebSocket handshake. |
| 60 | // For example, you can check, or don't check Origin header. |
| 61 | // Another example, you can select config.Protocol. |
| 62 | Handshake func(*Config, *http.Request) error |
| 63 | |
| 64 | // Handler handles a WebSocket connection. |
| 65 | Handler |
| 66 | } |
| 67 | |
| 68 | // ServeHTTP implements the http.Handler interface for a WebSocket |
| 69 | func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
| 70 | s.serveWebSocket(w, req) |
| 71 | } |
| 72 | |
| 73 | func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) { |
| 74 | rwc, buf, err := w.(http.Hijacker).Hijack() |
| 75 | if err != nil { |
| 76 | panic("Hijack failed: " + err.Error()) |
| 77 | } |
| 78 | // The server should abort the WebSocket connection if it finds |
| 79 | // the client did not send a handshake that matches with protocol |
| 80 | // specification. |
| 81 | defer rwc.Close() |
| 82 | conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake) |
| 83 | if err != nil { |
| 84 | return |
| 85 | } |
| 86 | if conn == nil { |
| 87 | panic("unexpected nil conn") |
| 88 | } |
| 89 | s.Handler(conn) |
| 90 | } |
| 91 | |
| 92 | // Handler is a simple interface to a WebSocket browser client. |
| 93 | // It checks if Origin header is valid URL by default. |
| 94 | // You might want to verify websocket.Conn.Config().Origin in the func. |
| 95 | // If you use Server instead of Handler, you could call websocket.Origin and |
| 96 | // check the origin in your Handshake func. So, if you want to accept |
| 97 | // non-browser clients, which do not send an Origin header, set a |
| 98 | // Server.Handshake that does not check the origin. |
| 99 | type Handler func(*Conn) |
| 100 | |
| 101 | func checkOrigin(config *Config, req *http.Request) (err error) { |
| 102 | config.Origin, err = Origin(config, req) |
| 103 | if err == nil && config.Origin == nil { |
| 104 | return fmt.Errorf("null origin") |
| 105 | } |
| 106 | return err |
| 107 | } |
| 108 | |
| 109 | // ServeHTTP implements the http.Handler interface for a WebSocket |
| 110 | func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
| 111 | s := Server{Handler: h, Handshake: checkOrigin} |
| 112 | s.serveWebSocket(w, req) |
| 113 | } |