Commit 49d79d7e authored by Matt Holt's avatar Matt Holt Committed by GitHub

Merge pull request #1598 from tw4452852/1589

proxy: recognize client's cancellation
parents 4c034f6a 0146bb4e
...@@ -9,8 +9,6 @@ import ( ...@@ -9,8 +9,6 @@ import (
"net/http" "net/http"
"strings" "strings"
"errors"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -155,7 +153,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { ...@@ -155,7 +153,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
return pusher.Push(target, opts) return pusher.Push(target, opts)
} }
return errors.New("push is unavailable (probably chained http.ResponseWriter does not implement http.Pusher)") return httpserver.NonFlusherError{Underlying: w.ResponseWriter}
} }
// Interface guards // Interface guards
......
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"net/http" "net/http"
"strings" "strings"
"errors"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
...@@ -141,7 +140,7 @@ func (rww *responseWriterWrapper) Push(target string, opts *http.PushOptions) er ...@@ -141,7 +140,7 @@ func (rww *responseWriterWrapper) Push(target string, opts *http.PushOptions) er
return pusher.Push(target, opts) return pusher.Push(target, opts)
} }
return errors.New("push is unavailable (probably chained http.ResponseWriter does not implement http.Pusher)") return httpserver.NonPusherError{Underlying: rww.ResponseWriter}
} }
// Interface guards // Interface guards
......
...@@ -8,6 +8,7 @@ var ( ...@@ -8,6 +8,7 @@ var (
_ error = NonHijackerError{} _ error = NonHijackerError{}
_ error = NonFlusherError{} _ error = NonFlusherError{}
_ error = NonCloseNotifierError{} _ error = NonCloseNotifierError{}
_ error = NonPusherError{}
) )
// NonHijackerError is more descriptive error caused by a non hijacker // NonHijackerError is more descriptive error caused by a non hijacker
...@@ -42,3 +43,14 @@ type NonCloseNotifierError struct { ...@@ -42,3 +43,14 @@ type NonCloseNotifierError struct {
func (c NonCloseNotifierError) Error() string { func (c NonCloseNotifierError) Error() string {
return fmt.Sprintf("%T is not a closeNotifier", c.Underlying) return fmt.Sprintf("%T is not a closeNotifier", c.Underlying)
} }
// NonPusherError is more descriptive error caused by a non pusher
type NonPusherError struct {
// underlying type which doesn't implement pusher
Underlying interface{}
}
// Implement Error
func (c NonPusherError) Error() string {
return fmt.Sprintf("%T is not a pusher", c.Underlying)
}
...@@ -2,7 +2,6 @@ package httpserver ...@@ -2,7 +2,6 @@ package httpserver
import ( import (
"bufio" "bufio"
"errors"
"net" "net"
"net/http" "net/http"
"time" "time"
...@@ -103,7 +102,7 @@ func (r *ResponseRecorder) Push(target string, opts *http.PushOptions) error { ...@@ -103,7 +102,7 @@ func (r *ResponseRecorder) Push(target string, opts *http.PushOptions) error {
return pusher.Push(target, opts) return pusher.Push(target, opts)
} }
return errors.New("push is unavailable (probably chained http.ResponseWriter does not implement http.Pusher)") return NonPusherError{Underlying: r.ResponseWriter}
} }
// Interface guards // Interface guards
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
package proxy package proxy
import ( import (
"context"
"errors" "errors"
"net" "net"
"net/http" "net/http"
...@@ -103,7 +104,8 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -103,7 +104,8 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
replacer := httpserver.NewReplacer(r, nil, "") replacer := httpserver.NewReplacer(r, nil, "")
// outreq is the request that makes a roundtrip to the backend // outreq is the request that makes a roundtrip to the backend
outreq := createUpstreamRequest(r) outreq, cancel := createUpstreamRequest(w, r)
defer cancel()
// If we have more than one upstream host defined and if retrying is enabled // If we have more than one upstream host defined and if retrying is enabled
// by setting try_duration to a non-zero value, caddy will try to // by setting try_duration to a non-zero value, caddy will try to
...@@ -131,7 +133,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -131,7 +133,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// loop and try to select another host, or false if we // loop and try to select another host, or false if we
// should break and stop retrying. // should break and stop retrying.
start := time.Now() start := time.Now()
keepRetrying := func() bool { keepRetrying := func(backendErr error) bool {
// if downstream has canceled the request, break
if backendErr == context.Canceled {
return false
}
// if we've tried long enough, break // if we've tried long enough, break
if time.Since(start) >= upstream.GetTryDuration() { if time.Since(start) >= upstream.GetTryDuration() {
return false return false
...@@ -150,7 +156,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -150,7 +156,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if backendErr == nil { if backendErr == nil {
backendErr = errors.New("no hosts available upstream") backendErr = errors.New("no hosts available upstream")
} }
if !keepRetrying() { if !keepRetrying(backendErr) {
break break
} }
continue continue
...@@ -238,7 +244,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -238,7 +244,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
} }
// if we've tried long enough, break // if we've tried long enough, break
if !keepRetrying() { if !keepRetrying(backendErr) {
break break
} }
} }
...@@ -267,9 +273,23 @@ func (p Proxy) match(r *http.Request) Upstream { ...@@ -267,9 +273,23 @@ func (p Proxy) match(r *http.Request) Upstream {
// that can be sent upstream. // that can be sent upstream.
// //
// Derived from reverseproxy.go in the standard Go httputil package. // Derived from reverseproxy.go in the standard Go httputil package.
func createUpstreamRequest(r *http.Request) *http.Request { func createUpstreamRequest(rw http.ResponseWriter, r *http.Request) (*http.Request, context.CancelFunc) {
outreq := new(http.Request) // Original incoming server request may be canceled by the
*outreq = *r // includes shallow copies of maps, but okay // user or by std lib(e.g. too many idle connections).
ctx, cancel := context.WithCancel(r.Context())
if cn, ok := rw.(http.CloseNotifier); ok {
notifyChan := cn.CloseNotify()
go func() {
select {
case <-notifyChan:
cancel()
case <-ctx.Done():
}
}()
}
outreq := r.WithContext(ctx) // includes shallow copies of maps, but okay
// We should set body to nil explicitly if request body is empty. // We should set body to nil explicitly if request body is empty.
// For server requests the Request Body is always non-nil. // For server requests the Request Body is always non-nil.
if r.ContentLength == 0 { if r.ContentLength == 0 {
...@@ -319,7 +339,7 @@ func createUpstreamRequest(r *http.Request) *http.Request { ...@@ -319,7 +339,7 @@ func createUpstreamRequest(r *http.Request) *http.Request {
outreq.Header.Set("X-Forwarded-For", clientIP) outreq.Header.Set("X-Forwarded-For", clientIP)
} }
return outreq return outreq, cancel
} }
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn { func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn {
......
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httptrace"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
...@@ -101,7 +100,7 @@ func TestReverseProxy(t *testing.T) { ...@@ -101,7 +100,7 @@ func TestReverseProxy(t *testing.T) {
// Make sure {upstream} placeholder is set // Make sure {upstream} placeholder is set
r.Body = ioutil.NopCloser(strings.NewReader("test")) r.Body = ioutil.NopCloser(strings.NewReader("test"))
rr := httpserver.NewResponseRecorder(httptest.NewRecorder()) rr := httpserver.NewResponseRecorder(testResponseRecorder{httptest.NewRecorder()})
rr.Replacer = httpserver.NewReplacer(r, rr, "-") rr.Replacer = httpserver.NewReplacer(r, rr, "-")
p.ServeHTTP(rr, r) p.ServeHTTP(rr, r)
...@@ -1123,7 +1122,18 @@ func TestReverseProxyLargeBody(t *testing.T) { ...@@ -1123,7 +1122,18 @@ func TestReverseProxyLargeBody(t *testing.T) {
} }
func TestCancelRequest(t *testing.T) { func TestCancelRequest(t *testing.T) {
reqInFlight := make(chan struct{})
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(reqInFlight) // cause the client to cancel its request
select {
case <-time.After(10 * time.Second):
t.Error("Handler never saw CloseNotify")
return
case <-w.(http.CloseNotifier).CloseNotify():
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, client")) w.Write([]byte("Hello, client"))
})) }))
defer backend.Close() defer backend.Close()
...@@ -1140,26 +1150,21 @@ func TestCancelRequest(t *testing.T) { ...@@ -1140,26 +1150,21 @@ func TestCancelRequest(t *testing.T) {
defer cancel() defer cancel()
req = req.WithContext(ctx) req = req.WithContext(ctx)
// add GotConn hook to cancel the request
gotC := make(chan struct{})
defer close(gotC)
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
gotC <- struct{}{}
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
// wait for canceling the request // wait for canceling the request
go func() { go func() {
<-gotC <-reqInFlight
cancel() cancel()
}() }()
status, err := p.ServeHTTP(httptest.NewRecorder(), req) rec := httptest.NewRecorder()
if status != 0 || err != nil { status, err := p.ServeHTTP(rec, req)
t.Errorf("expect proxy handle normally, but not, status:%d, err:%q", expectedStatus, expectErr := http.StatusBadGateway, context.Canceled
status, err) if status != expectedStatus || err != expectErr {
t.Errorf("expect proxy handle return status[%d] with error[%v], but got status[%d] with error[%v]",
expectedStatus, expectErr, status, err)
}
if body := rec.Body.String(); body != "" {
t.Errorf("expect a blank response, but got %q", body)
} }
} }
...@@ -1310,6 +1315,28 @@ func (c *fakeConn) Close() error { return nil } ...@@ -1310,6 +1315,28 @@ func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) } func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
// testResponseRecorder wraps `httptest.ResponseRecorder`,
// also implements `http.CloseNotifier`, `http.Hijacker` and `http.Pusher`.
type testResponseRecorder struct {
*httptest.ResponseRecorder
}
func (testResponseRecorder) CloseNotify() <-chan bool { return nil }
func (t testResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, httpserver.NonHijackerError{Underlying: t}
}
func (t testResponseRecorder) Push(target string, opts *http.PushOptions) error {
return httpserver.NonPusherError{Underlying: t}
}
// Interface guards
var (
_ http.Pusher = testResponseRecorder{}
_ http.Flusher = testResponseRecorder{}
_ http.CloseNotifier = testResponseRecorder{}
_ http.Hijacker = testResponseRecorder{}
)
func BenchmarkProxy(b *testing.B) { func BenchmarkProxy(b *testing.B) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, client")) w.Write([]byte("Hello, client"))
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
package proxy package proxy
import ( import (
"context"
"crypto/tls" "crypto/tls"
"io" "io"
"net" "net"
...@@ -252,14 +251,6 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, ...@@ -252,14 +251,6 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
rp.Director(outreq) rp.Director(outreq)
// Original incoming server request may be canceled by the
// user or by std lib(e.g. too many idle connections).
// Now we issue the new outgoing client request which
// doesn't depend on the original one. (issue 1345)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
outreq = outreq.WithContext(ctx)
res, err := transport.RoundTrip(outreq) res, err := transport.RoundTrip(outreq)
if err != nil { if err != nil {
return err return err
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment