David K. Bainbridge | 215e024 | 2017-09-05 23:18:24 -0700 | [diff] [blame] | 1 | // Copyright 2015 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | // +build go1.6 |
| 6 | |
| 7 | package http2 |
| 8 | |
| 9 | import ( |
| 10 | "crypto/tls" |
| 11 | "fmt" |
| 12 | "net/http" |
| 13 | ) |
| 14 | |
| 15 | func configureTransport(t1 *http.Transport) (*Transport, error) { |
| 16 | connPool := new(clientConnPool) |
| 17 | t2 := &Transport{ |
| 18 | ConnPool: noDialClientConnPool{connPool}, |
| 19 | t1: t1, |
| 20 | } |
| 21 | connPool.t = t2 |
| 22 | if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { |
| 23 | return nil, err |
| 24 | } |
| 25 | if t1.TLSClientConfig == nil { |
| 26 | t1.TLSClientConfig = new(tls.Config) |
| 27 | } |
| 28 | if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") { |
| 29 | t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...) |
| 30 | } |
| 31 | if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") { |
| 32 | t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") |
| 33 | } |
| 34 | upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper { |
| 35 | addr := authorityAddr("https", authority) |
| 36 | if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { |
| 37 | go c.Close() |
| 38 | return erringRoundTripper{err} |
| 39 | } else if !used { |
| 40 | // Turns out we don't need this c. |
| 41 | // For example, two goroutines made requests to the same host |
| 42 | // at the same time, both kicking off TCP dials. (since protocol |
| 43 | // was unknown) |
| 44 | go c.Close() |
| 45 | } |
| 46 | return t2 |
| 47 | } |
| 48 | if m := t1.TLSNextProto; len(m) == 0 { |
| 49 | t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{ |
| 50 | "h2": upgradeFn, |
| 51 | } |
| 52 | } else { |
| 53 | m["h2"] = upgradeFn |
| 54 | } |
| 55 | return t2, nil |
| 56 | } |
| 57 | |
| 58 | // registerHTTPSProtocol calls Transport.RegisterProtocol but |
| 59 | // converting panics into errors. |
| 60 | func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) { |
| 61 | defer func() { |
| 62 | if e := recover(); e != nil { |
| 63 | err = fmt.Errorf("%v", e) |
| 64 | } |
| 65 | }() |
| 66 | t.RegisterProtocol("https", rt) |
| 67 | return nil |
| 68 | } |
| 69 | |
| 70 | // noDialH2RoundTripper is a RoundTripper which only tries to complete the request |
| 71 | // if there's already has a cached connection to the host. |
| 72 | type noDialH2RoundTripper struct{ t *Transport } |
| 73 | |
| 74 | func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { |
| 75 | res, err := rt.t.RoundTrip(req) |
| 76 | if err == ErrNoCachedConn { |
| 77 | return nil, http.ErrSkipAltProtocol |
| 78 | } |
| 79 | return res, err |
| 80 | } |