blob: f730c39759ebd08577f137cc07e1b4ec90e2b071 [file] [log] [blame]
Matteo Scandoloa4285862020-12-01 18:10:10 -08001/*
2Copyright 2018 The Kubernetes Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package transport
18
19import (
20 "fmt"
21 "io/ioutil"
22 "net/http"
23 "strings"
24 "sync"
25 "time"
26
27 "golang.org/x/oauth2"
28
29 "k8s.io/klog/v2"
30)
31
32// TokenSourceWrapTransport returns a WrapTransport that injects bearer tokens
33// authentication from an oauth2.TokenSource.
34func TokenSourceWrapTransport(ts oauth2.TokenSource) func(http.RoundTripper) http.RoundTripper {
35 return func(rt http.RoundTripper) http.RoundTripper {
36 return &tokenSourceTransport{
37 base: rt,
38 ort: &oauth2.Transport{
39 Source: ts,
40 Base: rt,
41 },
42 }
43 }
44}
45
46// NewCachedFileTokenSource returns a oauth2.TokenSource reads a token from a
47// file at a specified path and periodically reloads it.
48func NewCachedFileTokenSource(path string) oauth2.TokenSource {
49 return &cachingTokenSource{
50 now: time.Now,
51 leeway: 10 * time.Second,
52 base: &fileTokenSource{
53 path: path,
54 // This period was picked because it is half of the duration between when the kubelet
55 // refreshes a projected service account token and when the original token expires.
56 // Default token lifetime is 10 minutes, and the kubelet starts refreshing at 80% of lifetime.
57 // This should induce re-reading at a frequency that works with the token volume source.
58 period: time.Minute,
59 },
60 }
61}
62
63// NewCachedTokenSource returns a oauth2.TokenSource reads a token from a
64// designed TokenSource. The ts would provide the source of token.
65func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource {
66 return &cachingTokenSource{
67 now: time.Now,
68 base: ts,
69 }
70}
71
72type tokenSourceTransport struct {
73 base http.RoundTripper
74 ort http.RoundTripper
75}
76
77func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
78 // This is to allow --token to override other bearer token providers.
79 if req.Header.Get("Authorization") != "" {
80 return tst.base.RoundTrip(req)
81 }
82 return tst.ort.RoundTrip(req)
83}
84
85func (tst *tokenSourceTransport) CancelRequest(req *http.Request) {
86 if req.Header.Get("Authorization") != "" {
87 tryCancelRequest(tst.base, req)
88 return
89 }
90 tryCancelRequest(tst.ort, req)
91}
92
93type fileTokenSource struct {
94 path string
95 period time.Duration
96}
97
98var _ = oauth2.TokenSource(&fileTokenSource{})
99
100func (ts *fileTokenSource) Token() (*oauth2.Token, error) {
101 tokb, err := ioutil.ReadFile(ts.path)
102 if err != nil {
103 return nil, fmt.Errorf("failed to read token file %q: %v", ts.path, err)
104 }
105 tok := strings.TrimSpace(string(tokb))
106 if len(tok) == 0 {
107 return nil, fmt.Errorf("read empty token from file %q", ts.path)
108 }
109
110 return &oauth2.Token{
111 AccessToken: tok,
112 Expiry: time.Now().Add(ts.period),
113 }, nil
114}
115
116type cachingTokenSource struct {
117 base oauth2.TokenSource
118 leeway time.Duration
119
120 sync.RWMutex
121 tok *oauth2.Token
122
123 // for testing
124 now func() time.Time
125}
126
127var _ = oauth2.TokenSource(&cachingTokenSource{})
128
129func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
130 now := ts.now()
131 // fast path
132 ts.RLock()
133 tok := ts.tok
134 ts.RUnlock()
135
136 if tok != nil && tok.Expiry.Add(-1*ts.leeway).After(now) {
137 return tok, nil
138 }
139
140 // slow path
141 ts.Lock()
142 defer ts.Unlock()
143 if tok := ts.tok; tok != nil && tok.Expiry.Add(-1*ts.leeway).After(now) {
144 return tok, nil
145 }
146
147 tok, err := ts.base.Token()
148 if err != nil {
149 if ts.tok == nil {
150 return nil, err
151 }
152 klog.Errorf("Unable to rotate token: %v", err)
153 return ts.tok, nil
154 }
155
156 ts.tok = tok
157 return tok, nil
158}