Commit 16250da3 authored by Matthew Holt's avatar Matthew Holt

errors: Fix low risk race condition at server close

See issue #1371 for more information.
parent 45a0e4cf
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"os" "os"
"runtime" "runtime"
"strings" "strings"
"sync"
"time" "time"
"github.com/mholt/caddy" "github.com/mholt/caddy"
...@@ -30,8 +31,9 @@ type ErrorHandler struct { ...@@ -30,8 +31,9 @@ type ErrorHandler struct {
LogFile string LogFile string
Log *log.Logger Log *log.Logger
LogRoller *httpserver.LogRoller LogRoller *httpserver.LogRoller
Debug bool // if true, errors are written out to client rather than to a log Debug bool // if true, errors are written out to client rather than to a log
file *os.File // a log file to close when done file *os.File // a log file to close when done
fileMu *sync.RWMutex // like with log middleware, os.File can't "safely" be closed in a different goroutine
} }
func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
...@@ -48,7 +50,9 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er ...@@ -48,7 +50,9 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er
fmt.Fprintln(w, errMsg) fmt.Fprintln(w, errMsg)
return 0, err // returning 0 signals that a response has been written return 0, err // returning 0 signals that a response has been written
} }
h.fileMu.RLock()
h.Log.Println(errMsg) h.Log.Println(errMsg)
h.fileMu.RUnlock()
} }
if status >= 400 { if status >= 400 {
...@@ -69,8 +73,10 @@ func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int ...@@ -69,8 +73,10 @@ func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int
errorPage, err := os.Open(pagePath) errorPage, err := os.Open(pagePath)
if err != nil { if err != nil {
// An additional error handling an error... <insert grumpy cat here> // An additional error handling an error... <insert grumpy cat here>
h.fileMu.RLock()
h.Log.Printf("%s [NOTICE %d %s] could not load error page: %v", h.Log.Printf("%s [NOTICE %d %s] could not load error page: %v",
time.Now().Format(timeFormat), code, r.URL.String(), err) time.Now().Format(timeFormat), code, r.URL.String(), err)
h.fileMu.RUnlock()
httpserver.DefaultErrorFunc(w, r, code) httpserver.DefaultErrorFunc(w, r, code)
return return
} }
...@@ -83,8 +89,10 @@ func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int ...@@ -83,8 +89,10 @@ func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int
if err != nil { if err != nil {
// Epic fail... sigh. // Epic fail... sigh.
h.fileMu.RLock()
h.Log.Printf("%s [NOTICE %d %s] could not respond with %s: %v", h.Log.Printf("%s [NOTICE %d %s] could not respond with %s: %v",
time.Now().Format(timeFormat), code, r.URL.String(), pagePath, err) time.Now().Format(timeFormat), code, r.URL.String(), pagePath, err)
h.fileMu.RUnlock()
httpserver.DefaultErrorFunc(w, r, code) httpserver.DefaultErrorFunc(w, r, code)
} }
...@@ -146,7 +154,9 @@ func (h ErrorHandler) recovery(w http.ResponseWriter, r *http.Request) { ...@@ -146,7 +154,9 @@ func (h ErrorHandler) recovery(w http.ResponseWriter, r *http.Request) {
httpserver.WriteTextResponse(w, http.StatusInternalServerError, fmt.Sprintf("%s\n\n%s", panicMsg, stack)) httpserver.WriteTextResponse(w, http.StatusInternalServerError, fmt.Sprintf("%s\n\n%s", panicMsg, stack))
} else { } else {
// Currently we don't use the function name, since file:line is more conventional // Currently we don't use the function name, since file:line is more conventional
h.fileMu.RLock()
h.Log.Printf(panicMsg) h.Log.Printf(panicMsg)
h.fileMu.RUnlock()
h.errorPage(w, r, http.StatusInternalServerError) h.errorPage(w, r, http.StatusInternalServerError)
} }
} }
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
...@@ -32,7 +33,8 @@ func TestErrors(t *testing.T) { ...@@ -32,7 +33,8 @@ func TestErrors(t *testing.T) {
http.StatusNotFound: path, http.StatusNotFound: path,
http.StatusForbidden: "not_exist_file", http.StatusForbidden: "not_exist_file",
}, },
Log: log.New(&buf, "", 0), Log: log.New(&buf, "", 0),
fileMu: new(sync.RWMutex),
} }
_, notExistErr := os.Open("not_exist_file") _, notExistErr := os.Open("not_exist_file")
...@@ -121,6 +123,7 @@ func TestVisibleErrorWithPanic(t *testing.T) { ...@@ -121,6 +123,7 @@ func TestVisibleErrorWithPanic(t *testing.T) {
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
panic(panicMsg) panic(panicMsg)
}), }),
fileMu: new(sync.RWMutex),
} }
req, err := http.NewRequest("GET", "/", nil) req, err := http.NewRequest("GET", "/", nil)
...@@ -176,7 +179,8 @@ func TestGenericErrorPage(t *testing.T) { ...@@ -176,7 +179,8 @@ func TestGenericErrorPage(t *testing.T) {
ErrorPages: map[int]string{ ErrorPages: map[int]string{
http.StatusNotFound: notFoundErrorPagePath, http.StatusNotFound: notFoundErrorPagePath,
}, },
Log: log.New(&buf, "", 0), Log: log.New(&buf, "", 0),
fileMu: new(sync.RWMutex),
} }
tests := []struct { tests := []struct {
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"sync"
"github.com/hashicorp/go-syslog" "github.com/hashicorp/go-syslog"
"github.com/mholt/caddy" "github.com/mholt/caddy"
...@@ -64,7 +65,9 @@ func setup(c *caddy.Controller) error { ...@@ -64,7 +65,9 @@ func setup(c *caddy.Controller) error {
// When server stops, close any open log file // When server stops, close any open log file
c.OnShutdown(func() error { c.OnShutdown(func() error {
if handler.file != nil { if handler.file != nil {
handler.fileMu.Lock()
handler.file.Close() handler.file.Close()
handler.fileMu.Unlock()
} }
return nil return nil
}) })
...@@ -81,7 +84,7 @@ func errorsParse(c *caddy.Controller) (*ErrorHandler, error) { ...@@ -81,7 +84,7 @@ func errorsParse(c *caddy.Controller) (*ErrorHandler, error) {
// Very important that we make a pointer because the startup // Very important that we make a pointer because the startup
// function that opens the log file must have access to the // function that opens the log file must have access to the
// same instance of the handler, not a copy. // same instance of the handler, not a copy.
handler := &ErrorHandler{ErrorPages: make(map[int]string)} handler := &ErrorHandler{ErrorPages: make(map[int]string), fileMu: new(sync.RWMutex)}
cfg := httpserver.GetConfig(c) cfg := httpserver.GetConfig(c)
......
...@@ -3,6 +3,7 @@ package errors ...@@ -3,6 +3,7 @@ package errors
import ( import (
"path/filepath" "path/filepath"
"reflect" "reflect"
"sync"
"testing" "testing"
"github.com/mholt/caddy" "github.com/mholt/caddy"
...@@ -58,18 +59,22 @@ func TestErrorsParse(t *testing.T) { ...@@ -58,18 +59,22 @@ func TestErrorsParse(t *testing.T) {
}{ }{
{`errors`, false, ErrorHandler{ {`errors`, false, ErrorHandler{
ErrorPages: map[int]string{}, ErrorPages: map[int]string{},
fileMu: new(sync.RWMutex),
}}, }},
{`errors errors.txt`, false, ErrorHandler{ {`errors errors.txt`, false, ErrorHandler{
ErrorPages: map[int]string{}, ErrorPages: map[int]string{},
LogFile: "errors.txt", LogFile: "errors.txt",
fileMu: new(sync.RWMutex),
}}, }},
{`errors visible`, false, ErrorHandler{ {`errors visible`, false, ErrorHandler{
ErrorPages: map[int]string{}, ErrorPages: map[int]string{},
Debug: true, Debug: true,
fileMu: new(sync.RWMutex),
}}, }},
{`errors { log visible }`, false, ErrorHandler{ {`errors { log visible }`, false, ErrorHandler{
ErrorPages: map[int]string{}, ErrorPages: map[int]string{},
Debug: true, Debug: true,
fileMu: new(sync.RWMutex),
}}, }},
{`errors { log errors.txt {`errors { log errors.txt
404 404.html 404 404.html
...@@ -80,6 +85,7 @@ func TestErrorsParse(t *testing.T) { ...@@ -80,6 +85,7 @@ func TestErrorsParse(t *testing.T) {
404: "404.html", 404: "404.html",
500: "500.html", 500: "500.html",
}, },
fileMu: new(sync.RWMutex),
}}, }},
{`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{ {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{
LogFile: "errors.txt", LogFile: "errors.txt",
...@@ -90,6 +96,7 @@ func TestErrorsParse(t *testing.T) { ...@@ -90,6 +96,7 @@ func TestErrorsParse(t *testing.T) {
LocalTime: true, LocalTime: true,
}, },
ErrorPages: map[int]string{}, ErrorPages: map[int]string{},
fileMu: new(sync.RWMutex),
}}, }},
{`errors { log errors.txt { {`errors { log errors.txt {
size 3 size 3
...@@ -110,6 +117,7 @@ func TestErrorsParse(t *testing.T) { ...@@ -110,6 +117,7 @@ func TestErrorsParse(t *testing.T) {
MaxBackups: 5, MaxBackups: 5,
LocalTime: true, LocalTime: true,
}, },
fileMu: new(sync.RWMutex),
}}, }},
{`errors { log errors.txt {`errors { log errors.txt
* generic_error.html * generic_error.html
...@@ -122,6 +130,7 @@ func TestErrorsParse(t *testing.T) { ...@@ -122,6 +130,7 @@ func TestErrorsParse(t *testing.T) {
404: "404.html", 404: "404.html",
503: "503.html", 503: "503.html",
}, },
fileMu: new(sync.RWMutex),
}}, }},
// test absolute file path // test absolute file path
{`errors { {`errors {
...@@ -131,16 +140,17 @@ func TestErrorsParse(t *testing.T) { ...@@ -131,16 +140,17 @@ func TestErrorsParse(t *testing.T) {
ErrorPages: map[int]string{ ErrorPages: map[int]string{
404: testAbs, 404: testAbs,
}, },
fileMu: new(sync.RWMutex),
}}, }},
// Next two test cases is the detection of duplicate status codes // Next two test cases is the detection of duplicate status codes
{`errors { {`errors {
503 503.html 503 503.html
503 503.html 503 503.html
}`, true, ErrorHandler{ErrorPages: map[int]string{}}}, }`, true, ErrorHandler{ErrorPages: map[int]string{}, fileMu: new(sync.RWMutex)}},
{`errors { {`errors {
* generic_error.html * generic_error.html
* generic_error.html * generic_error.html
}`, true, ErrorHandler{ErrorPages: map[int]string{}}}, }`, true, ErrorHandler{ErrorPages: map[int]string{}, fileMu: new(sync.RWMutex)}},
} }
for i, test := range tests { for i, test := range tests {
actualErrorsRule, err := errorsParse(caddy.NewTestController("http", test.inputErrorsRules)) actualErrorsRule, err := errorsParse(caddy.NewTestController("http", test.inputErrorsRules))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment