Commit 54c65cb0 authored by Matthew Holt's avatar Matthew Holt

templates: Properly propagate response status code (fixes #1841)

Benchmarks with wrk showed no noticeable performance impact
parent 22b835b9
...@@ -195,6 +195,28 @@ func (rb *ResponseBuffer) ReadFrom(src io.Reader) (int64, error) { ...@@ -195,6 +195,28 @@ func (rb *ResponseBuffer) ReadFrom(src io.Reader) (int64, error) {
return rb.Buffer.ReadFrom(src) return rb.Buffer.ReadFrom(src)
} }
// StatusCodeWriter returns an http.ResponseWriter that always
// writes the status code stored in rb from when a response
// was buffered to it.
func (rb *ResponseBuffer) StatusCodeWriter(w http.ResponseWriter) http.ResponseWriter {
return forcedStatusCodeWriter{w, rb}
}
// forcedStatusCodeWriter is used to force a status code when
// writing the header. It uses the status code saved on rb.
// This is useful if passing a http.ResponseWriter into
// http.ServeContent because ServeContent hard-codes 2xx status
// codes. If we buffered the response, we force that status code
// instead.
type forcedStatusCodeWriter struct {
http.ResponseWriter
rb *ResponseBuffer
}
func (fscw forcedStatusCodeWriter) WriteHeader(int) {
fscw.ResponseWriter.WriteHeader(fscw.rb.status)
}
// respBufPool is used for io.CopyBuffer when ResponseBuffer // respBufPool is used for io.CopyBuffer when ResponseBuffer
// is configured to stream a response. // is configured to stream a response.
var respBufPool = &sync.Pool{ var respBufPool = &sync.Pool{
......
...@@ -40,7 +40,7 @@ func (t Templates) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error ...@@ -40,7 +40,7 @@ func (t Templates) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error
if reqExt == "" { if reqExt == "" {
// request has no extension, so check response Content-Type // request has no extension, so check response Content-Type
ct := mime.TypeByExtension(ext) ct := mime.TypeByExtension(ext)
if strings.Contains(header.Get("Content-Type"), ct) { if ct != "" && strings.Contains(header.Get("Content-Type"), ct) {
return true return true
} }
} else if reqExt == ext { } else if reqExt == ext {
...@@ -96,13 +96,14 @@ func (t Templates) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error ...@@ -96,13 +96,14 @@ func (t Templates) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error
// set the actual content length now that the template was executed // set the actual content length now that the template was executed
w.Header().Set("Content-Length", strconv.Itoa(buf.Len())) w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
// get the modification time in preparation to ServeContent // get the modification time in preparation for http.ServeContent
modTime, _ := time.Parse(http.TimeFormat, w.Header().Get("Last-Modified")) modTime, _ := time.Parse(http.TimeFormat, w.Header().Get("Last-Modified"))
// at last, write the rendered template to the response // at last, write the rendered template to the response; make sure to use
http.ServeContent(w, r, templateName, modTime, bytes.NewReader(buf.Bytes())) // use the proper status code, since ServeContent hard-codes 2xx codes...
http.ServeContent(rb.StatusCodeWriter(w), r, templateName, modTime, bytes.NewReader(buf.Bytes()))
return http.StatusOK, nil return 0, nil
} }
return t.Next.ServeHTTP(w, r) return t.Next.ServeHTTP(w, r)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment