blob: 354001e1ede76a9f1e28a4e77b5add750a5ef959 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package websocket
6
7import (
8 "crypto/rand"
9 "crypto/sha1"
10 "encoding/base64"
11 "io"
12 "net/http"
13 "strings"
14 "unicode/utf8"
15)
16
17var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
18
19func computeAcceptKey(challengeKey string) string {
20 h := sha1.New()
21 h.Write([]byte(challengeKey))
22 h.Write(keyGUID)
23 return base64.StdEncoding.EncodeToString(h.Sum(nil))
24}
25
26func generateChallengeKey() (string, error) {
27 p := make([]byte, 16)
28 if _, err := io.ReadFull(rand.Reader, p); err != nil {
29 return "", err
30 }
31 return base64.StdEncoding.EncodeToString(p), nil
32}
33
34// Octet types from RFC 2616.
35var octetTypes [256]byte
36
37const (
38 isTokenOctet = 1 << iota
39 isSpaceOctet
40)
41
42func init() {
43 // From RFC 2616
44 //
45 // OCTET = <any 8-bit sequence of data>
46 // CHAR = <any US-ASCII character (octets 0 - 127)>
47 // CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
48 // CR = <US-ASCII CR, carriage return (13)>
49 // LF = <US-ASCII LF, linefeed (10)>
50 // SP = <US-ASCII SP, space (32)>
51 // HT = <US-ASCII HT, horizontal-tab (9)>
52 // <"> = <US-ASCII double-quote mark (34)>
53 // CRLF = CR LF
54 // LWS = [CRLF] 1*( SP | HT )
55 // TEXT = <any OCTET except CTLs, but including LWS>
56 // separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
57 // | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
58 // token = 1*<any CHAR except CTLs or separators>
59 // qdtext = <any TEXT except <">>
60
61 for c := 0; c < 256; c++ {
62 var t byte
63 isCtl := c <= 31 || c == 127
64 isChar := 0 <= c && c <= 127
65 isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
66 if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
67 t |= isSpaceOctet
68 }
69 if isChar && !isCtl && !isSeparator {
70 t |= isTokenOctet
71 }
72 octetTypes[c] = t
73 }
74}
75
76func skipSpace(s string) (rest string) {
77 i := 0
78 for ; i < len(s); i++ {
79 if octetTypes[s[i]]&isSpaceOctet == 0 {
80 break
81 }
82 }
83 return s[i:]
84}
85
86func nextToken(s string) (token, rest string) {
87 i := 0
88 for ; i < len(s); i++ {
89 if octetTypes[s[i]]&isTokenOctet == 0 {
90 break
91 }
92 }
93 return s[:i], s[i:]
94}
95
96func nextTokenOrQuoted(s string) (value string, rest string) {
97 if !strings.HasPrefix(s, "\"") {
98 return nextToken(s)
99 }
100 s = s[1:]
101 for i := 0; i < len(s); i++ {
102 switch s[i] {
103 case '"':
104 return s[:i], s[i+1:]
105 case '\\':
106 p := make([]byte, len(s)-1)
107 j := copy(p, s[:i])
108 escape := true
109 for i = i + 1; i < len(s); i++ {
110 b := s[i]
111 switch {
112 case escape:
113 escape = false
114 p[j] = b
115 j++
116 case b == '\\':
117 escape = true
118 case b == '"':
119 return string(p[:j]), s[i+1:]
120 default:
121 p[j] = b
122 j++
123 }
124 }
125 return "", ""
126 }
127 }
128 return "", ""
129}
130
131// equalASCIIFold returns true if s is equal to t with ASCII case folding.
132func equalASCIIFold(s, t string) bool {
133 for s != "" && t != "" {
134 sr, size := utf8.DecodeRuneInString(s)
135 s = s[size:]
136 tr, size := utf8.DecodeRuneInString(t)
137 t = t[size:]
138 if sr == tr {
139 continue
140 }
141 if 'A' <= sr && sr <= 'Z' {
142 sr = sr + 'a' - 'A'
143 }
144 if 'A' <= tr && tr <= 'Z' {
145 tr = tr + 'a' - 'A'
146 }
147 if sr != tr {
148 return false
149 }
150 }
151 return s == t
152}
153
154// tokenListContainsValue returns true if the 1#token header with the given
155// name contains a token equal to value with ASCII case folding.
156func tokenListContainsValue(header http.Header, name string, value string) bool {
157headers:
158 for _, s := range header[name] {
159 for {
160 var t string
161 t, s = nextToken(skipSpace(s))
162 if t == "" {
163 continue headers
164 }
165 s = skipSpace(s)
166 if s != "" && s[0] != ',' {
167 continue headers
168 }
169 if equalASCIIFold(t, value) {
170 return true
171 }
172 if s == "" {
173 continue headers
174 }
175 s = s[1:]
176 }
177 }
178 return false
179}
180
181// parseExtensions parses WebSocket extensions from a header.
182func parseExtensions(header http.Header) []map[string]string {
183 // From RFC 6455:
184 //
185 // Sec-WebSocket-Extensions = extension-list
186 // extension-list = 1#extension
187 // extension = extension-token *( ";" extension-param )
188 // extension-token = registered-token
189 // registered-token = token
190 // extension-param = token [ "=" (token | quoted-string) ]
191 // ;When using the quoted-string syntax variant, the value
192 // ;after quoted-string unescaping MUST conform to the
193 // ;'token' ABNF.
194
195 var result []map[string]string
196headers:
197 for _, s := range header["Sec-Websocket-Extensions"] {
198 for {
199 var t string
200 t, s = nextToken(skipSpace(s))
201 if t == "" {
202 continue headers
203 }
204 ext := map[string]string{"": t}
205 for {
206 s = skipSpace(s)
207 if !strings.HasPrefix(s, ";") {
208 break
209 }
210 var k string
211 k, s = nextToken(skipSpace(s[1:]))
212 if k == "" {
213 continue headers
214 }
215 s = skipSpace(s)
216 var v string
217 if strings.HasPrefix(s, "=") {
218 v, s = nextTokenOrQuoted(skipSpace(s[1:]))
219 s = skipSpace(s)
220 }
221 if s != "" && s[0] != ',' && s[0] != ';' {
222 continue headers
223 }
224 ext[k] = v
225 }
226 if s != "" && s[0] != ',' {
227 continue headers
228 }
229 result = append(result, ext)
230 if s == "" {
231 continue headers
232 }
233 s = s[1:]
234 }
235 }
236 return result
237}