Commit 8e7a36de authored by Tw's avatar Tw Committed by Matt Holt

ResponseWriterWrapper and HTTPInterfaces (#1644)

Signed-off-by: default avatarTw <tw19881113@gmail.com>
parent 86d107f6
...@@ -3,9 +3,7 @@ ...@@ -3,9 +3,7 @@
package gzip package gzip
import ( import (
"bufio"
"io" "io"
"net"
"net/http" "net/http"
"strings" "strings"
...@@ -58,7 +56,10 @@ outer: ...@@ -58,7 +56,10 @@ outer:
// original form. // original form.
gzipWriter := getWriter(c.Level) gzipWriter := getWriter(c.Level)
defer putWriter(c.Level, gzipWriter) defer putWriter(c.Level, gzipWriter)
gz := &gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} gz := &gzipResponseWriter{
Writer: gzipWriter,
ResponseWriterWrapper: &httpserver.ResponseWriterWrapper{ResponseWriter: w},
}
var rw http.ResponseWriter var rw http.ResponseWriter
// if no response filter is used // if no response filter is used
...@@ -92,7 +93,7 @@ outer: ...@@ -92,7 +93,7 @@ outer:
// with a gzip.Writer to compress the output. // with a gzip.Writer to compress the output.
type gzipResponseWriter struct { type gzipResponseWriter struct {
io.Writer io.Writer
http.ResponseWriter *httpserver.ResponseWriterWrapper
statusCodeWritten bool statusCodeWritten bool
} }
...@@ -104,7 +105,7 @@ func (w *gzipResponseWriter) WriteHeader(code int) { ...@@ -104,7 +105,7 @@ func (w *gzipResponseWriter) WriteHeader(code int) {
w.Header().Del("Content-Length") w.Header().Del("Content-Length")
w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Content-Encoding", "gzip")
w.Header().Add("Vary", "Accept-Encoding") w.Header().Add("Vary", "Accept-Encoding")
w.ResponseWriter.WriteHeader(code) w.ResponseWriterWrapper.WriteHeader(code)
w.statusCodeWritten = true w.statusCodeWritten = true
} }
...@@ -120,44 +121,5 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { ...@@ -120,44 +121,5 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) {
return n, err return n, err
} }
// Hijack implements http.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, httpserver.NonHijackerError{Underlying: w.ResponseWriter}
}
// Flush implements http.Flusher. It simply wraps the underlying
// ResponseWriter's Flush method if there is one, or panics.
func (w *gzipResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
} else {
panic(httpserver.NonFlusherError{Underlying: w.ResponseWriter}) // should be recovered at the beginning of middleware stack
}
}
// CloseNotify implements http.CloseNotifier.
// It just inherits the underlying ResponseWriter's CloseNotify method.
func (w *gzipResponseWriter) CloseNotify() <-chan bool {
if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
panic(httpserver.NonCloseNotifierError{Underlying: w.ResponseWriter})
}
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
if pusher, hasPusher := w.ResponseWriter.(http.Pusher); hasPusher {
return pusher.Push(target, opts)
}
return httpserver.NonFlusherError{Underlying: w.ResponseWriter}
}
// Interface guards // Interface guards
var _ http.Pusher = (*gzipResponseWriter)(nil) var _ httpserver.HTTPInterfaces = (*gzipResponseWriter)(nil)
var _ http.Flusher = (*gzipResponseWriter)(nil)
var _ http.CloseNotifier = (*gzipResponseWriter)(nil)
var _ http.Hijacker = (*gzipResponseWriter)(nil)
...@@ -33,7 +33,7 @@ func TestLengthFilter(t *testing.T) { ...@@ -33,7 +33,7 @@ func TestLengthFilter(t *testing.T) {
for j, filter := range filters { for j, filter := range filters {
r := httptest.NewRecorder() r := httptest.NewRecorder()
r.Header().Set("Content-Length", fmt.Sprint(ts.length)) r.Header().Set("Content-Length", fmt.Sprint(ts.length))
wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, &gzipResponseWriter{gzip.NewWriter(r), r, false}) wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, &gzipResponseWriter{gzip.NewWriter(r), &httpserver.ResponseWriterWrapper{ResponseWriter: r}, false})
if filter.ShouldCompress(wWriter) != ts.shouldCompress[j] { if filter.ShouldCompress(wWriter) != ts.shouldCompress[j] {
t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r)) t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r))
} }
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
package header package header
import ( import (
"bufio"
"net"
"net/http" "net/http"
"strings" "strings"
...@@ -23,7 +21,9 @@ type Headers struct { ...@@ -23,7 +21,9 @@ type Headers struct {
// setting headers on the response according to the configured rules. // setting headers on the response according to the configured rules.
func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
replacer := httpserver.NewReplacer(r, nil, "") replacer := httpserver.NewReplacer(r, nil, "")
rww := &responseWriterWrapper{ResponseWriter: w} rww := &responseWriterWrapper{
ResponseWriterWrapper: &httpserver.ResponseWriterWrapper{ResponseWriter: w},
}
for _, rule := range h.Rules { for _, rule := range h.Rules {
if httpserver.Path(r.URL.Path).Matches(rule.Path) { if httpserver.Path(r.URL.Path).Matches(rule.Path) {
for name := range rule.Headers { for name := range rule.Headers {
...@@ -62,20 +62,20 @@ type headerOperation func(http.Header) ...@@ -62,20 +62,20 @@ type headerOperation func(http.Header)
// responseWriterWrapper wraps the real ResponseWriter. // responseWriterWrapper wraps the real ResponseWriter.
// It defers header operations until writeHeader // It defers header operations until writeHeader
type responseWriterWrapper struct { type responseWriterWrapper struct {
http.ResponseWriter *httpserver.ResponseWriterWrapper
ops []headerOperation ops []headerOperation
wroteHeader bool wroteHeader bool
} }
func (rww *responseWriterWrapper) Header() http.Header { func (rww *responseWriterWrapper) Header() http.Header {
return rww.ResponseWriter.Header() return rww.ResponseWriterWrapper.Header()
} }
func (rww *responseWriterWrapper) Write(d []byte) (int, error) { func (rww *responseWriterWrapper) Write(d []byte) (int, error) {
if !rww.wroteHeader { if !rww.wroteHeader {
rww.WriteHeader(http.StatusOK) rww.WriteHeader(http.StatusOK)
} }
return rww.ResponseWriter.Write(d) return rww.ResponseWriterWrapper.Write(d)
} }
func (rww *responseWriterWrapper) WriteHeader(status int) { func (rww *responseWriterWrapper) WriteHeader(status int) {
...@@ -91,7 +91,7 @@ func (rww *responseWriterWrapper) WriteHeader(status int) { ...@@ -91,7 +91,7 @@ func (rww *responseWriterWrapper) WriteHeader(status int) {
op(h) op(h)
} }
rww.ResponseWriter.WriteHeader(status) rww.ResponseWriterWrapper.WriteHeader(status)
} }
// delHeader deletes the existing header according to the key // delHeader deletes the existing header according to the key
...@@ -106,45 +106,5 @@ func (rww *responseWriterWrapper) delHeader(key string) { ...@@ -106,45 +106,5 @@ func (rww *responseWriterWrapper) delHeader(key string) {
}) })
} }
// Hijack implements http.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (rww *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := rww.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, httpserver.NonHijackerError{Underlying: rww.ResponseWriter}
}
// Flush implements http.Flusher. It simply wraps the underlying
// ResponseWriter's Flush method if there is one, or panics.
func (rww *responseWriterWrapper) Flush() {
if f, ok := rww.ResponseWriter.(http.Flusher); ok {
f.Flush()
} else {
panic(httpserver.NonFlusherError{Underlying: rww.ResponseWriter}) // should be recovered at the beginning of middleware stack
}
}
// CloseNotify implements http.CloseNotifier.
// It just inherits the underlying ResponseWriter's CloseNotify method.
// It panics if the underlying ResponseWriter is not a CloseNotifier.
func (rww *responseWriterWrapper) CloseNotify() <-chan bool {
if cn, ok := rww.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
panic(httpserver.NonCloseNotifierError{Underlying: rww.ResponseWriter})
}
func (rww *responseWriterWrapper) Push(target string, opts *http.PushOptions) error {
if pusher, hasPusher := rww.ResponseWriter.(http.Pusher); hasPusher {
return pusher.Push(target, opts)
}
return httpserver.NonPusherError{Underlying: rww.ResponseWriter}
}
// Interface guards // Interface guards
var _ http.Pusher = (*responseWriterWrapper)(nil) var _ httpserver.HTTPInterfaces = (*responseWriterWrapper)(nil)
var _ http.Flusher = (*responseWriterWrapper)(nil)
var _ http.CloseNotifier = (*responseWriterWrapper)(nil)
var _ http.Hijacker = (*responseWriterWrapper)(nil)
package httpserver package httpserver
import ( import (
"bufio"
"net"
"net/http" "net/http"
"time" "time"
) )
...@@ -20,7 +18,7 @@ import ( ...@@ -20,7 +18,7 @@ import (
// //
// Beware when accessing the Replacer value; it may be nil! // Beware when accessing the Replacer value; it may be nil!
type ResponseRecorder struct { type ResponseRecorder struct {
http.ResponseWriter *ResponseWriterWrapper
Replacer Replacer Replacer Replacer
status int status int
size int size int
...@@ -35,9 +33,9 @@ type ResponseRecorder struct { ...@@ -35,9 +33,9 @@ type ResponseRecorder struct {
// of 200 to cover the default case. // of 200 to cover the default case.
func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder { func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder {
return &ResponseRecorder{ return &ResponseRecorder{
ResponseWriter: w, ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: w},
status: http.StatusOK, status: http.StatusOK,
start: time.Now(), start: time.Now(),
} }
} }
...@@ -45,13 +43,13 @@ func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder { ...@@ -45,13 +43,13 @@ func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder {
// underlying ResponseWriter's WriteHeader method. // underlying ResponseWriter's WriteHeader method.
func (r *ResponseRecorder) WriteHeader(status int) { func (r *ResponseRecorder) WriteHeader(status int) {
r.status = status r.status = status
r.ResponseWriter.WriteHeader(status) r.ResponseWriterWrapper.WriteHeader(status)
} }
// Write is a wrapper that records the size of the body // Write is a wrapper that records the size of the body
// that gets written. // that gets written.
func (r *ResponseRecorder) Write(buf []byte) (int, error) { func (r *ResponseRecorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf) n, err := r.ResponseWriterWrapper.Write(buf)
if err == nil { if err == nil {
r.size += n r.size += n
} }
...@@ -68,45 +66,5 @@ func (r *ResponseRecorder) Status() int { ...@@ -68,45 +66,5 @@ func (r *ResponseRecorder) Status() int {
return r.status return r.status
} }
// Hijack implements http.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *ResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := r.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, NonHijackerError{Underlying: r.ResponseWriter}
}
// Flush implements http.Flusher. It simply wraps the underlying
// ResponseWriter's Flush method if there is one, or does nothing.
func (r *ResponseRecorder) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
} else {
panic(NonFlusherError{Underlying: r.ResponseWriter}) // should be recovered at the beginning of middleware stack
}
}
// CloseNotify implements http.CloseNotifier.
// It just inherits the underlying ResponseWriter's CloseNotify method.
func (r *ResponseRecorder) CloseNotify() <-chan bool {
if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
panic(NonCloseNotifierError{Underlying: r.ResponseWriter})
}
// Push resource to client
func (r *ResponseRecorder) Push(target string, opts *http.PushOptions) error {
if pusher, hasPusher := r.ResponseWriter.(http.Pusher); hasPusher {
return pusher.Push(target, opts)
}
return NonPusherError{Underlying: r.ResponseWriter}
}
// Interface guards // Interface guards
var _ http.Pusher = (*ResponseRecorder)(nil) var _ HTTPInterfaces = (*ResponseRecorder)(nil)
var _ http.Flusher = (*ResponseRecorder)(nil)
var _ http.CloseNotifier = (*ResponseRecorder)(nil)
var _ http.Hijacker = (*ResponseRecorder)(nil)
package httpserver
import (
"bufio"
"net"
"net/http"
)
// ResponseWriterWrapper wrappers underlying ResponseWriter
// and inherits its Hijacker/Pusher/CloseNotifier/Flusher as well.
type ResponseWriterWrapper struct {
http.ResponseWriter
}
// Hijack implements http.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (rww *ResponseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := rww.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, NonHijackerError{Underlying: rww.ResponseWriter}
}
// Flush implements http.Flusher. It simply wraps the underlying
// ResponseWriter's Flush method if there is one, or panics.
func (rww *ResponseWriterWrapper) Flush() {
if f, ok := rww.ResponseWriter.(http.Flusher); ok {
f.Flush()
} else {
panic(NonFlusherError{Underlying: rww.ResponseWriter})
}
}
// CloseNotify implements http.CloseNotifier.
// It just inherits the underlying ResponseWriter's CloseNotify method.
// It panics if the underlying ResponseWriter is not a CloseNotifier.
func (rww *ResponseWriterWrapper) CloseNotify() <-chan bool {
if cn, ok := rww.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
panic(NonCloseNotifierError{Underlying: rww.ResponseWriter})
}
// Push implements http.Pusher.
// It just inherits the underlying ResponseWriter's Push method.
// It panics if the underlying ResponseWriter is not a Pusher.
func (rww *ResponseWriterWrapper) Push(target string, opts *http.PushOptions) error {
if pusher, hasPusher := rww.ResponseWriter.(http.Pusher); hasPusher {
return pusher.Push(target, opts)
}
return NonPusherError{Underlying: rww.ResponseWriter}
}
// HTTPInterfaces mix all the interfaces that middleware ResponseWriters need to support.
type HTTPInterfaces interface {
http.ResponseWriter
http.Pusher
http.Flusher
http.CloseNotifier
http.Hijacker
}
// Interface guards
var _ HTTPInterfaces = (*ResponseWriterWrapper)(nil)
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
package internalsrv package internalsrv
import ( import (
"bufio"
"net"
"net/http" "net/http"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
...@@ -44,7 +42,7 @@ func (i Internal) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -44,7 +42,7 @@ func (i Internal) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// Use internal response writer to ignore responses that will be // Use internal response writer to ignore responses that will be
// redirected to internal locations // redirected to internal locations
iw := internalResponseWriter{ResponseWriter: w} iw := internalResponseWriter{ResponseWriterWrapper: &httpserver.ResponseWriterWrapper{ResponseWriter: w}}
status, err := i.Next.ServeHTTP(iw, r) status, err := i.Next.ServeHTTP(iw, r)
for c := 0; c < maxRedirectCount && isInternalRedirect(iw); c++ { for c := 0; c < maxRedirectCount && isInternalRedirect(iw); c++ {
...@@ -69,7 +67,7 @@ func (i Internal) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -69,7 +67,7 @@ func (i Internal) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// calls to Write and WriteHeader if the response should be redirected to an // calls to Write and WriteHeader if the response should be redirected to an
// internal location. // internal location.
type internalResponseWriter struct { type internalResponseWriter struct {
http.ResponseWriter *httpserver.ResponseWriterWrapper
} }
// ClearHeader removes script headers that would interfere with follow up // ClearHeader removes script headers that would interfere with follow up
...@@ -84,7 +82,7 @@ func (w internalResponseWriter) ClearHeader() { ...@@ -84,7 +82,7 @@ func (w internalResponseWriter) ClearHeader() {
// internal location. // internal location.
func (w internalResponseWriter) WriteHeader(code int) { func (w internalResponseWriter) WriteHeader(code int) {
if !isInternalRedirect(w) { if !isInternalRedirect(w) {
w.ResponseWriter.WriteHeader(code) w.ResponseWriterWrapper.WriteHeader(code)
} }
} }
...@@ -94,53 +92,8 @@ func (w internalResponseWriter) Write(b []byte) (int, error) { ...@@ -94,53 +92,8 @@ func (w internalResponseWriter) Write(b []byte) (int, error) {
if isInternalRedirect(w) { if isInternalRedirect(w) {
return 0, nil return 0, nil
} }
return w.ResponseWriter.Write(b) return w.ResponseWriterWrapper.Write(b)
}
// Hijack implements http.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (w internalResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, httpserver.NonHijackerError{Underlying: w.ResponseWriter}
}
// Flush implements http.Flusher. It simply wraps the underlying
// ResponseWriter's Flush method if there is one, or panics.
func (w internalResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
} else {
panic(httpserver.NonFlusherError{Underlying: w.ResponseWriter})
}
}
// CloseNotify implements http.CloseNotifier.
// It just inherits the underlying ResponseWriter's CloseNotify method.
// It panics if the underlying ResponseWriter is not a CloseNotifier.
func (w internalResponseWriter) CloseNotify() <-chan bool {
if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
panic(httpserver.NonCloseNotifierError{Underlying: w.ResponseWriter})
}
// Push implements http.Pusher.
// It just inherits the underlying ResponseWriter's Push method.
// It panics if the underlying ResponseWriter is not a Pusher.
func (w internalResponseWriter) Push(target string, opts *http.PushOptions) error {
if pusher, hasPusher := w.ResponseWriter.(http.Pusher); hasPusher {
return pusher.Push(target, opts)
}
return httpserver.NonPusherError{Underlying: w.ResponseWriter}
} }
// Interface guards // Interface guards
var ( var _ httpserver.HTTPInterfaces = internalResponseWriter{}
_ http.Pusher = internalResponseWriter{}
_ http.Flusher = internalResponseWriter{}
_ http.CloseNotifier = internalResponseWriter{}
_ http.Hijacker = internalResponseWriter{}
)
...@@ -100,7 +100,9 @@ func TestReverseProxy(t *testing.T) { ...@@ -100,7 +100,9 @@ func TestReverseProxy(t *testing.T) {
// Make sure {upstream} placeholder is set // Make sure {upstream} placeholder is set
r.Body = ioutil.NopCloser(strings.NewReader("test")) r.Body = ioutil.NopCloser(strings.NewReader("test"))
rr := httpserver.NewResponseRecorder(testResponseRecorder{httptest.NewRecorder()}) rr := httpserver.NewResponseRecorder(testResponseRecorder{
ResponseWriterWrapper: &httpserver.ResponseWriterWrapper{ResponseWriter: httptest.NewRecorder()},
})
rr.Replacer = httpserver.NewReplacer(r, rr, "-") rr.Replacer = httpserver.NewReplacer(r, rr, "-")
p.ServeHTTP(rr, r) p.ServeHTTP(rr, r)
...@@ -1315,24 +1317,13 @@ func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write ...@@ -1315,24 +1317,13 @@ func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write
// testResponseRecorder wraps `httptest.ResponseRecorder`, // testResponseRecorder wraps `httptest.ResponseRecorder`,
// also implements `http.CloseNotifier`, `http.Hijacker` and `http.Pusher`. // also implements `http.CloseNotifier`, `http.Hijacker` and `http.Pusher`.
type testResponseRecorder struct { type testResponseRecorder struct {
*httptest.ResponseRecorder *httpserver.ResponseWriterWrapper
} }
func (testResponseRecorder) CloseNotify() <-chan bool { return nil } func (testResponseRecorder) CloseNotify() <-chan bool { return nil }
func (t testResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, httpserver.NonHijackerError{Underlying: t}
}
func (t testResponseRecorder) Push(target string, opts *http.PushOptions) error {
return httpserver.NonPusherError{Underlying: t}
}
// Interface guards // Interface guards
var ( var _ httpserver.HTTPInterfaces = testResponseRecorder{}
_ http.Pusher = testResponseRecorder{}
_ http.Flusher = testResponseRecorder{}
_ http.CloseNotifier = testResponseRecorder{}
_ http.Hijacker = testResponseRecorder{}
)
func BenchmarkProxy(b *testing.B) { func BenchmarkProxy(b *testing.B) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment