Commit 0b5f2f0d authored by Ian Lance Taylor's avatar Ian Lance Taylor

net/http: if context is canceled, return its error

This permits the error message to distinguish between a context that was
canceled and a context that timed out.

Updates #16381.

Change-Id: I3994b98e32952abcd7ddb5fee08fa1535999be6d
Reviewed-on: https://go-review.googlesource.com/24978
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 643b9ec0
......@@ -313,8 +313,8 @@ func TestClientRedirectContext(t *testing.T) {
if !ok {
t.Fatalf("got error %T; want *url.Error", err)
}
if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn {
t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err)
if ue.Err != context.Canceled {
t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled)
}
}
......
......@@ -76,7 +76,7 @@ type Transport struct {
idleLRU connLRU
reqMu sync.Mutex
reqCanceler map[*Request]func()
reqCanceler map[*Request]func(error)
altMu sync.RWMutex
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
......@@ -498,12 +498,17 @@ func (t *Transport) CloseIdleConnections() {
// cancelable context instead. CancelRequest cannot cancel HTTP/2
// requests.
func (t *Transport) CancelRequest(req *Request) {
t.cancelRequest(req, errRequestCanceled)
}
// Cancel an in-flight request, recording the error value.
func (t *Transport) cancelRequest(req *Request, err error) {
t.reqMu.Lock()
cancel := t.reqCanceler[req]
delete(t.reqCanceler, req)
t.reqMu.Unlock()
if cancel != nil {
cancel()
cancel(err)
}
}
......@@ -783,11 +788,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) {
}
}
func (t *Transport) setReqCanceler(r *Request, fn func()) {
func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
t.reqMu.Lock()
defer t.reqMu.Unlock()
if t.reqCanceler == nil {
t.reqCanceler = make(map[*Request]func())
t.reqCanceler = make(map[*Request]func(error))
}
if fn != nil {
t.reqCanceler[r] = fn
......@@ -800,7 +805,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) {
// for the request, we don't set the function and return false.
// Since CancelRequest will clear the canceler, we can use the return value to detect if
// the request was canceled since the last setReqCancel call.
func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
t.reqMu.Lock()
defer t.reqMu.Unlock()
_, ok := t.reqCanceler[r]
......@@ -849,7 +854,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
// set request canceler to some non-nil function so we
// can detect whether it was cleared between now and when
// we enter roundTrip
t.setReqCanceler(req, func() {})
t.setReqCanceler(req, func(error) {})
return pc, nil
}
......@@ -874,8 +879,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
}()
}
cancelc := make(chan struct{})
t.setReqCanceler(req, func() { close(cancelc) })
cancelc := make(chan error, 1)
t.setReqCanceler(req, func(err error) { cancelc <- err })
go func() {
pc, err := t.dialConn(ctx, cm)
......@@ -897,7 +902,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
select {
case <-req.Cancel:
case <-req.Context().Done():
case <-cancelc:
return nil, req.Context().Err()
case err := <-cancelc:
if err == errRequestCanceled {
err = errRequestCanceledConn
}
return nil, err
default:
// It wasn't an error due to cancelation, so
// return the original error message:
......@@ -922,10 +932,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
return nil, errRequestCanceledConn
case <-req.Context().Done():
handlePendingDial()
return nil, errRequestCanceledConn
case <-cancelc:
return nil, req.Context().Err()
case err := <-cancelc:
handlePendingDial()
return nil, errRequestCanceledConn
if err == errRequestCanceled {
err = errRequestCanceledConn
}
return nil, err
}
}
......@@ -1231,8 +1244,8 @@ type persistConn struct {
mu sync.Mutex // guards following fields
numExpectedResponses int
closed error // set non-nil when conn is closed, before closech is closed
canceledErr error // set non-nil if conn is canceled
broken bool // an error has happened on this connection; marked broken so it's not reused.
canceled bool // whether this conn was broken due a CancelRequest
reused bool // whether conn has had successful request/response and is being reused.
// mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the
......@@ -1270,11 +1283,12 @@ func (pc *persistConn) isBroken() bool {
return b
}
// isCanceled reports whether this connection was closed due to CancelRequest.
func (pc *persistConn) isCanceled() bool {
// canceled returns non-nil if the connection was closed due to
// CancelRequest or due to context cancelation.
func (pc *persistConn) canceled() error {
pc.mu.Lock()
defer pc.mu.Unlock()
return pc.canceled
return pc.canceledErr
}
// isReused reports whether this connection is in a known broken state.
......@@ -1297,10 +1311,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn
return
}
func (pc *persistConn) cancelRequest() {
func (pc *persistConn) cancelRequest(err error) {
pc.mu.Lock()
defer pc.mu.Unlock()
pc.canceled = true
pc.canceledErr = err
pc.closeLocked(errRequestCanceled)
}
......@@ -1328,8 +1342,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
if err == nil {
return nil
}
if pc.isCanceled() {
return errRequestCanceled
if err := pc.canceled(); err != nil {
return err
}
if err == errServerClosedIdle {
return err
......@@ -1351,8 +1365,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
// its pc.closech channel close, indicating the persistConn is dead.
// (after closech is closed, pc.closed is valid).
func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error {
if pc.isCanceled() {
return errRequestCanceled
if err := pc.canceled(); err != nil {
return err
}
err := pc.closed
if err == errServerClosedIdle {
......@@ -1509,8 +1523,10 @@ func (pc *persistConn) readLoop() {
waitForBodyRead <- isEOF
if isEOF {
<-eofc // see comment above eofc declaration
} else if err != nil && pc.isCanceled() {
return errRequestCanceled
} else if err != nil {
if cerr := pc.canceled(); cerr != nil {
return cerr
}
}
return err
},
......@@ -1550,7 +1566,7 @@ func (pc *persistConn) readLoop() {
pc.t.CancelRequest(rc.req)
case <-rc.req.Context().Done():
alive = false
pc.t.CancelRequest(rc.req)
pc.t.cancelRequest(rc.req, rc.req.Context().Err())
case <-pc.closech:
alive = false
}
......@@ -1836,8 +1852,8 @@ WaitResponse:
select {
case err := <-writeErrCh:
if err != nil {
if pc.isCanceled() {
err = errRequestCanceled
if cerr := pc.canceled(); cerr != nil {
err = cerr
}
re = responseAndError{err: err}
pc.close(fmt.Errorf("write error: %v", err))
......@@ -1861,9 +1877,8 @@ WaitResponse:
case <-cancelChan:
pc.t.CancelRequest(req.Request)
cancelChan = nil
ctxDoneChan = nil
case <-ctxDoneChan:
pc.t.CancelRequest(req.Request)
pc.t.cancelRequest(req.Request, req.Context().Err())
cancelChan = nil
ctxDoneChan = nil
}
......
......@@ -1718,8 +1718,17 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
}
_, err := c.Do(req)
if err == nil || !strings.Contains(err.Error(), "canceled") {
t.Errorf("Do error = %v; want cancelation", err)
if ue, ok := err.(*url.Error); ok {
err = ue.Err
}
if withCtx {
if err != context.Canceled {
t.Errorf("Do error = %v; want %v", err, context.Canceled)
}
} else {
if err == nil || !strings.Contains(err.Error(), "canceled") {
t.Errorf("Do error = %v; want cancelation", err)
}
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment