Matteo Scandolo | a428586 | 2020-12-01 18:10:10 -0800 | [diff] [blame] | 1 | // Copyright 2014 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package internal |
| 6 | |
| 7 | import ( |
| 8 | "context" |
| 9 | "encoding/json" |
| 10 | "errors" |
| 11 | "fmt" |
| 12 | "io" |
| 13 | "io/ioutil" |
| 14 | "math" |
| 15 | "mime" |
| 16 | "net/http" |
| 17 | "net/url" |
| 18 | "strconv" |
| 19 | "strings" |
| 20 | "sync" |
| 21 | "time" |
| 22 | |
| 23 | "golang.org/x/net/context/ctxhttp" |
| 24 | ) |
| 25 | |
| 26 | // Token represents the credentials used to authorize |
| 27 | // the requests to access protected resources on the OAuth 2.0 |
| 28 | // provider's backend. |
| 29 | // |
| 30 | // This type is a mirror of oauth2.Token and exists to break |
| 31 | // an otherwise-circular dependency. Other internal packages |
| 32 | // should convert this Token into an oauth2.Token before use. |
| 33 | type Token struct { |
| 34 | // AccessToken is the token that authorizes and authenticates |
| 35 | // the requests. |
| 36 | AccessToken string |
| 37 | |
| 38 | // TokenType is the type of token. |
| 39 | // The Type method returns either this or "Bearer", the default. |
| 40 | TokenType string |
| 41 | |
| 42 | // RefreshToken is a token that's used by the application |
| 43 | // (as opposed to the user) to refresh the access token |
| 44 | // if it expires. |
| 45 | RefreshToken string |
| 46 | |
| 47 | // Expiry is the optional expiration time of the access token. |
| 48 | // |
| 49 | // If zero, TokenSource implementations will reuse the same |
| 50 | // token forever and RefreshToken or equivalent |
| 51 | // mechanisms for that TokenSource will not be used. |
| 52 | Expiry time.Time |
| 53 | |
| 54 | // Raw optionally contains extra metadata from the server |
| 55 | // when updating a token. |
| 56 | Raw interface{} |
| 57 | } |
| 58 | |
| 59 | // tokenJSON is the struct representing the HTTP response from OAuth2 |
| 60 | // providers returning a token in JSON form. |
| 61 | type tokenJSON struct { |
| 62 | AccessToken string `json:"access_token"` |
| 63 | TokenType string `json:"token_type"` |
| 64 | RefreshToken string `json:"refresh_token"` |
| 65 | ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number |
| 66 | } |
| 67 | |
| 68 | func (e *tokenJSON) expiry() (t time.Time) { |
| 69 | if v := e.ExpiresIn; v != 0 { |
| 70 | return time.Now().Add(time.Duration(v) * time.Second) |
| 71 | } |
| 72 | return |
| 73 | } |
| 74 | |
| 75 | type expirationTime int32 |
| 76 | |
| 77 | func (e *expirationTime) UnmarshalJSON(b []byte) error { |
| 78 | if len(b) == 0 || string(b) == "null" { |
| 79 | return nil |
| 80 | } |
| 81 | var n json.Number |
| 82 | err := json.Unmarshal(b, &n) |
| 83 | if err != nil { |
| 84 | return err |
| 85 | } |
| 86 | i, err := n.Int64() |
| 87 | if err != nil { |
| 88 | return err |
| 89 | } |
| 90 | if i > math.MaxInt32 { |
| 91 | i = math.MaxInt32 |
| 92 | } |
| 93 | *e = expirationTime(i) |
| 94 | return nil |
| 95 | } |
| 96 | |
| 97 | // RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. |
| 98 | // |
| 99 | // Deprecated: this function no longer does anything. Caller code that |
| 100 | // wants to avoid potential extra HTTP requests made during |
| 101 | // auto-probing of the provider's auth style should set |
| 102 | // Endpoint.AuthStyle. |
| 103 | func RegisterBrokenAuthHeaderProvider(tokenURL string) {} |
| 104 | |
| 105 | // AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. |
| 106 | type AuthStyle int |
| 107 | |
| 108 | const ( |
| 109 | AuthStyleUnknown AuthStyle = 0 |
| 110 | AuthStyleInParams AuthStyle = 1 |
| 111 | AuthStyleInHeader AuthStyle = 2 |
| 112 | ) |
| 113 | |
| 114 | // authStyleCache is the set of tokenURLs we've successfully used via |
| 115 | // RetrieveToken and which style auth we ended up using. |
| 116 | // It's called a cache, but it doesn't (yet?) shrink. It's expected that |
| 117 | // the set of OAuth2 servers a program contacts over time is fixed and |
| 118 | // small. |
| 119 | var authStyleCache struct { |
| 120 | sync.Mutex |
| 121 | m map[string]AuthStyle // keyed by tokenURL |
| 122 | } |
| 123 | |
| 124 | // ResetAuthCache resets the global authentication style cache used |
| 125 | // for AuthStyleUnknown token requests. |
| 126 | func ResetAuthCache() { |
| 127 | authStyleCache.Lock() |
| 128 | defer authStyleCache.Unlock() |
| 129 | authStyleCache.m = nil |
| 130 | } |
| 131 | |
| 132 | // lookupAuthStyle reports which auth style we last used with tokenURL |
| 133 | // when calling RetrieveToken and whether we have ever done so. |
| 134 | func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { |
| 135 | authStyleCache.Lock() |
| 136 | defer authStyleCache.Unlock() |
| 137 | style, ok = authStyleCache.m[tokenURL] |
| 138 | return |
| 139 | } |
| 140 | |
| 141 | // setAuthStyle adds an entry to authStyleCache, documented above. |
| 142 | func setAuthStyle(tokenURL string, v AuthStyle) { |
| 143 | authStyleCache.Lock() |
| 144 | defer authStyleCache.Unlock() |
| 145 | if authStyleCache.m == nil { |
| 146 | authStyleCache.m = make(map[string]AuthStyle) |
| 147 | } |
| 148 | authStyleCache.m[tokenURL] = v |
| 149 | } |
| 150 | |
| 151 | // newTokenRequest returns a new *http.Request to retrieve a new token |
| 152 | // from tokenURL using the provided clientID, clientSecret, and POST |
| 153 | // body parameters. |
| 154 | // |
| 155 | // inParams is whether the clientID & clientSecret should be encoded |
| 156 | // as the POST body. An 'inParams' value of true means to send it in |
| 157 | // the POST body (along with any values in v); false means to send it |
| 158 | // in the Authorization header. |
| 159 | func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) { |
| 160 | if authStyle == AuthStyleInParams { |
| 161 | v = cloneURLValues(v) |
| 162 | if clientID != "" { |
| 163 | v.Set("client_id", clientID) |
| 164 | } |
| 165 | if clientSecret != "" { |
| 166 | v.Set("client_secret", clientSecret) |
| 167 | } |
| 168 | } |
| 169 | req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) |
| 170 | if err != nil { |
| 171 | return nil, err |
| 172 | } |
| 173 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 174 | if authStyle == AuthStyleInHeader { |
| 175 | req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) |
| 176 | } |
| 177 | return req, nil |
| 178 | } |
| 179 | |
| 180 | func cloneURLValues(v url.Values) url.Values { |
| 181 | v2 := make(url.Values, len(v)) |
| 182 | for k, vv := range v { |
| 183 | v2[k] = append([]string(nil), vv...) |
| 184 | } |
| 185 | return v2 |
| 186 | } |
| 187 | |
| 188 | func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) { |
| 189 | needsAuthStyleProbe := authStyle == 0 |
| 190 | if needsAuthStyleProbe { |
| 191 | if style, ok := lookupAuthStyle(tokenURL); ok { |
| 192 | authStyle = style |
| 193 | needsAuthStyleProbe = false |
| 194 | } else { |
| 195 | authStyle = AuthStyleInHeader // the first way we'll try |
| 196 | } |
| 197 | } |
| 198 | req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) |
| 199 | if err != nil { |
| 200 | return nil, err |
| 201 | } |
| 202 | token, err := doTokenRoundTrip(ctx, req) |
| 203 | if err != nil && needsAuthStyleProbe { |
| 204 | // If we get an error, assume the server wants the |
| 205 | // clientID & clientSecret in a different form. |
| 206 | // See https://code.google.com/p/goauth2/issues/detail?id=31 for background. |
| 207 | // In summary: |
| 208 | // - Reddit only accepts client secret in the Authorization header |
| 209 | // - Dropbox accepts either it in URL param or Auth header, but not both. |
| 210 | // - Google only accepts URL param (not spec compliant?), not Auth header |
| 211 | // - Stripe only accepts client secret in Auth header with Bearer method, not Basic |
| 212 | // |
| 213 | // We used to maintain a big table in this code of all the sites and which way |
| 214 | // they went, but maintaining it didn't scale & got annoying. |
| 215 | // So just try both ways. |
| 216 | authStyle = AuthStyleInParams // the second way we'll try |
| 217 | req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) |
| 218 | token, err = doTokenRoundTrip(ctx, req) |
| 219 | } |
| 220 | if needsAuthStyleProbe && err == nil { |
| 221 | setAuthStyle(tokenURL, authStyle) |
| 222 | } |
| 223 | // Don't overwrite `RefreshToken` with an empty value |
| 224 | // if this was a token refreshing request. |
| 225 | if token != nil && token.RefreshToken == "" { |
| 226 | token.RefreshToken = v.Get("refresh_token") |
| 227 | } |
| 228 | return token, err |
| 229 | } |
| 230 | |
| 231 | func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { |
| 232 | r, err := ctxhttp.Do(ctx, ContextClient(ctx), req) |
| 233 | if err != nil { |
| 234 | return nil, err |
| 235 | } |
| 236 | body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) |
| 237 | r.Body.Close() |
| 238 | if err != nil { |
| 239 | return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) |
| 240 | } |
| 241 | if code := r.StatusCode; code < 200 || code > 299 { |
| 242 | return nil, &RetrieveError{ |
| 243 | Response: r, |
| 244 | Body: body, |
| 245 | } |
| 246 | } |
| 247 | |
| 248 | var token *Token |
| 249 | content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) |
| 250 | switch content { |
| 251 | case "application/x-www-form-urlencoded", "text/plain": |
| 252 | vals, err := url.ParseQuery(string(body)) |
| 253 | if err != nil { |
| 254 | return nil, err |
| 255 | } |
| 256 | token = &Token{ |
| 257 | AccessToken: vals.Get("access_token"), |
| 258 | TokenType: vals.Get("token_type"), |
| 259 | RefreshToken: vals.Get("refresh_token"), |
| 260 | Raw: vals, |
| 261 | } |
| 262 | e := vals.Get("expires_in") |
| 263 | expires, _ := strconv.Atoi(e) |
| 264 | if expires != 0 { |
| 265 | token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) |
| 266 | } |
| 267 | default: |
| 268 | var tj tokenJSON |
| 269 | if err = json.Unmarshal(body, &tj); err != nil { |
| 270 | return nil, err |
| 271 | } |
| 272 | token = &Token{ |
| 273 | AccessToken: tj.AccessToken, |
| 274 | TokenType: tj.TokenType, |
| 275 | RefreshToken: tj.RefreshToken, |
| 276 | Expiry: tj.expiry(), |
| 277 | Raw: make(map[string]interface{}), |
| 278 | } |
| 279 | json.Unmarshal(body, &token.Raw) // no error checks for optional fields |
| 280 | } |
| 281 | if token.AccessToken == "" { |
| 282 | return nil, errors.New("oauth2: server response missing access_token") |
| 283 | } |
| 284 | return token, nil |
| 285 | } |
| 286 | |
| 287 | type RetrieveError struct { |
| 288 | Response *http.Response |
| 289 | Body []byte |
| 290 | } |
| 291 | |
| 292 | func (r *RetrieveError) Error() string { |
| 293 | return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body) |
| 294 | } |