blob: b41a63d1ff5a6a33715d611f3e21e2b2d117b7c3 [file] [log] [blame]
Scott Bakere7144bc2019-10-01 14:16:47 -07001// Package httpcache provides a http.RoundTripper implementation that works as a
2// mostly RFC-compliant cache for http responses.
3//
4// It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client
5// and not for a shared proxy).
6//
7package httpcache
8
9import (
10 "bufio"
11 "bytes"
12 "errors"
13 "io"
14 "io/ioutil"
15 "net/http"
16 "net/http/httputil"
17 "strings"
18 "sync"
19 "time"
20)
21
22const (
23 stale = iota
24 fresh
25 transparent
26 // XFromCache is the header added to responses that are returned from the cache
27 XFromCache = "X-From-Cache"
28)
29
30// A Cache interface is used by the Transport to store and retrieve responses.
31type Cache interface {
32 // Get returns the []byte representation of a cached response and a bool
33 // set to true if the value isn't empty
34 Get(key string) (responseBytes []byte, ok bool)
35 // Set stores the []byte representation of a response against a key
36 Set(key string, responseBytes []byte)
37 // Delete removes the value associated with the key
38 Delete(key string)
39}
40
41// cacheKey returns the cache key for req.
42func cacheKey(req *http.Request) string {
43 if req.Method == http.MethodGet {
44 return req.URL.String()
45 } else {
46 return req.Method + " " + req.URL.String()
47 }
48}
49
50// CachedResponse returns the cached http.Response for req if present, and nil
51// otherwise.
52func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
53 cachedVal, ok := c.Get(cacheKey(req))
54 if !ok {
55 return
56 }
57
58 b := bytes.NewBuffer(cachedVal)
59 return http.ReadResponse(bufio.NewReader(b), req)
60}
61
62// MemoryCache is an implemtation of Cache that stores responses in an in-memory map.
63type MemoryCache struct {
64 mu sync.RWMutex
65 items map[string][]byte
66}
67
68// Get returns the []byte representation of the response and true if present, false if not
69func (c *MemoryCache) Get(key string) (resp []byte, ok bool) {
70 c.mu.RLock()
71 resp, ok = c.items[key]
72 c.mu.RUnlock()
73 return resp, ok
74}
75
76// Set saves response resp to the cache with key
77func (c *MemoryCache) Set(key string, resp []byte) {
78 c.mu.Lock()
79 c.items[key] = resp
80 c.mu.Unlock()
81}
82
83// Delete removes key from the cache
84func (c *MemoryCache) Delete(key string) {
85 c.mu.Lock()
86 delete(c.items, key)
87 c.mu.Unlock()
88}
89
90// NewMemoryCache returns a new Cache that will store items in an in-memory map
91func NewMemoryCache() *MemoryCache {
92 c := &MemoryCache{items: map[string][]byte{}}
93 return c
94}
95
96// Transport is an implementation of http.RoundTripper that will return values from a cache
97// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since)
98// to repeated requests allowing servers to return 304 / Not Modified
99type Transport struct {
100 // The RoundTripper interface actually used to make requests
101 // If nil, http.DefaultTransport is used
102 Transport http.RoundTripper
103 Cache Cache
104 // If true, responses returned from the cache will be given an extra header, X-From-Cache
105 MarkCachedResponses bool
106}
107
108// NewTransport returns a new Transport with the
109// provided Cache implementation and MarkCachedResponses set to true
110func NewTransport(c Cache) *Transport {
111 return &Transport{Cache: c, MarkCachedResponses: true}
112}
113
114// Client returns an *http.Client that caches responses.
115func (t *Transport) Client() *http.Client {
116 return &http.Client{Transport: t}
117}
118
119// varyMatches will return false unless all of the cached values for the headers listed in Vary
120// match the new request
121func varyMatches(cachedResp *http.Response, req *http.Request) bool {
122 for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") {
123 header = http.CanonicalHeaderKey(header)
124 if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) {
125 return false
126 }
127 }
128 return true
129}
130
131// RoundTrip takes a Request and returns a Response
132//
133// If there is a fresh Response already in cache, then it will be returned without connecting to
134// the server.
135//
136// If there is a stale Response, then any validators it contains will be set on the new request
137// to give the server a chance to respond with NotModified. If this happens, then the cached Response
138// will be returned.
139func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
140 cacheKey := cacheKey(req)
141 cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""
142 var cachedResp *http.Response
143 if cacheable {
144 cachedResp, err = CachedResponse(t.Cache, req)
145 } else {
146 // Need to invalidate an existing value
147 t.Cache.Delete(cacheKey)
148 }
149
150 transport := t.Transport
151 if transport == nil {
152 transport = http.DefaultTransport
153 }
154
155 if cacheable && cachedResp != nil && err == nil {
156 if t.MarkCachedResponses {
157 cachedResp.Header.Set(XFromCache, "1")
158 }
159
160 if varyMatches(cachedResp, req) {
161 // Can only use cached value if the new request doesn't Vary significantly
162 freshness := getFreshness(cachedResp.Header, req.Header)
163 if freshness == fresh {
164 return cachedResp, nil
165 }
166
167 if freshness == stale {
168 var req2 *http.Request
169 // Add validators if caller hasn't already done so
170 etag := cachedResp.Header.Get("etag")
171 if etag != "" && req.Header.Get("etag") == "" {
172 req2 = cloneRequest(req)
173 req2.Header.Set("if-none-match", etag)
174 }
175 lastModified := cachedResp.Header.Get("last-modified")
176 if lastModified != "" && req.Header.Get("last-modified") == "" {
177 if req2 == nil {
178 req2 = cloneRequest(req)
179 }
180 req2.Header.Set("if-modified-since", lastModified)
181 }
182 if req2 != nil {
183 req = req2
184 }
185 }
186 }
187
188 resp, err = transport.RoundTrip(req)
189 if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified {
190 // Replace the 304 response with the one from cache, but update with some new headers
191 endToEndHeaders := getEndToEndHeaders(resp.Header)
192 for _, header := range endToEndHeaders {
193 cachedResp.Header[header] = resp.Header[header]
194 }
195 resp = cachedResp
196 } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) &&
197 req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) {
198 // In case of transport failure and stale-if-error activated, returns cached content
199 // when available
200 return cachedResp, nil
201 } else {
202 if err != nil || resp.StatusCode != http.StatusOK {
203 t.Cache.Delete(cacheKey)
204 }
205 if err != nil {
206 return nil, err
207 }
208 }
209 } else {
210 reqCacheControl := parseCacheControl(req.Header)
211 if _, ok := reqCacheControl["only-if-cached"]; ok {
212 resp = newGatewayTimeoutResponse(req)
213 } else {
214 resp, err = transport.RoundTrip(req)
215 if err != nil {
216 return nil, err
217 }
218 }
219 }
220
221 if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) {
222 for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") {
223 varyKey = http.CanonicalHeaderKey(varyKey)
224 fakeHeader := "X-Varied-" + varyKey
225 reqValue := req.Header.Get(varyKey)
226 if reqValue != "" {
227 resp.Header.Set(fakeHeader, reqValue)
228 }
229 }
230 switch req.Method {
231 case "GET":
232 // Delay caching until EOF is reached.
233 resp.Body = &cachingReadCloser{
234 R: resp.Body,
235 OnEOF: func(r io.Reader) {
236 resp := *resp
237 resp.Body = ioutil.NopCloser(r)
238 respBytes, err := httputil.DumpResponse(&resp, true)
239 if err == nil {
240 t.Cache.Set(cacheKey, respBytes)
241 }
242 },
243 }
244 default:
245 respBytes, err := httputil.DumpResponse(resp, true)
246 if err == nil {
247 t.Cache.Set(cacheKey, respBytes)
248 }
249 }
250 } else {
251 t.Cache.Delete(cacheKey)
252 }
253 return resp, nil
254}
255
256// ErrNoDateHeader indicates that the HTTP headers contained no Date header.
257var ErrNoDateHeader = errors.New("no Date header")
258
259// Date parses and returns the value of the Date header.
260func Date(respHeaders http.Header) (date time.Time, err error) {
261 dateHeader := respHeaders.Get("date")
262 if dateHeader == "" {
263 err = ErrNoDateHeader
264 return
265 }
266
267 return time.Parse(time.RFC1123, dateHeader)
268}
269
270type realClock struct{}
271
272func (c *realClock) since(d time.Time) time.Duration {
273 return time.Since(d)
274}
275
276type timer interface {
277 since(d time.Time) time.Duration
278}
279
280var clock timer = &realClock{}
281
282// getFreshness will return one of fresh/stale/transparent based on the cache-control
283// values of the request and the response
284//
285// fresh indicates the response can be returned
286// stale indicates that the response needs validating before it is returned
287// transparent indicates the response should not be used to fulfil the request
288//
289// Because this is only a private cache, 'public' and 'private' in cache-control aren't
290// signficant. Similarly, smax-age isn't used.
291func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) {
292 respCacheControl := parseCacheControl(respHeaders)
293 reqCacheControl := parseCacheControl(reqHeaders)
294 if _, ok := reqCacheControl["no-cache"]; ok {
295 return transparent
296 }
297 if _, ok := respCacheControl["no-cache"]; ok {
298 return stale
299 }
300 if _, ok := reqCacheControl["only-if-cached"]; ok {
301 return fresh
302 }
303
304 date, err := Date(respHeaders)
305 if err != nil {
306 return stale
307 }
308 currentAge := clock.since(date)
309
310 var lifetime time.Duration
311 var zeroDuration time.Duration
312
313 // If a response includes both an Expires header and a max-age directive,
314 // the max-age directive overrides the Expires header, even if the Expires header is more restrictive.
315 if maxAge, ok := respCacheControl["max-age"]; ok {
316 lifetime, err = time.ParseDuration(maxAge + "s")
317 if err != nil {
318 lifetime = zeroDuration
319 }
320 } else {
321 expiresHeader := respHeaders.Get("Expires")
322 if expiresHeader != "" {
323 expires, err := time.Parse(time.RFC1123, expiresHeader)
324 if err != nil {
325 lifetime = zeroDuration
326 } else {
327 lifetime = expires.Sub(date)
328 }
329 }
330 }
331
332 if maxAge, ok := reqCacheControl["max-age"]; ok {
333 // the client is willing to accept a response whose age is no greater than the specified time in seconds
334 lifetime, err = time.ParseDuration(maxAge + "s")
335 if err != nil {
336 lifetime = zeroDuration
337 }
338 }
339 if minfresh, ok := reqCacheControl["min-fresh"]; ok {
340 // the client wants a response that will still be fresh for at least the specified number of seconds.
341 minfreshDuration, err := time.ParseDuration(minfresh + "s")
342 if err == nil {
343 currentAge = time.Duration(currentAge + minfreshDuration)
344 }
345 }
346
347 if maxstale, ok := reqCacheControl["max-stale"]; ok {
348 // Indicates that the client is willing to accept a response that has exceeded its expiration time.
349 // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded
350 // its expiration time by no more than the specified number of seconds.
351 // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age.
352 //
353 // Responses served only because of a max-stale value are supposed to have a Warning header added to them,
354 // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different
355 // return-value available here.
356 if maxstale == "" {
357 return fresh
358 }
359 maxstaleDuration, err := time.ParseDuration(maxstale + "s")
360 if err == nil {
361 currentAge = time.Duration(currentAge - maxstaleDuration)
362 }
363 }
364
365 if lifetime > currentAge {
366 return fresh
367 }
368
369 return stale
370}
371
372// Returns true if either the request or the response includes the stale-if-error
373// cache control extension: https://tools.ietf.org/html/rfc5861
374func canStaleOnError(respHeaders, reqHeaders http.Header) bool {
375 respCacheControl := parseCacheControl(respHeaders)
376 reqCacheControl := parseCacheControl(reqHeaders)
377
378 var err error
379 lifetime := time.Duration(-1)
380
381 if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok {
382 if staleMaxAge != "" {
383 lifetime, err = time.ParseDuration(staleMaxAge + "s")
384 if err != nil {
385 return false
386 }
387 } else {
388 return true
389 }
390 }
391 if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok {
392 if staleMaxAge != "" {
393 lifetime, err = time.ParseDuration(staleMaxAge + "s")
394 if err != nil {
395 return false
396 }
397 } else {
398 return true
399 }
400 }
401
402 if lifetime >= 0 {
403 date, err := Date(respHeaders)
404 if err != nil {
405 return false
406 }
407 currentAge := clock.since(date)
408 if lifetime > currentAge {
409 return true
410 }
411 }
412
413 return false
414}
415
416func getEndToEndHeaders(respHeaders http.Header) []string {
417 // These headers are always hop-by-hop
418 hopByHopHeaders := map[string]struct{}{
419 "Connection": {},
420 "Keep-Alive": {},
421 "Proxy-Authenticate": {},
422 "Proxy-Authorization": {},
423 "Te": {},
424 "Trailers": {},
425 "Transfer-Encoding": {},
426 "Upgrade": {},
427 }
428
429 for _, extra := range strings.Split(respHeaders.Get("connection"), ",") {
430 // any header listed in connection, if present, is also considered hop-by-hop
431 if strings.Trim(extra, " ") != "" {
432 hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{}
433 }
434 }
435 endToEndHeaders := []string{}
436 for respHeader := range respHeaders {
437 if _, ok := hopByHopHeaders[respHeader]; !ok {
438 endToEndHeaders = append(endToEndHeaders, respHeader)
439 }
440 }
441 return endToEndHeaders
442}
443
444func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) {
445 if _, ok := respCacheControl["no-store"]; ok {
446 return false
447 }
448 if _, ok := reqCacheControl["no-store"]; ok {
449 return false
450 }
451 return true
452}
453
454func newGatewayTimeoutResponse(req *http.Request) *http.Response {
455 var braw bytes.Buffer
456 braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n")
457 resp, err := http.ReadResponse(bufio.NewReader(&braw), req)
458 if err != nil {
459 panic(err)
460 }
461 return resp
462}
463
464// cloneRequest returns a clone of the provided *http.Request.
465// The clone is a shallow copy of the struct and its Header map.
466// (This function copyright goauth2 authors: https://code.google.com/p/goauth2)
467func cloneRequest(r *http.Request) *http.Request {
468 // shallow copy of the struct
469 r2 := new(http.Request)
470 *r2 = *r
471 // deep copy of the Header
472 r2.Header = make(http.Header)
473 for k, s := range r.Header {
474 r2.Header[k] = s
475 }
476 return r2
477}
478
479type cacheControl map[string]string
480
481func parseCacheControl(headers http.Header) cacheControl {
482 cc := cacheControl{}
483 ccHeader := headers.Get("Cache-Control")
484 for _, part := range strings.Split(ccHeader, ",") {
485 part = strings.Trim(part, " ")
486 if part == "" {
487 continue
488 }
489 if strings.ContainsRune(part, '=') {
490 keyval := strings.Split(part, "=")
491 cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",")
492 } else {
493 cc[part] = ""
494 }
495 }
496 return cc
497}
498
499// headerAllCommaSepValues returns all comma-separated values (each
500// with whitespace trimmed) for header name in headers. According to
501// Section 4.2 of the HTTP/1.1 spec
502// (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2),
503// values from multiple occurrences of a header should be concatenated, if
504// the header's value is a comma-separated list.
505func headerAllCommaSepValues(headers http.Header, name string) []string {
506 var vals []string
507 for _, val := range headers[http.CanonicalHeaderKey(name)] {
508 fields := strings.Split(val, ",")
509 for i, f := range fields {
510 fields[i] = strings.TrimSpace(f)
511 }
512 vals = append(vals, fields...)
513 }
514 return vals
515}
516
517// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF
518// handler with a full copy of the content read from R when EOF is
519// reached.
520type cachingReadCloser struct {
521 // Underlying ReadCloser.
522 R io.ReadCloser
523 // OnEOF is called with a copy of the content of R when EOF is reached.
524 OnEOF func(io.Reader)
525
526 buf bytes.Buffer // buf stores a copy of the content of R.
527}
528
529// Read reads the next len(p) bytes from R or until R is drained. The
530// return value n is the number of bytes read. If R has no data to
531// return, err is io.EOF and OnEOF is called with a full copy of what
532// has been read so far.
533func (r *cachingReadCloser) Read(p []byte) (n int, err error) {
534 n, err = r.R.Read(p)
535 r.buf.Write(p[:n])
536 if err == io.EOF {
537 r.OnEOF(bytes.NewReader(r.buf.Bytes()))
538 }
539 return n, err
540}
541
542func (r *cachingReadCloser) Close() error {
543 return r.R.Close()
544}
545
546// NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation
547func NewMemoryCacheTransport() *Transport {
548 c := NewMemoryCache()
549 t := NewTransport(c)
550 return t
551}