Commit 723f8653 authored by Daniel Morsing's avatar Daniel Morsing

net/http: fix race between dialing and canceling

In the brief window between getConn and persistConn.roundTrip,
a cancel could end up going missing.

Fix by making it possible to inspect if a cancel function was cleared
and checking if we were canceled before entering roundTrip.

Fixes #10511

Change-Id: If6513e63fbc2edb703e36d6356ccc95a1dc33144
Reviewed-on: https://go-review.googlesource.com/9181Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 5fa2d991
......@@ -82,6 +82,10 @@ func SetInstallConnClosedHook(f func()) {
testHookPersistConnClosedGotRes = f
}
func SetEnterRoundTripHook(f func()) {
testHookEnterRoundTrip = f
}
func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
f := func() <-chan time.Time {
return ch
......
......@@ -475,6 +475,25 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) {
}
}
// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
// for the request, we don't set the function and return false.
// Since CancelRequest will clear the canceler, we can use the return value to detect if
// the request was canceled since the last setReqCancel call.
func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
t.reqMu.Lock()
defer t.reqMu.Unlock()
_, ok := t.reqCanceler[r]
if !ok {
return false
}
if fn != nil {
t.reqCanceler[r] = fn
} else {
delete(t.reqCanceler, r)
}
return true
}
func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
if t.Dial != nil {
return t.Dial(network, addr)
......@@ -491,6 +510,10 @@ var prePendingDial, postPendingDial func()
// is ready to write requests to.
func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
if pc := t.getIdleConn(cm); pc != nil {
// set request canceler to some non-nil function so we
// can detect whether it was cleared between now and when
// we enter roundTrip
t.setReqCanceler(req, func() {})
return pc, nil
}
......@@ -1063,10 +1086,20 @@ var errTimeout error = &httpError{err: "net/http: timeout awaiting response head
var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
var errRequestCanceled = errors.New("net/http: request canceled")
var testHookPersistConnClosedGotRes func() // nil except for tests
// nil except for tests
var (
testHookPersistConnClosedGotRes func()
testHookEnterRoundTrip func()
)
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
pc.t.setReqCanceler(req.Request, pc.cancelRequest)
if hook := testHookEnterRoundTrip; hook != nil {
hook()
}
if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
pc.t.putIdleConn(pc)
return nil, errRequestCanceled
}
pc.lk.Lock()
pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc
......
......@@ -2400,6 +2400,32 @@ func TestTransportResponseCancelRace(t *testing.T) {
res.Body.Close()
}
func TestTransportDialCancelRace(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close()
tr := &Transport{}
defer tr.CloseIdleConnections()
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
SetEnterRoundTripHook(func() {
tr.CancelRequest(req)
})
defer SetEnterRoundTripHook(nil)
res, err := tr.RoundTrip(req)
if err != ExportErrRequestCanceled {
t.Errorf("expected canceled request error; got %v", err)
if err == nil {
res.Body.Close()
}
}
}
func wantBody(res *http.Response, err error, want string) error {
if err != nil {
return err
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment