Commit a5aa91b9 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: make client await response concurrently with writing request

If the server replies with an HTTP response before we're done
writing our body (for instance "401 Unauthorized" response), we
were previously ignoring that, since we returned our write
error ("broken pipe", etc) before ever reading the response.
Now we read and write at the same time.

Fixes #3595

R=rsc, adg
CC=golang-dev
https://golang.org/cl/6238043
parent e1d9fcd2
...@@ -323,6 +323,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { ...@@ -323,6 +323,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
cacheKey: cm.String(), cacheKey: cm.String(),
conn: conn, conn: conn,
reqch: make(chan requestAndChan, 50), reqch: make(chan requestAndChan, 50),
writech: make(chan writeRequest, 50),
} }
switch { switch {
...@@ -380,6 +381,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { ...@@ -380,6 +381,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
pconn.br = bufio.NewReader(pconn.conn) pconn.br = bufio.NewReader(pconn.conn)
pconn.bw = bufio.NewWriter(pconn.conn) pconn.bw = bufio.NewWriter(pconn.conn)
go pconn.readLoop() go pconn.readLoop()
go pconn.writeLoop()
return pconn, nil return pconn, nil
} }
...@@ -487,7 +489,8 @@ type persistConn struct { ...@@ -487,7 +489,8 @@ type persistConn struct {
closed bool // whether conn has been closed closed bool // whether conn has been closed
br *bufio.Reader // from conn br *bufio.Reader // from conn
bw *bufio.Writer // to conn bw *bufio.Writer // to conn
reqch chan requestAndChan // written by roundTrip(); read by readLoop() reqch chan requestAndChan // written by roundTrip; read by readLoop
writech chan writeRequest // written by roundTrip; read by writeLoop
isProxy bool isProxy bool
// mutateHeaderFunc is an optional func to modify extra // mutateHeaderFunc is an optional func to modify extra
...@@ -519,6 +522,7 @@ func remoteSideClosed(err error) bool { ...@@ -519,6 +522,7 @@ func remoteSideClosed(err error) bool {
} }
func (pc *persistConn) readLoop() { func (pc *persistConn) readLoop() {
defer close(pc.writech)
alive := true alive := true
var lastbody io.ReadCloser // last response body, if any, read on this connection var lastbody io.ReadCloser // last response body, if any, read on this connection
...@@ -579,7 +583,7 @@ func (pc *persistConn) readLoop() { ...@@ -579,7 +583,7 @@ func (pc *persistConn) readLoop() {
if alive && !pc.t.putIdleConn(pc) { if alive && !pc.t.putIdleConn(pc) {
alive = false alive = false
} }
if !alive { if !alive || pc.isBroken() {
pc.close() pc.close()
} }
waitForBodyRead <- true waitForBodyRead <- true
...@@ -615,6 +619,23 @@ func (pc *persistConn) readLoop() { ...@@ -615,6 +619,23 @@ func (pc *persistConn) readLoop() {
} }
} }
func (pc *persistConn) writeLoop() {
for wr := range pc.writech {
if pc.isBroken() {
wr.ch <- errors.New("http: can't write HTTP request on broken connection")
continue
}
err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra)
if err == nil {
err = pc.bw.Flush()
}
if err != nil {
pc.markBroken()
}
wr.ch <- err
}
}
type responseAndError struct { type responseAndError struct {
res *Response res *Response
err error err error
...@@ -630,6 +651,15 @@ type requestAndChan struct { ...@@ -630,6 +651,15 @@ type requestAndChan struct {
addedGzip bool addedGzip bool
} }
// A writeRequest is sent by the readLoop's goroutine to the
// writeLoop's goroutine to write a request while the read loop
// concurrently waits on both the write response and the server's
// reply.
type writeRequest struct {
req *transportRequest
ch chan<- error
}
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
if pc.mutateHeaderFunc != nil { if pc.mutateHeaderFunc != nil {
pc.mutateHeaderFunc(req.extraHeaders()) pc.mutateHeaderFunc(req.extraHeaders())
...@@ -652,16 +682,29 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err ...@@ -652,16 +682,29 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
pc.numExpectedResponses++ pc.numExpectedResponses++
pc.lk.Unlock() pc.lk.Unlock()
err = req.Request.write(pc.bw, pc.isProxy, req.extra) // Write the request concurrently with waiting for a response,
if err != nil { // in case the server decides to reply before reading our full
pc.close() // request body.
return writeErrCh := make(chan error, 1)
pc.writech <- writeRequest{req, writeErrCh}
resc := make(chan responseAndError, 1)
pc.reqch <- requestAndChan{req.Request, resc, requestedGzip}
var re responseAndError
WaitResponse:
for {
select {
case err := <-writeErrCh:
if err != nil {
re = responseAndError{nil, err}
break WaitResponse
}
case re = <-resc:
break WaitResponse
}
} }
pc.bw.Flush()
ch := make(chan responseAndError, 1)
pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
re := <-ch
pc.lk.Lock() pc.lk.Lock()
pc.numExpectedResponses-- pc.numExpectedResponses--
pc.lk.Unlock() pc.lk.Unlock()
...@@ -669,6 +712,15 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err ...@@ -669,6 +712,15 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
return re.res, re.err return re.res, re.err
} }
// markBroken marks a connection as broken (so it's not reused).
// It differs from close in that it doesn't close the underlying
// connection for use when it's still being read.
func (pc *persistConn) markBroken() {
pc.lk.Lock()
defer pc.lk.Unlock()
pc.broken = true
}
func (pc *persistConn) close() { func (pc *persistConn) close() {
pc.lk.Lock() pc.lk.Lock()
defer pc.lk.Unlock() defer pc.lk.Unlock()
......
...@@ -833,6 +833,30 @@ func TestIssue3644(t *testing.T) { ...@@ -833,6 +833,30 @@ func TestIssue3644(t *testing.T) {
} }
} }
// Test that a client receives a server's reply, even if the server doesn't read
// the entire request body.
func TestIssue3595(t *testing.T) {
const deniedMsg = "sorry, denied."
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
Error(w, deniedMsg, StatusUnauthorized)
}))
defer ts.Close()
tr := &Transport{}
c := &Client{Transport: tr}
res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
if err != nil {
t.Errorf("Post: %v", err)
return
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("Body ReadAll: %v", err)
}
if !strings.Contains(string(got), deniedMsg) {
t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
}
}
type fooProto struct{} type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) { func (fooProto) RoundTrip(req *Request) (*Response, error) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment