gRPC migration
Change-Id: I3129ae27d7ee12a23c7046f0d877e8064f2fd7f4
diff --git a/vendor/golang.org/x/net/http2/transport.go b/vendor/golang.org/x/net/http2/transport.go
index 76a92e0..b97adff 100644
--- a/vendor/golang.org/x/net/http2/transport.go
+++ b/vendor/golang.org/x/net/http2/transport.go
@@ -154,12 +154,21 @@
// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
// It returns an error if t1 has already been HTTP/2-enabled.
+//
+// Use ConfigureTransports instead to configure the HTTP/2 Transport.
func ConfigureTransport(t1 *http.Transport) error {
- _, err := configureTransport(t1)
+ _, err := ConfigureTransports(t1)
return err
}
-func configureTransport(t1 *http.Transport) (*Transport, error) {
+// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2.
+// It returns a new HTTP/2 Transport for further configuration.
+// It returns an error if t1 has already been HTTP/2-enabled.
+func ConfigureTransports(t1 *http.Transport) (*Transport, error) {
+ return configureTransports(t1)
+}
+
+func configureTransports(t1 *http.Transport) (*Transport, error) {
connPool := new(clientConnPool)
t2 := &Transport{
ConnPool: noDialClientConnPool{connPool},
@@ -255,9 +264,8 @@
peerMaxHeaderListSize uint64
initialWindowSize uint32
- hbuf bytes.Buffer // HPACK encoder writes into this
- henc *hpack.Encoder
- freeBuf [][]byte
+ hbuf bytes.Buffer // HPACK encoder writes into this
+ henc *hpack.Encoder
wmu sync.Mutex // held while writing; acquire AFTER mu if holding both
werr error // first write error that has occurred
@@ -555,12 +563,12 @@
return false
}
-func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
+func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
- tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
+ tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host))
if err != nil {
return nil, err
}
@@ -581,34 +589,24 @@
return cfg
}
-func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
+func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) {
if t.DialTLS != nil {
return t.DialTLS
}
- return t.dialTLSDefault
-}
-
-func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
- cn, err := tls.Dial(network, addr, cfg)
- if err != nil {
- return nil, err
- }
- if err := cn.Handshake(); err != nil {
- return nil, err
- }
- if !cfg.InsecureSkipVerify {
- if err := cn.VerifyHostname(cfg.ServerName); err != nil {
+ return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ tlsCn, err := t.dialTLSWithContext(ctx, network, addr, cfg)
+ if err != nil {
return nil, err
}
+ state := tlsCn.ConnectionState()
+ if p := state.NegotiatedProtocol; p != NextProtoTLS {
+ return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
+ }
+ if !state.NegotiatedProtocolIsMutual {
+ return nil, errors.New("http2: could not negotiate protocol mutually")
+ }
+ return tlsCn, nil
}
- state := cn.ConnectionState()
- if p := state.NegotiatedProtocol; p != NextProtoTLS {
- return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
- }
- if !state.NegotiatedProtocolIsMutual {
- return nil, errors.New("http2: could not negotiate protocol mutually")
- }
- return cn, nil
}
// disableKeepAlives reports whether connections should be closed as
@@ -689,6 +687,7 @@
cc.inflow.add(transportDefaultConnFlow + initialWindowSize)
cc.bw.Flush()
if cc.werr != nil {
+ cc.Close()
return nil, cc.werr
}
@@ -913,46 +912,6 @@
return cc.closeForError(err)
}
-const maxAllocFrameSize = 512 << 10
-
-// frameBuffer returns a scratch buffer suitable for writing DATA frames.
-// They're capped at the min of the peer's max frame size or 512KB
-// (kinda arbitrarily), but definitely capped so we don't allocate 4GB
-// bufers.
-func (cc *ClientConn) frameScratchBuffer() []byte {
- cc.mu.Lock()
- size := cc.maxFrameSize
- if size > maxAllocFrameSize {
- size = maxAllocFrameSize
- }
- for i, buf := range cc.freeBuf {
- if len(buf) >= int(size) {
- cc.freeBuf[i] = nil
- cc.mu.Unlock()
- return buf[:size]
- }
- }
- cc.mu.Unlock()
- return make([]byte, size)
-}
-
-func (cc *ClientConn) putFrameScratchBuffer(buf []byte) {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate.
- if len(cc.freeBuf) < maxBufs {
- cc.freeBuf = append(cc.freeBuf, buf)
- return
- }
- for i, old := range cc.freeBuf {
- if old == nil {
- cc.freeBuf[i] = buf
- return
- }
- }
- // forget about it.
-}
-
// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not
// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests.
var errRequestCanceled = errors.New("net/http: request canceled")
@@ -995,7 +954,7 @@
if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv)
}
- if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !strings.EqualFold(vv[0], "close") && !strings.EqualFold(vv[0], "keep-alive")) {
+ if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) {
return fmt.Errorf("http2: invalid Connection request header: %q", vv)
}
return nil
@@ -1080,6 +1039,15 @@
bodyWriter := cc.t.getBodyWriterState(cs, body)
cs.on100 = bodyWriter.on100
+ defer func() {
+ cc.wmu.Lock()
+ werr := cc.werr
+ cc.wmu.Unlock()
+ if werr != nil {
+ cc.Close()
+ }
+ }()
+
cc.wmu.Lock()
endStream := !hasBody && !hasTrailers
werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
@@ -1129,6 +1097,9 @@
// we can keep it.
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWrite)
+ if hasBody && !bodyWritten {
+ <-bodyWriter.resc
+ }
}
if re.err != nil {
cc.forgetStreamID(cs.ID)
@@ -1149,6 +1120,7 @@
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
+ <-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), errTimeout
@@ -1158,6 +1130,7 @@
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
+ <-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), ctx.Err()
@@ -1167,6 +1140,7 @@
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
+ <-bodyWriter.resc
}
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), errRequestCanceled
@@ -1176,6 +1150,7 @@
// forgetStreamID.
return nil, cs.getStartedWrite(), cs.resetErr
case err := <-bodyWriter.resc:
+ bodyWritten = true
// Prefer the read loop's response, if available. Issue 16102.
select {
case re := <-readLoopResCh:
@@ -1186,7 +1161,6 @@
cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), err
}
- bodyWritten = true
if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d)
defer timer.Stop()
@@ -1280,11 +1254,35 @@
errReqBodyTooLong = errors.New("http2: request body larger than specified content length")
)
+// frameScratchBufferLen returns the length of a buffer to use for
+// outgoing request bodies to read/write to/from.
+//
+// It returns max(1, min(peer's advertised max frame size,
+// Request.ContentLength+1, 512KB)).
+func (cs *clientStream) frameScratchBufferLen(maxFrameSize int) int {
+ const max = 512 << 10
+ n := int64(maxFrameSize)
+ if n > max {
+ n = max
+ }
+ if cl := actualContentLength(cs.req); cl != -1 && cl+1 < n {
+ // Add an extra byte past the declared content-length to
+ // give the caller's Request.Body io.Reader a chance to
+ // give us more bytes than they declared, so we can catch it
+ // early.
+ n = cl + 1
+ }
+ if n < 1 {
+ return 1
+ }
+ return int(n) // doesn't truncate; max is 512K
+}
+
+var bufPool sync.Pool // of *[]byte
+
func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
cc := cs.cc
sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
- buf := cc.frameScratchBuffer()
- defer cc.putFrameScratchBuffer(buf)
defer func() {
traceWroteRequest(cs.trace, err)
@@ -1303,9 +1301,24 @@
remainLen := actualContentLength(req)
hasContentLen := remainLen != -1
+ cc.mu.Lock()
+ maxFrameSize := int(cc.maxFrameSize)
+ cc.mu.Unlock()
+
+ // Scratch buffer for reading into & writing from.
+ scratchLen := cs.frameScratchBufferLen(maxFrameSize)
+ var buf []byte
+ if bp, ok := bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen {
+ defer bufPool.Put(bp)
+ buf = *bp
+ } else {
+ buf = make([]byte, scratchLen)
+ defer bufPool.Put(&buf)
+ }
+
var sawEOF bool
for !sawEOF {
- n, err := body.Read(buf[:len(buf)-1])
+ n, err := body.Read(buf[:len(buf)])
if hasContentLen {
remainLen -= int64(n)
if remainLen == 0 && err == nil {
@@ -1316,8 +1329,9 @@
// to send the END_STREAM bit early, double-check that we're actually
// at EOF. Subsequent reads should return (0, EOF) at this point.
// If either value is different, we return an error in one of two ways below.
+ var scratch [1]byte
var n1 int
- n1, err = body.Read(buf[n:])
+ n1, err = body.Read(scratch[:])
remainLen -= int64(n1)
}
if remainLen < 0 {
@@ -1387,10 +1401,6 @@
}
}
- cc.mu.Lock()
- maxFrameSize := int(cc.maxFrameSize)
- cc.mu.Unlock()
-
cc.wmu.Lock()
defer cc.wmu.Unlock()
@@ -1506,19 +1516,21 @@
var didUA bool
for k, vv := range req.Header {
- if strings.EqualFold(k, "host") || strings.EqualFold(k, "content-length") {
+ if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") {
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
- } else if strings.EqualFold(k, "connection") || strings.EqualFold(k, "proxy-connection") ||
- strings.EqualFold(k, "transfer-encoding") || strings.EqualFold(k, "upgrade") ||
- strings.EqualFold(k, "keep-alive") {
+ } else if asciiEqualFold(k, "connection") ||
+ asciiEqualFold(k, "proxy-connection") ||
+ asciiEqualFold(k, "transfer-encoding") ||
+ asciiEqualFold(k, "upgrade") ||
+ asciiEqualFold(k, "keep-alive") {
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
- } else if strings.EqualFold(k, "user-agent") {
+ } else if asciiEqualFold(k, "user-agent") {
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
@@ -1531,7 +1543,7 @@
if vv[0] == "" {
continue
}
- } else if strings.EqualFold(k, "cookie") {
+ } else if asciiEqualFold(k, "cookie") {
// Per 8.1.2.5 To allow for better compression efficiency, the
// Cookie header field MAY be split into separate header fields,
// each with one or more cookie-pairs.
@@ -1590,7 +1602,12 @@
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
- name = strings.ToLower(name)
+ name, ascii := asciiToLower(name)
+ if !ascii {
+ // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
+ // field names have to be ASCII characters (just as in HTTP/1.x).
+ return
+ }
cc.writeHeader(name, value)
if traceHeaders {
traceWroteHeaderField(trace, name, value)
@@ -1638,9 +1655,14 @@
}
for k, vv := range req.Trailer {
+ lowKey, ascii := asciiToLower(k)
+ if !ascii {
+ // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
+ // field names have to be ASCII characters (just as in HTTP/1.x).
+ continue
+ }
// Transfer-Encoding, etc.. have already been filtered at the
// start of RoundTrip
- lowKey := strings.ToLower(k)
for _, v := range vv {
cc.writeHeader(lowKey, v)
}
@@ -2006,8 +2028,8 @@
if !streamEnded || isHead {
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
- if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
- res.ContentLength = clen64
+ if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil {
+ res.ContentLength = int64(cl)
} else {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
@@ -2525,6 +2547,7 @@
type erringRoundTripper struct{ err error }
+func (rt erringRoundTripper) RoundTripErr() error { return rt.err }
func (rt erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { return nil, rt.err }
// gzipReader wraps a response body so it can lazily
@@ -2606,7 +2629,9 @@
func (s bodyWriterState) cancel() {
if s.timer != nil {
- s.timer.Stop()
+ if s.timer.Stop() {
+ s.resc <- nil
+ }
}
}