Commit 127d2bf7 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix Transport races & deadlocks

Thanks to Dustin Sallings for exposing the most frustrating
bug ever, and for providing repro cases (which formed the
basis of the new tests in this CL), and to Dave Cheney and
Dmitry Vyukov for help debugging and fixing.

This CL depends on submited pollster CLs ffd1e075c260 (Unix)
and 14b544194509 (Windows), as well as unsubmitted 6852085.
Some operating systems (OpenBSD, NetBSD, ?) may still require
more pollster work, fixing races (Issue 4434 and
http://goo.gl/JXB6W).

Tested on linux-amd64 and darwin-amd64, both with GOMAXPROCS 1
and 4 (all combinations of which previously failed differently)

Fixes #4191
Update #4434 (related fallout from this bug)

R=dave, bradfitz, dsallings, rsc, fullung
CC=golang-dev
https://golang.org/cl/6851061
parent 5188c0b5
......@@ -7,7 +7,14 @@
package http
import "time"
import (
"net"
"time"
)
func NewLoggingConn(baseName string, c net.Conn) net.Conn {
return newLoggingConn(baseName, c)
}
func (t *Transport) IdleConnKeysForTesting() (keys []string) {
keys = make([]string, 0)
......
......@@ -170,16 +170,23 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
// noLimit is an effective infinite upper bound for io.LimitedReader
const noLimit int64 = (1 << 63) - 1
// debugServerConnections controls whether all server connections are wrapped
// with a verbose logging wrapper.
const debugServerConnections = false
// Create new connection from rwc.
func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
c = new(conn)
c.remoteAddr = rwc.RemoteAddr().String()
c.server = srv
c.rwc = rwc
if debugServerConnections {
c.rwc = newLoggingConn("server", c.rwc)
}
c.body = make([]byte, sniffLen)
c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader)
c.lr = io.LimitReader(c.rwc, noLimit).(*io.LimitedReader)
br := bufio.NewReader(c.lr)
bw := bufio.NewWriter(rwc)
bw := bufio.NewWriter(c.rwc)
c.buf = bufio.NewReadWriter(br, bw)
return c, nil
}
......@@ -495,7 +502,7 @@ func (w *response) Write(data []byte) (n int, err error) {
// then there would be fewer chunk headers.
// On the other hand, it would make hijacking more difficult.
if w.chunking {
fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt
fmt.Fprintf(w.conn.buf, "%x\r\n", len(data))
}
n, err = w.conn.buf.Write(data)
if err == nil && w.chunking {
......@@ -1309,3 +1316,45 @@ func (tw *timeoutWriter) WriteHeader(code int) {
tw.mu.Unlock()
tw.w.WriteHeader(code)
}
// loggingConn is used for debugging.
type loggingConn struct {
name string
net.Conn
}
var (
uniqNameMu sync.Mutex
uniqNameNext = make(map[string]int)
)
func newLoggingConn(baseName string, c net.Conn) net.Conn {
uniqNameMu.Lock()
defer uniqNameMu.Unlock()
uniqNameNext[baseName]++
return &loggingConn{
name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]),
Conn: c,
}
}
func (c *loggingConn) Write(p []byte) (n int, err error) {
log.Printf("%s.Write(%d) = ....", c.name, len(p))
n, err = c.Conn.Write(p)
log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err)
return
}
func (c *loggingConn) Read(p []byte) (n int, err error) {
log.Printf("%s.Read(%d) = ....", c.name, len(p))
n, err = c.Conn.Read(p)
log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err)
return
}
func (c *loggingConn) Close() (err error) {
log.Printf("%s.Close() = ...", c.name)
err = c.Conn.Close()
log.Printf("%s.Close() = %v", c.name, err)
return
}
......@@ -24,7 +24,6 @@ import (
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
......@@ -613,14 +612,18 @@ func (pc *persistConn) readLoop() {
if hasBody {
lastbody = resp.Body
waitForBodyRead = make(chan bool, 1)
resp.Body.(*bodyEOFSignal).fn = func() {
if alive && !pc.t.putIdleConn(pc) {
alive = false
resp.Body.(*bodyEOFSignal).fn = func(err error) {
alive1 := alive
if err != nil {
alive1 = false
}
if !alive || pc.isBroken() {
if alive1 && !pc.t.putIdleConn(pc) {
alive1 = false
}
if !alive1 || pc.isBroken() {
pc.close()
}
waitForBodyRead <- true
waitForBodyRead <- alive1
}
}
......@@ -644,7 +647,7 @@ func (pc *persistConn) readLoop() {
// Wait for the just-returned response body to be fully consumed
// before we race and peek on the underlying bufio reader.
if waitForBodyRead != nil {
<-waitForBodyRead
alive = <-waitForBodyRead
}
if !alive {
......@@ -810,50 +813,61 @@ func canonicalAddr(url *url.URL) string {
}
// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
// once, right before the final Read() or Close() call returns, but after
// EOF has been seen.
// once, right before its final (error-producing) Read or Close call
// returns.
type bodyEOFSignal struct {
body io.ReadCloser
fn func()
isClosed uint32 // atomic bool, non-zero if true
once sync.Once
body io.ReadCloser
mu sync.Mutex // guards closed, rerr and fn
closed bool // whether Close has been called
rerr error // sticky Read error
fn func(error) // error will be nil on Read io.EOF
}
func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
n, err = es.body.Read(p)
if es.closed() && n > 0 {
panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725")
es.mu.Lock()
closed, rerr := es.closed, es.rerr
es.mu.Unlock()
if closed {
return 0, errors.New("http: read on closed response body")
}
if err == io.EOF {
es.condfn()
if rerr != nil {
return 0, rerr
}
return
}
func (es *bodyEOFSignal) Close() (err error) {
if !es.setClosed() {
// already closed
return nil
}
err = es.body.Close()
if err == nil {
es.condfn()
n, err = es.body.Read(p)
if err != nil {
es.mu.Lock()
defer es.mu.Unlock()
if es.rerr == nil {
es.rerr = err
}
es.condfn(err)
}
return
}
func (es *bodyEOFSignal) condfn() {
if es.fn != nil {
es.once.Do(es.fn)
func (es *bodyEOFSignal) Close() error {
es.mu.Lock()
defer es.mu.Unlock()
if es.closed {
return nil
}
es.closed = true
err := es.body.Close()
es.condfn(err)
return err
}
func (es *bodyEOFSignal) closed() bool {
return atomic.LoadUint32(&es.isClosed) != 0
}
func (es *bodyEOFSignal) setClosed() bool {
return atomic.CompareAndSwapUint32(&es.isClosed, 0, 1)
// caller must hold es.mu.
func (es *bodyEOFSignal) condfn(err error) {
if es.fn == nil {
return
}
if err == io.EOF {
err = nil
}
es.fn(err)
es.fn = nil
}
type readFirstCloseBoth struct {
......
......@@ -901,6 +901,111 @@ func TestTransportConcurrency(t *testing.T) {
wg.Wait()
}
func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
const debug = false
mux := NewServeMux()
mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
io.Copy(w, neverEnding('a'))
})
ts := httptest.NewServer(mux)
client := &Client{
Transport: &Transport{
Dial: func(n, addr string) (net.Conn, error) {
conn, err := net.Dial(n, addr)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
if debug {
conn = NewLoggingConn("client", conn)
}
return conn, nil
},
DisableKeepAlives: true,
},
}
nRuns := 5
if testing.Short() {
nRuns = 1
}
for i := 0; i < nRuns; i++ {
if debug {
println("run", i+1, "of", nRuns)
}
sres, err := client.Get(ts.URL + "/get")
if err != nil {
t.Errorf("Error issuing GET: %v", err)
break
}
_, err = io.Copy(ioutil.Discard, sres.Body)
if err == nil {
t.Errorf("Unexpected successful copy")
break
}
}
if debug {
println("tests complete; waiting for handlers to finish")
}
ts.Close()
}
func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
const debug = false
mux := NewServeMux()
mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
io.Copy(w, neverEnding('a'))
})
mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
defer r.Body.Close()
io.Copy(ioutil.Discard, r.Body)
})
ts := httptest.NewServer(mux)
client := &Client{
Transport: &Transport{
Dial: func(n, addr string) (net.Conn, error) {
conn, err := net.Dial(n, addr)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
if debug {
conn = NewLoggingConn("client", conn)
}
return conn, nil
},
DisableKeepAlives: true,
},
}
nRuns := 5
if testing.Short() {
nRuns = 1
}
for i := 0; i < nRuns; i++ {
if debug {
println("run", i+1, "of", nRuns)
}
sres, err := client.Get(ts.URL + "/get")
if err != nil {
t.Errorf("Error issuing GET: %v", err)
break
}
req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
_, err = client.Do(req)
if err == nil {
t.Errorf("Unexpected successful PUT")
break
}
}
if debug {
println("tests complete; waiting for handlers to finish")
}
ts.Close()
}
type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment