Commit c40a73d8 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: make hidden http2 Transport respect remaining Transport fields

Updates x/net/http2 to git rev 72aa00c6 for https://golang.org/cl/18721
(but actually at https://golang.org/cl/18722 now)

Fixes #14008

Change-Id: If05d5ad51ec0ba5ba7e4fe16605c0a83f0484bc8
Reviewed-on: https://go-review.googlesource.com/18723
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 3208d92b
...@@ -47,7 +47,7 @@ const ( ...@@ -47,7 +47,7 @@ const (
h2Mode = true h2Mode = true
) )
func newClientServerTest(t *testing.T, h2 bool, h Handler) *clientServerTest { func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest {
cst := &clientServerTest{ cst := &clientServerTest{
t: t, t: t,
h2: h2, h2: h2,
...@@ -55,6 +55,16 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler) *clientServerTest { ...@@ -55,6 +55,16 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler) *clientServerTest {
tr: &Transport{}, tr: &Transport{},
} }
cst.c = &Client{Transport: cst.tr} cst.c = &Client{Transport: cst.tr}
for _, opt := range opts {
switch opt := opt.(type) {
case func(*Transport):
opt(cst.tr)
default:
t.Fatalf("unhandled option type %T", opt)
}
}
if !h2 { if !h2 {
cst.ts = httptest.NewServer(h) cst.ts = httptest.NewServer(h)
return cst return cst
...@@ -139,6 +149,7 @@ type h12Compare struct { ...@@ -139,6 +149,7 @@ type h12Compare struct {
Handler func(ResponseWriter, *Request) // required Handler func(ResponseWriter, *Request) // required
ReqFunc reqFunc // optional ReqFunc reqFunc // optional
CheckResponse func(proto string, res *Response) // optional CheckResponse func(proto string, res *Response) // optional
Opts []interface{}
} }
func (tt h12Compare) reqFunc() reqFunc { func (tt h12Compare) reqFunc() reqFunc {
...@@ -149,9 +160,9 @@ func (tt h12Compare) reqFunc() reqFunc { ...@@ -149,9 +160,9 @@ func (tt h12Compare) reqFunc() reqFunc {
} }
func (tt h12Compare) run(t *testing.T) { func (tt h12Compare) run(t *testing.T) {
cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler)) cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
defer cst1.close() defer cst1.close()
cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler)) cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
defer cst2.close() defer cst2.close()
res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
...@@ -380,6 +391,20 @@ func TestH12_AutoGzip(t *testing.T) { ...@@ -380,6 +391,20 @@ func TestH12_AutoGzip(t *testing.T) {
}.run(t) }.run(t)
} }
func TestH12_AutoGzip_Disabled(t *testing.T) {
h12Compare{
Opts: []interface{}{
func(tr *Transport) { tr.DisableCompression = true },
},
Handler: func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
if ae := r.Header.Get("Accept-Encoding"); ae != "" {
t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
}
},
}.run(t)
}
// Test304Responses verifies that 304s don't declare that they're // Test304Responses verifies that 304s don't declare that they're
// chunking in their response headers and aren't allowed to produce // chunking in their response headers and aren't allowed to produce
// output. // output.
......
...@@ -24,7 +24,6 @@ import ( ...@@ -24,7 +24,6 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/net/http2/hpack"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
...@@ -38,6 +37,8 @@ import ( ...@@ -38,6 +37,8 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"golang.org/x/net/http2/hpack"
) )
// ClientConnPool manages a pool of HTTP/2 client connections. // ClientConnPool manages a pool of HTTP/2 client connections.
...@@ -248,7 +249,11 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [ ...@@ -248,7 +249,11 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [
func http2configureTransport(t1 *Transport) (*http2Transport, error) { func http2configureTransport(t1 *Transport) (*http2Transport, error) {
connPool := new(http2clientConnPool) connPool := new(http2clientConnPool)
t2 := &http2Transport{ConnPool: http2noDialClientConnPool{connPool}} t2 := &http2Transport{
ConnPool: http2noDialClientConnPool{connPool},
t1: t1,
}
connPool.t = t2
if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil { if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil {
return nil, err return nil, err
} }
...@@ -2184,6 +2189,19 @@ func http2bodyAllowedForStatus(status int) bool { ...@@ -2184,6 +2189,19 @@ func http2bodyAllowedForStatus(status int) bool {
return true return true
} }
type http2httpError struct {
msg string
timeout bool
}
func (e *http2httpError) Error() string { return e.msg }
func (e *http2httpError) Timeout() bool { return e.timeout }
func (e *http2httpError) Temporary() bool { return true }
var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true}
// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like
// io.Pipe except there are no PipeReader/PipeWriter halves, and the // io.Pipe except there are no PipeReader/PipeWriter halves, and the
// underlying buffer is an interface. (io.Pipe is always unbuffered) // underlying buffer is an interface. (io.Pipe is always unbuffered)
...@@ -4320,6 +4338,11 @@ type http2Transport struct { ...@@ -4320,6 +4338,11 @@ type http2Transport struct {
// to mean no limit. // to mean no limit.
MaxHeaderListSize uint32 MaxHeaderListSize uint32
// t1, if non-nil, is the standard library Transport using
// this transport. Its settings are used (but not its
// RoundTrip method, etc).
t1 *Transport
connPoolOnce sync.Once connPoolOnce sync.Once
connPoolOrDef http2ClientConnPool // non-nil version of ConnPool connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
} }
...@@ -4335,11 +4358,7 @@ func (t *http2Transport) maxHeaderListSize() uint32 { ...@@ -4335,11 +4358,7 @@ func (t *http2Transport) maxHeaderListSize() uint32 {
} }
func (t *http2Transport) disableCompression() bool { func (t *http2Transport) disableCompression() bool {
if t.DisableCompression { return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
return true
}
return false
} }
var http2errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6") var http2errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6")
...@@ -4395,7 +4414,7 @@ type http2ClientConn struct { ...@@ -4395,7 +4414,7 @@ type http2ClientConn struct {
henc *hpack.Encoder henc *hpack.Encoder
freeBuf [][]byte freeBuf [][]byte
wmu sync.Mutex // held while writing; acquire AFTER wmu if holding both wmu sync.Mutex // held while writing; acquire AFTER mu if holding both
werr error // first write error that has occurred werr error // first write error that has occurred
} }
...@@ -4413,7 +4432,7 @@ type http2clientStream struct { ...@@ -4413,7 +4432,7 @@ type http2clientStream struct {
inflow http2flow // guarded by cc.mu inflow http2flow // guarded by cc.mu
bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
readErr error // sticky read error; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read
stopReqBody bool // stop writing req body; guarded by cc.mu stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu
peerReset chan struct{} // closed on peer reset peerReset chan struct{} // closed on peer reset
resetErr error // populated before peerReset is closed resetErr error // populated before peerReset is closed
...@@ -4456,10 +4475,13 @@ func (cs *http2clientStream) checkReset() error { ...@@ -4456,10 +4475,13 @@ func (cs *http2clientStream) checkReset() error {
} }
} }
func (cs *http2clientStream) abortRequestBodyWrite() { func (cs *http2clientStream) abortRequestBodyWrite(err error) {
if err == nil {
panic("nil error")
}
cc := cs.cc cc := cs.cc
cc.mu.Lock() cc.mu.Lock()
cs.stopReqBody = true cs.stopReqBody = err
cc.cond.Broadcast() cc.cond.Broadcast()
cc.mu.Unlock() cc.mu.Unlock()
} }
...@@ -4598,6 +4620,12 @@ func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) ( ...@@ -4598,6 +4620,12 @@ func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (
return cn, nil return cn, nil
} }
// disableKeepAlives reports whether connections should be closed as
// soon as possible after handling the first request.
func (t *http2Transport) disableKeepAlives() bool {
return t.t1 != nil && t.t1.DisableKeepAlives
}
func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
if http2VerboseLogs { if http2VerboseLogs {
t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr()) t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
...@@ -4692,7 +4720,7 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool { ...@@ -4692,7 +4720,7 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool {
} }
func (cc *http2ClientConn) canTakeNewRequestLocked() bool { func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
return cc.goAway == nil && return cc.goAway == nil && !cc.closed &&
int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) && int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
cc.nextStreamID < 2147483647 cc.nextStreamID < 2147483647
} }
...@@ -4772,6 +4800,14 @@ func http2commaSeparatedTrailers(req *Request) (string, error) { ...@@ -4772,6 +4800,14 @@ func http2commaSeparatedTrailers(req *Request) (string, error) {
return "", nil return "", nil
} }
func (cc *http2ClientConn) responseHeaderTimeout() time.Duration {
if cc.t.t1 != nil {
return cc.t.t1.ResponseHeaderTimeout
}
return 0
}
func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
trailers, err := http2commaSeparatedTrailers(req) trailers, err := http2commaSeparatedTrailers(req)
if err != nil { if err != nil {
...@@ -4832,24 +4868,32 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { ...@@ -4832,24 +4868,32 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
return nil, werr return nil, werr
} }
var respHeaderTimer <-chan time.Time
var bodyCopyErrc chan error // result of body copy var bodyCopyErrc chan error // result of body copy
if hasBody { if hasBody {
bodyCopyErrc = make(chan error, 1) bodyCopyErrc = make(chan error, 1)
go func() { go func() {
bodyCopyErrc <- cs.writeRequestBody(body, req.Body) bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
}() }()
} else {
if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C
}
} }
readLoopResCh := cs.resc readLoopResCh := cs.resc
requestCanceledCh := http2requestCancel(req) requestCanceledCh := http2requestCancel(req)
requestCanceled := false bodyWritten := false
for { for {
select { select {
case re := <-readLoopResCh: case re := <-readLoopResCh:
res := re.res res := re.res
if re.err != nil || res.StatusCode > 299 { if re.err != nil || res.StatusCode > 299 {
cs.abortRequestBodyWrite() cs.abortRequestBodyWrite(http2errStopReqBodyWrite)
} }
if re.err != nil { if re.err != nil {
cc.forgetStreamID(cs.ID) cc.forgetStreamID(cs.ID)
...@@ -4858,32 +4902,35 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { ...@@ -4858,32 +4902,35 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
res.Request = req res.Request = req
res.TLS = cc.tlsState res.TLS = cc.tlsState
return res, nil return res, nil
case <-respHeaderTimer:
cc.forgetStreamID(cs.ID)
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
} else {
cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
}
return nil, http2errTimeout
case <-requestCanceledCh: case <-requestCanceledCh:
cc.forgetStreamID(cs.ID) cc.forgetStreamID(cs.ID)
cs.abortRequestBodyWrite() if !hasBody || bodyWritten {
if !hasBody {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
return nil, http2errRequestCanceled } else {
cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
} }
requestCanceled = true
requestCanceledCh = nil
readLoopResCh = nil
case <-cs.peerReset:
if requestCanceled {
return nil, http2errRequestCanceled return nil, http2errRequestCanceled
} case <-cs.peerReset:
return nil, cs.resetErr return nil, cs.resetErr
case err := <-bodyCopyErrc: case err := <-bodyCopyErrc:
if requestCanceled {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
return nil, http2errRequestCanceled
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
bodyWritten = true
if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C
}
} }
} }
} }
...@@ -4916,9 +4963,14 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs [] ...@@ -4916,9 +4963,14 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []
return cc.werr return cc.werr
} }
// errAbortReqBodyWrite is an internal error value. // internal error values; they don't escape to callers
// It doesn't escape to callers. var (
var http2errAbortReqBodyWrite = errors.New("http2: aborting request body write") // abort request body write; don't send cancel
http2errStopReqBodyWrite = errors.New("http2: aborting request body write")
// abort request body write, but send stream reset of cancel.
http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request")
)
func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) { func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
cc := cs.cc cc := cs.cc
...@@ -4951,7 +5003,13 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos ...@@ -4951,7 +5003,13 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos
for len(remain) > 0 && err == nil { for len(remain) > 0 && err == nil {
var allowed int32 var allowed int32
allowed, err = cs.awaitFlowControl(len(remain)) allowed, err = cs.awaitFlowControl(len(remain))
if err != nil { switch {
case err == http2errStopReqBodyWrite:
return err
case err == http2errStopReqBodyWriteAndCancel:
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
return err
case err != nil:
return err return err
} }
cc.wmu.Lock() cc.wmu.Lock()
...@@ -5005,8 +5063,8 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er ...@@ -5005,8 +5063,8 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er
if cc.closed { if cc.closed {
return 0, http2errClientConnClosed return 0, http2errClientConnClosed
} }
if cs.stopReqBody { if cs.stopReqBody != nil {
return 0, http2errAbortReqBodyWrite return 0, cs.stopReqBody
} }
if err := cs.checkReset(); err != nil { if err := cs.checkReset(); err != nil {
return 0, err return 0, err
...@@ -5074,7 +5132,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail ...@@ -5074,7 +5132,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
cc.writeHeader(lowKey, v) cc.writeHeader(lowKey, v)
} }
} }
if contentLength >= 0 { if http2shouldSendReqContentLength(req.Method, contentLength) {
cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
} }
if addGzipHeader { if addGzipHeader {
...@@ -5086,6 +5144,27 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail ...@@ -5086,6 +5144,27 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
return cc.hbuf.Bytes() return cc.hbuf.Bytes()
} }
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func http2shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
// requires cc.mu be held. // requires cc.mu be held.
func (cc *http2ClientConn) encodeTrailers(req *Request) []byte { func (cc *http2ClientConn) encodeTrailers(req *Request) []byte {
cc.hbuf.Reset() cc.hbuf.Reset()
...@@ -5204,6 +5283,8 @@ func (rl *http2clientConnReadLoop) cleanup() { ...@@ -5204,6 +5283,8 @@ func (rl *http2clientConnReadLoop) cleanup() {
func (rl *http2clientConnReadLoop) run() error { func (rl *http2clientConnReadLoop) run() error {
cc := rl.cc cc := rl.cc
closeWhenIdle := cc.t.disableKeepAlives()
gotReply := false
for { for {
f, err := cc.fr.ReadFrame() f, err := cc.fr.ReadFrame()
if err != nil { if err != nil {
...@@ -5218,18 +5299,25 @@ func (rl *http2clientConnReadLoop) run() error { ...@@ -5218,18 +5299,25 @@ func (rl *http2clientConnReadLoop) run() error {
if http2VerboseLogs { if http2VerboseLogs {
cc.vlogf("http2: Transport received %s", http2summarizeFrame(f)) cc.vlogf("http2: Transport received %s", http2summarizeFrame(f))
} }
maybeIdle := false
switch f := f.(type) { switch f := f.(type) {
case *http2HeadersFrame: case *http2HeadersFrame:
err = rl.processHeaders(f) err = rl.processHeaders(f)
maybeIdle = true
gotReply = true
case *http2ContinuationFrame: case *http2ContinuationFrame:
err = rl.processContinuation(f) err = rl.processContinuation(f)
maybeIdle = true
case *http2DataFrame: case *http2DataFrame:
err = rl.processData(f) err = rl.processData(f)
maybeIdle = true
case *http2GoAwayFrame: case *http2GoAwayFrame:
err = rl.processGoAway(f) err = rl.processGoAway(f)
maybeIdle = true
case *http2RSTStreamFrame: case *http2RSTStreamFrame:
err = rl.processResetStream(f) err = rl.processResetStream(f)
maybeIdle = true
case *http2SettingsFrame: case *http2SettingsFrame:
err = rl.processSettings(f) err = rl.processSettings(f)
case *http2PushPromiseFrame: case *http2PushPromiseFrame:
...@@ -5244,6 +5332,9 @@ func (rl *http2clientConnReadLoop) run() error { ...@@ -5244,6 +5332,9 @@ func (rl *http2clientConnReadLoop) run() error {
if err != nil { if err != nil {
return err return err
} }
if closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 {
cc.closeIfIdle()
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment