blob: 722f6043d373a5fb86fb020d0cb8853bf52330ad [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright 2018 by David A. Golden. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package scram
8
9import (
10 "encoding/base64"
11 "errors"
12 "fmt"
13 "strconv"
14 "strings"
15)
16
17type c1Msg struct {
18 gs2Header string
19 authzID string
20 username string
21 nonce string
22 c1b string
23}
24
25type c2Msg struct {
26 cbind []byte
27 nonce string
28 proof []byte
29 c2wop string
30}
31
32type s1Msg struct {
33 nonce string
34 salt []byte
35 iters int
36}
37
38type s2Msg struct {
39 verifier []byte
40 err string
41}
42
43func parseField(s, k string) (string, error) {
44 t := strings.TrimPrefix(s, k+"=")
45 if t == s {
46 return "", fmt.Errorf("error parsing '%s' for field '%s'", s, k)
47 }
48 return t, nil
49}
50
51func parseGS2Flag(s string) (string, error) {
52 if s[0] == 'p' {
53 return "", fmt.Errorf("channel binding requested but not supported")
54 }
55
56 if s == "n" || s == "y" {
57 return s, nil
58 }
59
60 return "", fmt.Errorf("error parsing '%s' for gs2 flag", s)
61}
62
63func parseFieldBase64(s, k string) ([]byte, error) {
64 raw, err := parseField(s, k)
65 if err != nil {
66 return nil, err
67 }
68
69 dec, err := base64.StdEncoding.DecodeString(raw)
70 if err != nil {
71 return nil, err
72 }
73
74 return dec, nil
75}
76
77func parseFieldInt(s, k string) (int, error) {
78 raw, err := parseField(s, k)
79 if err != nil {
80 return 0, err
81 }
82
83 num, err := strconv.Atoi(raw)
84 if err != nil {
85 return 0, fmt.Errorf("error parsing field '%s': %v", k, err)
86 }
87
88 return num, nil
89}
90
91func parseClientFirst(c1 string) (msg c1Msg, err error) {
92
93 fields := strings.Split(c1, ",")
94 if len(fields) < 4 {
95 err = errors.New("not enough fields in first server message")
96 return
97 }
98
99 gs2flag, err := parseGS2Flag(fields[0])
100 if err != nil {
101 return
102 }
103
104 // 'a' field is optional
105 if len(fields[1]) > 0 {
106 msg.authzID, err = parseField(fields[1], "a")
107 if err != nil {
108 return
109 }
110 }
111
112 // Recombine and save the gs2 header
113 msg.gs2Header = gs2flag + "," + msg.authzID + ","
114
115 // Check for unsupported extensions field "m".
116 if strings.HasPrefix(fields[2], "m=") {
117 err = errors.New("SCRAM message extensions are not supported")
118 return
119 }
120
121 msg.username, err = parseField(fields[2], "n")
122 if err != nil {
123 return
124 }
125
126 msg.nonce, err = parseField(fields[3], "r")
127 if err != nil {
128 return
129 }
130
131 msg.c1b = strings.Join(fields[2:], ",")
132
133 return
134}
135
136func parseClientFinal(c2 string) (msg c2Msg, err error) {
137 fields := strings.Split(c2, ",")
138 if len(fields) < 3 {
139 err = errors.New("not enough fields in first server message")
140 return
141 }
142
143 msg.cbind, err = parseFieldBase64(fields[0], "c")
144 if err != nil {
145 return
146 }
147
148 msg.nonce, err = parseField(fields[1], "r")
149 if err != nil {
150 return
151 }
152
153 // Extension fields may come between nonce and proof, so we
154 // grab the *last* fields as proof.
155 msg.proof, err = parseFieldBase64(fields[len(fields)-1], "p")
156 if err != nil {
157 return
158 }
159
160 msg.c2wop = c2[:strings.LastIndex(c2, ",")]
161
162 return
163}
164
165func parseServerFirst(s1 string) (msg s1Msg, err error) {
166
167 // Check for unsupported extensions field "m".
168 if strings.HasPrefix(s1, "m=") {
169 err = errors.New("SCRAM message extensions are not supported")
170 return
171 }
172
173 fields := strings.Split(s1, ",")
174 if len(fields) < 3 {
175 err = errors.New("not enough fields in first server message")
176 return
177 }
178
179 msg.nonce, err = parseField(fields[0], "r")
180 if err != nil {
181 return
182 }
183
184 msg.salt, err = parseFieldBase64(fields[1], "s")
185 if err != nil {
186 return
187 }
188
189 msg.iters, err = parseFieldInt(fields[2], "i")
190
191 return
192}
193
194func parseServerFinal(s2 string) (msg s2Msg, err error) {
195 fields := strings.Split(s2, ",")
196
197 msg.verifier, err = parseFieldBase64(fields[0], "v")
198 if err == nil {
199 return
200 }
201
202 msg.err, err = parseField(fields[0], "e")
203
204 return
205}