Commit c37ad7f6 authored by Matthew Holt's avatar Matthew Holt

Only write error message/page if body not already written (fixes #567)

Based on work started in, and replaces, #614
parent 737c7c43
...@@ -43,7 +43,9 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er ...@@ -43,7 +43,9 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er
} }
if status >= 400 { if status >= 400 {
h.errorPage(w, r, status) if w.Header().Get("Content-Length") == "" {
h.errorPage(w, r, status)
}
return 0, err return 0, err
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"testing" "testing"
...@@ -78,6 +79,13 @@ func TestErrors(t *testing.T) { ...@@ -78,6 +79,13 @@ func TestErrors(t *testing.T) {
expectedLog: "", expectedLog: "",
expectedErr: nil, expectedErr: nil,
}, },
{
next: genErrorHandler(http.StatusNotFound, nil, "normal"),
expectedCode: 0,
expectedBody: "normal",
expectedLog: "",
expectedErr: nil,
},
{ {
next: genErrorHandler(http.StatusForbidden, nil, ""), next: genErrorHandler(http.StatusForbidden, nil, ""),
expectedCode: 0, expectedCode: 0,
...@@ -158,6 +166,9 @@ func TestVisibleErrorWithPanic(t *testing.T) { ...@@ -158,6 +166,9 @@ func TestVisibleErrorWithPanic(t *testing.T) {
func genErrorHandler(status int, err error, body string) middleware.Handler { func genErrorHandler(status int, err error, body string) middleware.Handler {
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
if len(body) > 0 {
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
}
fmt.Fprint(w, body) fmt.Fprint(w, body)
return status, err return status, err
}) })
......
...@@ -107,7 +107,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -107,7 +107,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
} }
var responseBody io.Reader = resp.Body var responseBody io.Reader = resp.Body
if r.Header.Get("Content-Length") == "" { if resp.Header.Get("Content-Length") == "" {
// If the upstream app didn't set a Content-Length (shame on them), // If the upstream app didn't set a Content-Length (shame on them),
// we need to do it to prevent error messages being appended to // we need to do it to prevent error messages being appended to
// an already-written response, and other problematic behavior. // an already-written response, and other problematic behavior.
...@@ -137,6 +137,11 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) ...@@ -137,6 +137,11 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n")) err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
} }
// Normally we should only return a status >= 400 if no response
// body is written yet, however, upstream apps don't know about
// this contract and we still want the correct code logged, so error
// handling code in our stack needs to check Content-Length before
// writing an error message... oh well.
return resp.StatusCode, err return resp.StatusCode, err
} }
} }
......
...@@ -26,7 +26,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -26,7 +26,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// The error must be handled here so the log entry will record the response size. // The error must be handled here so the log entry will record the response size.
if l.ErrorFunc != nil { if l.ErrorFunc != nil {
l.ErrorFunc(responseRecorder, r, status) l.ErrorFunc(responseRecorder, r, status)
} else { } else if responseRecorder.Header().Get("Content-Length") == "" { // ensure no body written since proxy backends may write an error page
// Default failover error handler // Default failover error handler
responseRecorder.WriteHeader(status) responseRecorder.WriteHeader(status)
fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status)) fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status))
......
...@@ -319,7 +319,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -319,7 +319,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
status, _ := vh.stack.ServeHTTP(w, r) status, _ := vh.stack.ServeHTTP(w, r)
// Fallback error response in case error handling wasn't chained in // Fallback error response in case error handling wasn't chained in
if status >= 400 { if status >= 400 && w.Header().Get("Content-Length") == "" {
DefaultErrorFunc(w, r, status) DefaultErrorFunc(w, r, status)
} }
} else { } else {
...@@ -417,36 +417,6 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) { ...@@ -417,36 +417,6 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
return ln.TCPListener.File() return ln.TCPListener.File()
} }
// copied from net/http/transport.go
/*
TODO - remove - not necessary?
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}*/
// ShutdownCallbacks executes all the shutdown callbacks // ShutdownCallbacks executes all the shutdown callbacks
// for all the virtualhosts in servers, and returns all the // for all the virtualhosts in servers, and returns all the
// errors generated during their execution. In other words, // errors generated during their execution. In other words,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment