Commit 7448eb41 authored by Emmanuel Odeke's avatar Emmanuel Odeke Committed by Brad Fitzpatrick

net/http: don't wrap request cancellation errors in timeouts

Based on Filippo Valsorda's https://golang.org/cl/24230

Fixes #16094

Change-Id: Ie39b0834e220f0a0f4fbfb3bbb271e70837718c3
Reviewed-on: https://go-review.googlesource.com/32478Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 47bdae94
......@@ -218,7 +218,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error)
if !deadline.IsZero() {
forkReq()
}
stopTimer, wasCanceled := setRequestCancel(req, rt, deadline)
stopTimer, didTimeout := setRequestCancel(req, rt, deadline)
resp, err := rt.RoundTrip(req)
if err != nil {
......@@ -238,9 +238,9 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error)
}
if !deadline.IsZero() {
resp.Body = &cancelTimerBody{
stop: stopTimer,
rc: resp.Body,
reqWasCanceled: wasCanceled,
stop: stopTimer,
rc: resp.Body,
reqDidTimeout: didTimeout,
}
}
return resp, nil
......@@ -249,7 +249,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error)
// setRequestCancel sets the Cancel field of req, if deadline is
// non-zero. The RoundTripper's type is used to determine whether the legacy
// CancelRequest behavior should be used.
func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), wasCanceled func() bool) {
func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) {
if deadline.IsZero() {
return nop, alwaysFalse
}
......@@ -259,15 +259,6 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi
cancel := make(chan struct{})
req.Cancel = cancel
wasCanceled = func() bool {
select {
case <-cancel:
return true
default:
return false
}
}
doCancel := func() {
// The new way:
close(cancel)
......@@ -292,19 +283,22 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi
stopTimer = func() { once.Do(func() { close(stopTimerCh) }) }
timer := time.NewTimer(time.Until(deadline))
var timedOut atomicBool
go func() {
select {
case <-initialReqCancel:
doCancel()
timer.Stop()
case <-timer.C:
timedOut.setTrue()
doCancel()
case <-stopTimerCh:
timer.Stop()
}
}()
return stopTimer, wasCanceled
return stopTimer, timedOut.isSet
}
// See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt
......@@ -728,12 +722,12 @@ func (c *Client) Head(url string) (resp *Response, err error) {
// cancelTimerBody is an io.ReadCloser that wraps rc with two features:
// 1) on Read error or close, the stop func is called.
// 2) On Read failure, if reqWasCanceled is true, the error is wrapped and
// 2) On Read failure, if reqDidTimeout is true, the error is wrapped and
// marked as net.Error that hit its timeout.
type cancelTimerBody struct {
stop func() // stops the time.Timer waiting to cancel the request
rc io.ReadCloser
reqWasCanceled func() bool
stop func() // stops the time.Timer waiting to cancel the request
rc io.ReadCloser
reqDidTimeout func() bool
}
func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
......@@ -745,7 +739,7 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
if err == io.EOF {
return n, err
}
if b.reqWasCanceled() {
if b.reqDidTimeout() {
err = &httpError{
err: err.Error() + " (Client.Timeout exceeded while reading body)",
timeout: true,
......
......@@ -1185,10 +1185,12 @@ func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) }
func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) }
func testClientTimeout(t *testing.T, h2 bool) {
if testing.Short() {
t.Skip("skipping in short mode")
}
setParallel(t)
defer afterTest(t)
testDone := make(chan struct{})
const timeout = 100 * time.Millisecond
sawRoot := make(chan bool, 1)
sawSlow := make(chan bool, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
......@@ -1201,16 +1203,22 @@ func testClientTimeout(t *testing.T, h2 bool) {
w.Write([]byte("Hello"))
w.(Flusher).Flush()
sawSlow <- true
time.Sleep(2 * time.Second)
select {
case <-testDone:
case <-time.After(timeout * 10):
}
return
}
}))
defer cst.close()
const timeout = 500 * time.Millisecond
defer close(testDone)
cst.c.Timeout = timeout
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
if strings.Contains(err.Error(), "Client.Timeout") {
t.Skip("host too slow to get fast resource in 100ms")
}
t.Fatal(err)
}
......@@ -1260,11 +1268,9 @@ func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h
// Client.Timeout firing before getting to the body
func testClientTimeout_Headers(t *testing.T, h2 bool) {
if testing.Short() {
t.Skip("skipping in short mode")
}
setParallel(t)
defer afterTest(t)
donec := make(chan bool)
donec := make(chan bool, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
<-donec
}))
......@@ -1278,9 +1284,10 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) {
// doesn't know this, so synchronize explicitly.
defer func() { donec <- true }()
cst.c.Timeout = 500 * time.Millisecond
_, err := cst.c.Get(cst.ts.URL)
cst.c.Timeout = 5 * time.Millisecond
res, err := cst.c.Get(cst.ts.URL)
if err == nil {
res.Body.Close()
t.Fatal("got response from Get; expected error")
}
if _, ok := err.(*url.Error); !ok {
......@@ -1298,6 +1305,36 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) {
}
}
// Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be
// returned.
func TestClientTimeoutCancel(t *testing.T) {
setParallel(t)
defer afterTest(t)
testDone := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.(Flusher).Flush()
<-testDone
}))
defer cst.close()
defer close(testDone)
cst.c.Timeout = 1 * time.Hour
req, _ := NewRequest("GET", cst.ts.URL, nil)
req.Cancel = ctx.Done()
res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
cancel()
_, err = io.Copy(ioutil.Discard, res.Body)
if err != ExportErrRequestCanceled {
t.Fatal("error = %v; want errRequestCanceled")
}
}
func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) }
func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) }
func testClientRedirectEatsBody(t *testing.T, h2 bool) {
......@@ -1317,10 +1354,10 @@ func testClientRedirectEatsBody(t *testing.T, h2 bool) {
t.Fatal(err)
}
_, err = ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatal(err)
}
res.Body.Close()
var first string
select {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment