Commit 127d2bf7 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix Transport races & deadlocks

Thanks to Dustin Sallings for exposing the most frustrating
bug ever, and for providing repro cases (which formed the
basis of the new tests in this CL), and to Dave Cheney and
Dmitry Vyukov for help debugging and fixing.

This CL depends on submited pollster CLs ffd1e075c260 (Unix)
and 14b544194509 (Windows), as well as unsubmitted 6852085.
Some operating systems (OpenBSD, NetBSD, ?) may still require
more pollster work, fixing races (Issue 4434 and
http://goo.gl/JXB6W).

Tested on linux-amd64 and darwin-amd64, both with GOMAXPROCS 1
and 4 (all combinations of which previously failed differently)

Fixes #4191
Update #4434 (related fallout from this bug)

R=dave, bradfitz, dsallings, rsc, fullung
CC=golang-dev
https://golang.org/cl/6851061
parent 5188c0b5
...@@ -7,7 +7,14 @@ ...@@ -7,7 +7,14 @@
package http package http
import "time" import (
"net"
"time"
)
func NewLoggingConn(baseName string, c net.Conn) net.Conn {
return newLoggingConn(baseName, c)
}
func (t *Transport) IdleConnKeysForTesting() (keys []string) { func (t *Transport) IdleConnKeysForTesting() (keys []string) {
keys = make([]string, 0) keys = make([]string, 0)
......
...@@ -170,16 +170,23 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) { ...@@ -170,16 +170,23 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
// noLimit is an effective infinite upper bound for io.LimitedReader // noLimit is an effective infinite upper bound for io.LimitedReader
const noLimit int64 = (1 << 63) - 1 const noLimit int64 = (1 << 63) - 1
// debugServerConnections controls whether all server connections are wrapped
// with a verbose logging wrapper.
const debugServerConnections = false
// Create new connection from rwc. // Create new connection from rwc.
func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
c = new(conn) c = new(conn)
c.remoteAddr = rwc.RemoteAddr().String() c.remoteAddr = rwc.RemoteAddr().String()
c.server = srv c.server = srv
c.rwc = rwc c.rwc = rwc
if debugServerConnections {
c.rwc = newLoggingConn("server", c.rwc)
}
c.body = make([]byte, sniffLen) c.body = make([]byte, sniffLen)
c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader) c.lr = io.LimitReader(c.rwc, noLimit).(*io.LimitedReader)
br := bufio.NewReader(c.lr) br := bufio.NewReader(c.lr)
bw := bufio.NewWriter(rwc) bw := bufio.NewWriter(c.rwc)
c.buf = bufio.NewReadWriter(br, bw) c.buf = bufio.NewReadWriter(br, bw)
return c, nil return c, nil
} }
...@@ -495,7 +502,7 @@ func (w *response) Write(data []byte) (n int, err error) { ...@@ -495,7 +502,7 @@ func (w *response) Write(data []byte) (n int, err error) {
// then there would be fewer chunk headers. // then there would be fewer chunk headers.
// On the other hand, it would make hijacking more difficult. // On the other hand, it would make hijacking more difficult.
if w.chunking { if w.chunking {
fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt fmt.Fprintf(w.conn.buf, "%x\r\n", len(data))
} }
n, err = w.conn.buf.Write(data) n, err = w.conn.buf.Write(data)
if err == nil && w.chunking { if err == nil && w.chunking {
...@@ -1309,3 +1316,45 @@ func (tw *timeoutWriter) WriteHeader(code int) { ...@@ -1309,3 +1316,45 @@ func (tw *timeoutWriter) WriteHeader(code int) {
tw.mu.Unlock() tw.mu.Unlock()
tw.w.WriteHeader(code) tw.w.WriteHeader(code)
} }
// loggingConn is used for debugging.
type loggingConn struct {
name string
net.Conn
}
var (
uniqNameMu sync.Mutex
uniqNameNext = make(map[string]int)
)
func newLoggingConn(baseName string, c net.Conn) net.Conn {
uniqNameMu.Lock()
defer uniqNameMu.Unlock()
uniqNameNext[baseName]++
return &loggingConn{
name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]),
Conn: c,
}
}
func (c *loggingConn) Write(p []byte) (n int, err error) {
log.Printf("%s.Write(%d) = ....", c.name, len(p))
n, err = c.Conn.Write(p)
log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err)
return
}
func (c *loggingConn) Read(p []byte) (n int, err error) {
log.Printf("%s.Read(%d) = ....", c.name, len(p))
n, err = c.Conn.Read(p)
log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err)
return
}
func (c *loggingConn) Close() (err error) {
log.Printf("%s.Close() = ...", c.name)
err = c.Conn.Close()
log.Printf("%s.Close() = %v", c.name, err)
return
}
...@@ -24,7 +24,6 @@ import ( ...@@ -24,7 +24,6 @@ import (
"os" "os"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
...@@ -613,14 +612,18 @@ func (pc *persistConn) readLoop() { ...@@ -613,14 +612,18 @@ func (pc *persistConn) readLoop() {
if hasBody { if hasBody {
lastbody = resp.Body lastbody = resp.Body
waitForBodyRead = make(chan bool, 1) waitForBodyRead = make(chan bool, 1)
resp.Body.(*bodyEOFSignal).fn = func() { resp.Body.(*bodyEOFSignal).fn = func(err error) {
if alive && !pc.t.putIdleConn(pc) { alive1 := alive
alive = false if err != nil {
alive1 = false
}
if alive1 && !pc.t.putIdleConn(pc) {
alive1 = false
} }
if !alive || pc.isBroken() { if !alive1 || pc.isBroken() {
pc.close() pc.close()
} }
waitForBodyRead <- true waitForBodyRead <- alive1
} }
} }
...@@ -644,7 +647,7 @@ func (pc *persistConn) readLoop() { ...@@ -644,7 +647,7 @@ func (pc *persistConn) readLoop() {
// Wait for the just-returned response body to be fully consumed // Wait for the just-returned response body to be fully consumed
// before we race and peek on the underlying bufio reader. // before we race and peek on the underlying bufio reader.
if waitForBodyRead != nil { if waitForBodyRead != nil {
<-waitForBodyRead alive = <-waitForBodyRead
} }
if !alive { if !alive {
...@@ -810,50 +813,61 @@ func canonicalAddr(url *url.URL) string { ...@@ -810,50 +813,61 @@ func canonicalAddr(url *url.URL) string {
} }
// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
// once, right before the final Read() or Close() call returns, but after // once, right before its final (error-producing) Read or Close call
// EOF has been seen. // returns.
type bodyEOFSignal struct { type bodyEOFSignal struct {
body io.ReadCloser body io.ReadCloser
fn func() mu sync.Mutex // guards closed, rerr and fn
isClosed uint32 // atomic bool, non-zero if true closed bool // whether Close has been called
once sync.Once rerr error // sticky Read error
fn func(error) // error will be nil on Read io.EOF
} }
func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
n, err = es.body.Read(p) es.mu.Lock()
if es.closed() && n > 0 { closed, rerr := es.closed, es.rerr
panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") es.mu.Unlock()
if closed {
return 0, errors.New("http: read on closed response body")
} }
if err == io.EOF { if rerr != nil {
es.condfn() return 0, rerr
} }
return
}
func (es *bodyEOFSignal) Close() (err error) { n, err = es.body.Read(p)
if !es.setClosed() { if err != nil {
// already closed es.mu.Lock()
return nil defer es.mu.Unlock()
if es.rerr == nil {
es.rerr = err
} }
err = es.body.Close() es.condfn(err)
if err == nil {
es.condfn()
} }
return return
} }
func (es *bodyEOFSignal) condfn() { func (es *bodyEOFSignal) Close() error {
if es.fn != nil { es.mu.Lock()
es.once.Do(es.fn) defer es.mu.Unlock()
if es.closed {
return nil
} }
es.closed = true
err := es.body.Close()
es.condfn(err)
return err
} }
func (es *bodyEOFSignal) closed() bool { // caller must hold es.mu.
return atomic.LoadUint32(&es.isClosed) != 0 func (es *bodyEOFSignal) condfn(err error) {
} if es.fn == nil {
return
func (es *bodyEOFSignal) setClosed() bool { }
return atomic.CompareAndSwapUint32(&es.isClosed, 0, 1) if err == io.EOF {
err = nil
}
es.fn(err)
es.fn = nil
} }
type readFirstCloseBoth struct { type readFirstCloseBoth struct {
......
...@@ -901,6 +901,111 @@ func TestTransportConcurrency(t *testing.T) { ...@@ -901,6 +901,111 @@ func TestTransportConcurrency(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
const debug = false
mux := NewServeMux()
mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
io.Copy(w, neverEnding('a'))
})
ts := httptest.NewServer(mux)
client := &Client{
Transport: &Transport{
Dial: func(n, addr string) (net.Conn, error) {
conn, err := net.Dial(n, addr)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
if debug {
conn = NewLoggingConn("client", conn)
}
return conn, nil
},
DisableKeepAlives: true,
},
}
nRuns := 5
if testing.Short() {
nRuns = 1
}
for i := 0; i < nRuns; i++ {
if debug {
println("run", i+1, "of", nRuns)
}
sres, err := client.Get(ts.URL + "/get")
if err != nil {
t.Errorf("Error issuing GET: %v", err)
break
}
_, err = io.Copy(ioutil.Discard, sres.Body)
if err == nil {
t.Errorf("Unexpected successful copy")
break
}
}
if debug {
println("tests complete; waiting for handlers to finish")
}
ts.Close()
}
func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
const debug = false
mux := NewServeMux()
mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
io.Copy(w, neverEnding('a'))
})
mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
defer r.Body.Close()
io.Copy(ioutil.Discard, r.Body)
})
ts := httptest.NewServer(mux)
client := &Client{
Transport: &Transport{
Dial: func(n, addr string) (net.Conn, error) {
conn, err := net.Dial(n, addr)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
if debug {
conn = NewLoggingConn("client", conn)
}
return conn, nil
},
DisableKeepAlives: true,
},
}
nRuns := 5
if testing.Short() {
nRuns = 1
}
for i := 0; i < nRuns; i++ {
if debug {
println("run", i+1, "of", nRuns)
}
sres, err := client.Get(ts.URL + "/get")
if err != nil {
t.Errorf("Error issuing GET: %v", err)
break
}
req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
_, err = client.Do(req)
if err == nil {
t.Errorf("Unexpected successful PUT")
break
}
}
if debug {
println("tests complete; waiting for handlers to finish")
}
ts.Close()
}
type fooProto struct{} type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) { func (fooProto) RoundTrip(req *Request) (*Response, error) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment