sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 1 | // Copyright 2014 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package internal |
| 6 | |
| 7 | import ( |
| 8 | "context" |
| 9 | "encoding/json" |
| 10 | "errors" |
| 11 | "fmt" |
| 12 | "io" |
| 13 | "io/ioutil" |
| 14 | "mime" |
| 15 | "net/http" |
| 16 | "net/url" |
| 17 | "strconv" |
| 18 | "strings" |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 19 | "sync" |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 20 | "time" |
| 21 | |
| 22 | "golang.org/x/net/context/ctxhttp" |
| 23 | ) |
| 24 | |
| 25 | // Token represents the credentials used to authorize |
| 26 | // the requests to access protected resources on the OAuth 2.0 |
| 27 | // provider's backend. |
| 28 | // |
| 29 | // This type is a mirror of oauth2.Token and exists to break |
| 30 | // an otherwise-circular dependency. Other internal packages |
| 31 | // should convert this Token into an oauth2.Token before use. |
| 32 | type Token struct { |
| 33 | // AccessToken is the token that authorizes and authenticates |
| 34 | // the requests. |
| 35 | AccessToken string |
| 36 | |
| 37 | // TokenType is the type of token. |
| 38 | // The Type method returns either this or "Bearer", the default. |
| 39 | TokenType string |
| 40 | |
| 41 | // RefreshToken is a token that's used by the application |
| 42 | // (as opposed to the user) to refresh the access token |
| 43 | // if it expires. |
| 44 | RefreshToken string |
| 45 | |
| 46 | // Expiry is the optional expiration time of the access token. |
| 47 | // |
| 48 | // If zero, TokenSource implementations will reuse the same |
| 49 | // token forever and RefreshToken or equivalent |
| 50 | // mechanisms for that TokenSource will not be used. |
| 51 | Expiry time.Time |
| 52 | |
| 53 | // Raw optionally contains extra metadata from the server |
| 54 | // when updating a token. |
| 55 | Raw interface{} |
| 56 | } |
| 57 | |
| 58 | // tokenJSON is the struct representing the HTTP response from OAuth2 |
| 59 | // providers returning a token in JSON form. |
| 60 | type tokenJSON struct { |
| 61 | AccessToken string `json:"access_token"` |
| 62 | TokenType string `json:"token_type"` |
| 63 | RefreshToken string `json:"refresh_token"` |
| 64 | ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number |
| 65 | Expires expirationTime `json:"expires"` // broken Facebook spelling of expires_in |
| 66 | } |
| 67 | |
| 68 | func (e *tokenJSON) expiry() (t time.Time) { |
| 69 | if v := e.ExpiresIn; v != 0 { |
| 70 | return time.Now().Add(time.Duration(v) * time.Second) |
| 71 | } |
| 72 | if v := e.Expires; v != 0 { |
| 73 | return time.Now().Add(time.Duration(v) * time.Second) |
| 74 | } |
| 75 | return |
| 76 | } |
| 77 | |
| 78 | type expirationTime int32 |
| 79 | |
| 80 | func (e *expirationTime) UnmarshalJSON(b []byte) error { |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 81 | if len(b) == 0 || string(b) == "null" { |
| 82 | return nil |
| 83 | } |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 84 | var n json.Number |
| 85 | err := json.Unmarshal(b, &n) |
| 86 | if err != nil { |
| 87 | return err |
| 88 | } |
| 89 | i, err := n.Int64() |
| 90 | if err != nil { |
| 91 | return err |
| 92 | } |
| 93 | *e = expirationTime(i) |
| 94 | return nil |
| 95 | } |
| 96 | |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 97 | // RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. |
| 98 | // |
| 99 | // Deprecated: this function no longer does anything. Caller code that |
| 100 | // wants to avoid potential extra HTTP requests made during |
| 101 | // auto-probing of the provider's auth style should set |
| 102 | // Endpoint.AuthStyle. |
| 103 | func RegisterBrokenAuthHeaderProvider(tokenURL string) {} |
| 104 | |
| 105 | // AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. |
| 106 | type AuthStyle int |
| 107 | |
| 108 | const ( |
| 109 | AuthStyleUnknown AuthStyle = 0 |
| 110 | AuthStyleInParams AuthStyle = 1 |
| 111 | AuthStyleInHeader AuthStyle = 2 |
| 112 | ) |
| 113 | |
| 114 | // authStyleCache is the set of tokenURLs we've successfully used via |
| 115 | // RetrieveToken and which style auth we ended up using. |
| 116 | // It's called a cache, but it doesn't (yet?) shrink. It's expected that |
| 117 | // the set of OAuth2 servers a program contacts over time is fixed and |
| 118 | // small. |
| 119 | var authStyleCache struct { |
| 120 | sync.Mutex |
| 121 | m map[string]AuthStyle // keyed by tokenURL |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 122 | } |
| 123 | |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 124 | // ResetAuthCache resets the global authentication style cache used |
| 125 | // for AuthStyleUnknown token requests. |
| 126 | func ResetAuthCache() { |
| 127 | authStyleCache.Lock() |
| 128 | defer authStyleCache.Unlock() |
| 129 | authStyleCache.m = nil |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 130 | } |
| 131 | |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 132 | // lookupAuthStyle reports which auth style we last used with tokenURL |
| 133 | // when calling RetrieveToken and whether we have ever done so. |
| 134 | func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { |
| 135 | authStyleCache.Lock() |
| 136 | defer authStyleCache.Unlock() |
| 137 | style, ok = authStyleCache.m[tokenURL] |
| 138 | return |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 139 | } |
| 140 | |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 141 | // setAuthStyle adds an entry to authStyleCache, documented above. |
| 142 | func setAuthStyle(tokenURL string, v AuthStyle) { |
| 143 | authStyleCache.Lock() |
| 144 | defer authStyleCache.Unlock() |
| 145 | if authStyleCache.m == nil { |
| 146 | authStyleCache.m = make(map[string]AuthStyle) |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 147 | } |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 148 | authStyleCache.m[tokenURL] = v |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 149 | } |
| 150 | |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 151 | // newTokenRequest returns a new *http.Request to retrieve a new token |
| 152 | // from tokenURL using the provided clientID, clientSecret, and POST |
| 153 | // body parameters. |
| 154 | // |
| 155 | // inParams is whether the clientID & clientSecret should be encoded |
| 156 | // as the POST body. An 'inParams' value of true means to send it in |
| 157 | // the POST body (along with any values in v); false means to send it |
| 158 | // in the Authorization header. |
| 159 | func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) { |
| 160 | if authStyle == AuthStyleInParams { |
| 161 | v = cloneURLValues(v) |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 162 | if clientID != "" { |
| 163 | v.Set("client_id", clientID) |
| 164 | } |
| 165 | if clientSecret != "" { |
| 166 | v.Set("client_secret", clientSecret) |
| 167 | } |
| 168 | } |
| 169 | req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) |
| 170 | if err != nil { |
| 171 | return nil, err |
| 172 | } |
| 173 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 174 | if authStyle == AuthStyleInHeader { |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 175 | req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) |
| 176 | } |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 177 | return req, nil |
| 178 | } |
| 179 | |
| 180 | func cloneURLValues(v url.Values) url.Values { |
| 181 | v2 := make(url.Values, len(v)) |
| 182 | for k, vv := range v { |
| 183 | v2[k] = append([]string(nil), vv...) |
| 184 | } |
| 185 | return v2 |
| 186 | } |
| 187 | |
| 188 | func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) { |
| 189 | needsAuthStyleProbe := authStyle == 0 |
| 190 | if needsAuthStyleProbe { |
| 191 | if style, ok := lookupAuthStyle(tokenURL); ok { |
| 192 | authStyle = style |
| 193 | needsAuthStyleProbe = false |
| 194 | } else { |
| 195 | authStyle = AuthStyleInHeader // the first way we'll try |
| 196 | } |
| 197 | } |
| 198 | req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) |
| 199 | if err != nil { |
| 200 | return nil, err |
| 201 | } |
| 202 | token, err := doTokenRoundTrip(ctx, req) |
| 203 | if err != nil && needsAuthStyleProbe { |
| 204 | // If we get an error, assume the server wants the |
| 205 | // clientID & clientSecret in a different form. |
| 206 | // See https://code.google.com/p/goauth2/issues/detail?id=31 for background. |
| 207 | // In summary: |
| 208 | // - Reddit only accepts client secret in the Authorization header |
| 209 | // - Dropbox accepts either it in URL param or Auth header, but not both. |
| 210 | // - Google only accepts URL param (not spec compliant?), not Auth header |
| 211 | // - Stripe only accepts client secret in Auth header with Bearer method, not Basic |
| 212 | // |
| 213 | // We used to maintain a big table in this code of all the sites and which way |
| 214 | // they went, but maintaining it didn't scale & got annoying. |
| 215 | // So just try both ways. |
| 216 | authStyle = AuthStyleInParams // the second way we'll try |
| 217 | req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) |
| 218 | token, err = doTokenRoundTrip(ctx, req) |
| 219 | } |
| 220 | if needsAuthStyleProbe && err == nil { |
| 221 | setAuthStyle(tokenURL, authStyle) |
| 222 | } |
| 223 | // Don't overwrite `RefreshToken` with an empty value |
| 224 | // if this was a token refreshing request. |
| 225 | if token != nil && token.RefreshToken == "" { |
| 226 | token.RefreshToken = v.Get("refresh_token") |
| 227 | } |
| 228 | return token, err |
| 229 | } |
| 230 | |
| 231 | func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 232 | r, err := ctxhttp.Do(ctx, ContextClient(ctx), req) |
| 233 | if err != nil { |
| 234 | return nil, err |
| 235 | } |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 236 | body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 237 | r.Body.Close() |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 238 | if err != nil { |
| 239 | return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
| 240 | } |
| 241 | if code := r.StatusCode; code < 200 || code > 299 { |
| 242 | return nil, &RetrieveError{ |
| 243 | Response: r, |
| 244 | Body: body, |
| 245 | } |
| 246 | } |
| 247 | |
| 248 | var token *Token |
| 249 | content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) |
| 250 | switch content { |
| 251 | case "application/x-www-form-urlencoded", "text/plain": |
| 252 | vals, err := url.ParseQuery(string(body)) |
| 253 | if err != nil { |
| 254 | return nil, err |
| 255 | } |
| 256 | token = &Token{ |
| 257 | AccessToken: vals.Get("access_token"), |
| 258 | TokenType: vals.Get("token_type"), |
| 259 | RefreshToken: vals.Get("refresh_token"), |
| 260 | Raw: vals, |
| 261 | } |
| 262 | e := vals.Get("expires_in") |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 263 | if e == "" || e == "null" { |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 264 | // TODO(jbd): Facebook's OAuth2 implementation is broken and |
| 265 | // returns expires_in field in expires. Remove the fallback to expires, |
| 266 | // when Facebook fixes their implementation. |
| 267 | e = vals.Get("expires") |
| 268 | } |
| 269 | expires, _ := strconv.Atoi(e) |
| 270 | if expires != 0 { |
| 271 | token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) |
| 272 | } |
| 273 | default: |
| 274 | var tj tokenJSON |
| 275 | if err = json.Unmarshal(body, &tj); err != nil { |
| 276 | return nil, err |
| 277 | } |
| 278 | token = &Token{ |
| 279 | AccessToken: tj.AccessToken, |
| 280 | TokenType: tj.TokenType, |
| 281 | RefreshToken: tj.RefreshToken, |
| 282 | Expiry: tj.expiry(), |
| 283 | Raw: make(map[string]interface{}), |
| 284 | } |
| 285 | json.Unmarshal(body, &token.Raw) // no error checks for optional fields |
| 286 | } |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 287 | if token.AccessToken == "" { |
Stephane Barbarie | 260a563 | 2019-02-26 16:12:49 -0500 | [diff] [blame] | 288 | return nil, errors.New("oauth2: server response missing access_token") |
sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 289 | } |
| 290 | return token, nil |
| 291 | } |
| 292 | |
| 293 | type RetrieveError struct { |
| 294 | Response *http.Response |
| 295 | Body []byte |
| 296 | } |
| 297 | |
| 298 | func (r *RetrieveError) Error() string { |
| 299 | return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) |
| 300 | } |