blob: 43a87c753bfce1f378ba1bd7462890b18cf30b01 [file] [log] [blame]
khenaidoo59ce9dd2019-11-11 13:05:32 -05001// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package websocket
6
7import (
8 "bufio"
9 "bytes"
10 "crypto/tls"
11 "encoding/base64"
12 "errors"
13 "io"
14 "io/ioutil"
15 "net"
16 "net/http"
17 "net/url"
18 "strings"
19 "time"
20)
21
22// ErrBadHandshake is returned when the server response to opening handshake is
23// invalid.
24var ErrBadHandshake = errors.New("websocket: bad handshake")
25
26var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
27
28// NewClient creates a new client connection using the given net connection.
29// The URL u specifies the host and request URI. Use requestHeader to specify
30// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
31// (Cookie). Use the response.Header to get the selected subprotocol
32// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
33//
34// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
35// non-nil *http.Response so that callers can handle redirects, authentication,
36// etc.
37//
38// Deprecated: Use Dialer instead.
39func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
40 d := Dialer{
41 ReadBufferSize: readBufSize,
42 WriteBufferSize: writeBufSize,
43 NetDial: func(net, addr string) (net.Conn, error) {
44 return netConn, nil
45 },
46 }
47 return d.Dial(u.String(), requestHeader)
48}
49
50// A Dialer contains options for connecting to WebSocket server.
51type Dialer struct {
52 // NetDial specifies the dial function for creating TCP connections. If
53 // NetDial is nil, net.Dial is used.
54 NetDial func(network, addr string) (net.Conn, error)
55
56 // Proxy specifies a function to return a proxy for a given
57 // Request. If the function returns a non-nil error, the
58 // request is aborted with the provided error.
59 // If Proxy is nil or returns a nil *URL, no proxy is used.
60 Proxy func(*http.Request) (*url.URL, error)
61
62 // TLSClientConfig specifies the TLS configuration to use with tls.Client.
63 // If nil, the default configuration is used.
64 TLSClientConfig *tls.Config
65
66 // HandshakeTimeout specifies the duration for the handshake to complete.
67 HandshakeTimeout time.Duration
68
69 // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
70 // size is zero, then a useful default size is used. The I/O buffer sizes
71 // do not limit the size of the messages that can be sent or received.
72 ReadBufferSize, WriteBufferSize int
73
74 // Subprotocols specifies the client's requested subprotocols.
75 Subprotocols []string
76
77 // EnableCompression specifies if the client should attempt to negotiate
78 // per message compression (RFC 7692). Setting this value to true does not
79 // guarantee that compression will be supported. Currently only "no context
80 // takeover" modes are supported.
81 EnableCompression bool
82
83 // Jar specifies the cookie jar.
84 // If Jar is nil, cookies are not sent in requests and ignored
85 // in responses.
86 Jar http.CookieJar
87}
88
89var errMalformedURL = errors.New("malformed ws or wss URL")
90
91// parseURL parses the URL.
92//
93// This function is a replacement for the standard library url.Parse function.
94// In Go 1.4 and earlier, url.Parse loses information from the path.
95func parseURL(s string) (*url.URL, error) {
96 // From the RFC:
97 //
98 // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
99 // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
100 var u url.URL
101 switch {
102 case strings.HasPrefix(s, "ws://"):
103 u.Scheme = "ws"
104 s = s[len("ws://"):]
105 case strings.HasPrefix(s, "wss://"):
106 u.Scheme = "wss"
107 s = s[len("wss://"):]
108 default:
109 return nil, errMalformedURL
110 }
111
112 if i := strings.Index(s, "?"); i >= 0 {
113 u.RawQuery = s[i+1:]
114 s = s[:i]
115 }
116
117 if i := strings.Index(s, "/"); i >= 0 {
118 u.Opaque = s[i:]
119 s = s[:i]
120 } else {
121 u.Opaque = "/"
122 }
123
124 u.Host = s
125
126 if strings.Contains(u.Host, "@") {
127 // Don't bother parsing user information because user information is
128 // not allowed in websocket URIs.
129 return nil, errMalformedURL
130 }
131
132 return &u, nil
133}
134
135func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
136 hostPort = u.Host
137 hostNoPort = u.Host
138 if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
139 hostNoPort = hostNoPort[:i]
140 } else {
141 switch u.Scheme {
142 case "wss":
143 hostPort += ":443"
144 case "https":
145 hostPort += ":443"
146 default:
147 hostPort += ":80"
148 }
149 }
150 return hostPort, hostNoPort
151}
152
153// DefaultDialer is a dialer with all fields set to the default zero values.
154var DefaultDialer = &Dialer{
155 Proxy: http.ProxyFromEnvironment,
156}
157
158// Dial creates a new client connection. Use requestHeader to specify the
159// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
160// Use the response.Header to get the selected subprotocol
161// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
162//
163// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
164// non-nil *http.Response so that callers can handle redirects, authentication,
165// etcetera. The response body may not contain the entire response and does not
166// need to be closed by the application.
167func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
168
169 if d == nil {
170 d = &Dialer{
171 Proxy: http.ProxyFromEnvironment,
172 }
173 }
174
175 challengeKey, err := generateChallengeKey()
176 if err != nil {
177 return nil, nil, err
178 }
179
180 u, err := parseURL(urlStr)
181 if err != nil {
182 return nil, nil, err
183 }
184
185 switch u.Scheme {
186 case "ws":
187 u.Scheme = "http"
188 case "wss":
189 u.Scheme = "https"
190 default:
191 return nil, nil, errMalformedURL
192 }
193
194 if u.User != nil {
195 // User name and password are not allowed in websocket URIs.
196 return nil, nil, errMalformedURL
197 }
198
199 req := &http.Request{
200 Method: "GET",
201 URL: u,
202 Proto: "HTTP/1.1",
203 ProtoMajor: 1,
204 ProtoMinor: 1,
205 Header: make(http.Header),
206 Host: u.Host,
207 }
208
209 // Set the cookies present in the cookie jar of the dialer
210 if d.Jar != nil {
211 for _, cookie := range d.Jar.Cookies(u) {
212 req.AddCookie(cookie)
213 }
214 }
215
216 // Set the request headers using the capitalization for names and values in
217 // RFC examples. Although the capitalization shouldn't matter, there are
218 // servers that depend on it. The Header.Set method is not used because the
219 // method canonicalizes the header names.
220 req.Header["Upgrade"] = []string{"websocket"}
221 req.Header["Connection"] = []string{"Upgrade"}
222 req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
223 req.Header["Sec-WebSocket-Version"] = []string{"13"}
224 if len(d.Subprotocols) > 0 {
225 req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
226 }
227 for k, vs := range requestHeader {
228 switch {
229 case k == "Host":
230 if len(vs) > 0 {
231 req.Host = vs[0]
232 }
233 case k == "Upgrade" ||
234 k == "Connection" ||
235 k == "Sec-Websocket-Key" ||
236 k == "Sec-Websocket-Version" ||
237 k == "Sec-Websocket-Extensions" ||
238 (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
239 return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
240 default:
241 req.Header[k] = vs
242 }
243 }
244
245 if d.EnableCompression {
246 req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
247 }
248
249 hostPort, hostNoPort := hostPortNoPort(u)
250
251 var proxyURL *url.URL
252 // Check wether the proxy method has been configured
253 if d.Proxy != nil {
254 proxyURL, err = d.Proxy(req)
255 }
256 if err != nil {
257 return nil, nil, err
258 }
259
260 var targetHostPort string
261 if proxyURL != nil {
262 targetHostPort, _ = hostPortNoPort(proxyURL)
263 } else {
264 targetHostPort = hostPort
265 }
266
267 var deadline time.Time
268 if d.HandshakeTimeout != 0 {
269 deadline = time.Now().Add(d.HandshakeTimeout)
270 }
271
272 netDial := d.NetDial
273 if netDial == nil {
274 netDialer := &net.Dialer{Deadline: deadline}
275 netDial = netDialer.Dial
276 }
277
278 netConn, err := netDial("tcp", targetHostPort)
279 if err != nil {
280 return nil, nil, err
281 }
282
283 defer func() {
284 if netConn != nil {
285 netConn.Close()
286 }
287 }()
288
289 if err := netConn.SetDeadline(deadline); err != nil {
290 return nil, nil, err
291 }
292
293 if proxyURL != nil {
294 connectHeader := make(http.Header)
295 if user := proxyURL.User; user != nil {
296 proxyUser := user.Username()
297 if proxyPassword, passwordSet := user.Password(); passwordSet {
298 credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
299 connectHeader.Set("Proxy-Authorization", "Basic "+credential)
300 }
301 }
302 connectReq := &http.Request{
303 Method: "CONNECT",
304 URL: &url.URL{Opaque: hostPort},
305 Host: hostPort,
306 Header: connectHeader,
307 }
308
309 connectReq.Write(netConn)
310
311 // Read response.
312 // Okay to use and discard buffered reader here, because
313 // TLS server will not speak until spoken to.
314 br := bufio.NewReader(netConn)
315 resp, err := http.ReadResponse(br, connectReq)
316 if err != nil {
317 return nil, nil, err
318 }
319 if resp.StatusCode != 200 {
320 f := strings.SplitN(resp.Status, " ", 2)
321 return nil, nil, errors.New(f[1])
322 }
323 }
324
325 if u.Scheme == "https" {
326 cfg := cloneTLSConfig(d.TLSClientConfig)
327 if cfg.ServerName == "" {
328 cfg.ServerName = hostNoPort
329 }
330 tlsConn := tls.Client(netConn, cfg)
331 netConn = tlsConn
332 if err := tlsConn.Handshake(); err != nil {
333 return nil, nil, err
334 }
335 if !cfg.InsecureSkipVerify {
336 if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
337 return nil, nil, err
338 }
339 }
340 }
341
342 conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
343
344 if err := req.Write(netConn); err != nil {
345 return nil, nil, err
346 }
347
348 resp, err := http.ReadResponse(conn.br, req)
349 if err != nil {
350 return nil, nil, err
351 }
352
353 if d.Jar != nil {
354 if rc := resp.Cookies(); len(rc) > 0 {
355 d.Jar.SetCookies(u, rc)
356 }
357 }
358
359 if resp.StatusCode != 101 ||
360 !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
361 !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
362 resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
363 // Before closing the network connection on return from this
364 // function, slurp up some of the response to aid application
365 // debugging.
366 buf := make([]byte, 1024)
367 n, _ := io.ReadFull(resp.Body, buf)
368 resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
369 return nil, resp, ErrBadHandshake
370 }
371
372 for _, ext := range parseExtensions(resp.Header) {
373 if ext[""] != "permessage-deflate" {
374 continue
375 }
376 _, snct := ext["server_no_context_takeover"]
377 _, cnct := ext["client_no_context_takeover"]
378 if !snct || !cnct {
379 return nil, resp, errInvalidCompression
380 }
381 conn.newCompressionWriter = compressNoContextTakeover
382 conn.newDecompressionReader = decompressNoContextTakeover
383 break
384 }
385
386 resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
387 conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
388
389 netConn.SetDeadline(time.Time{})
390 netConn = nil // to avoid close in defer.
391 return conn, resp, nil
392}