Commit ebe91d11 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: treat HEAD requests like GET requests

A response to a HEAD request is supposed to look the same as a
response to a GET request, just without a body.

HEAD requests are incredibly rare in the wild.

The Go net/http package has so far treated HEAD requests
specially: a Write on our default ResponseWriter returned
ErrBodyNotAllowed, telling handlers that something was wrong.
This was to optimize the fast path for HEAD requests, but:

1) because HEAD requests are incredibly rare, they're not
   worth having a fast path for.

2) Letting the http.Handler handle but do nop Writes is still
   very fast.

3) this forces ugly error handling into the application.
   e.g. https://code.google.com/p/go/source/detail?r=6f596be7a31e
   and related.

4) The net/http package nowadays does Content-Type sniffing,
   but you don't get that for HEAD.

5) The net/http package nowadays does Content-Length counting
   for small (few KB) responses, but not for HEAD.

6) ErrBodyNotAllowed was useless. By the time you received it,
   you had probably already done all your heavy computation
   and I/O to calculate what to write.

So, this change makes HEAD requests like GET requests.

We now count content-length and sniff content-type for HEAD
requests. If you Write, it doesn't return an error.

If you want a fast-path in your code for HEAD, you have to do
it early and set all the response headers yourself. Just like
before. If you choose not to Write in HEAD requests, be sure
to set Content-Length if you know it. We won't write
"Content-Length: 0" because you might've just chosen to not
write (or you don't know your Content-Length in advance).

Fixes #5454

R=golang-dev, dsymonds
CC=golang-dev
https://golang.org/cl/12583043
parent a4ebad79
...@@ -632,22 +632,20 @@ func Test304Responses(t *testing.T) { ...@@ -632,22 +632,20 @@ func Test304Responses(t *testing.T) {
} }
} }
// TestHeadResponses verifies that responses to HEAD requests don't // TestHeadResponses verifies that all MIME type sniffing and Content-Length
// declare that they're chunking in their response headers, aren't // counting of GET requests also happens on HEAD requests.
// allowed to produce output, and don't set a Content-Type since
// the real type of the body data cannot be inferred.
func TestHeadResponses(t *testing.T) { func TestHeadResponses(t *testing.T) {
defer afterTest(t) defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("Ignored body")) _, err := w.Write([]byte("<html>"))
if err != ErrBodyNotAllowed { if err != nil {
t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) t.Errorf("ResponseWriter.Write: %v", err)
} }
// Also exercise the ReaderFrom path // Also exercise the ReaderFrom path
_, err = io.Copy(w, strings.NewReader("Ignored body")) _, err = io.Copy(w, strings.NewReader("789a"))
if err != ErrBodyNotAllowed { if err != nil {
t.Errorf("on Copy, expected ErrBodyNotAllowed, got %v", err) t.Errorf("Copy(ResponseWriter, ...): %v", err)
} }
})) }))
defer ts.Close() defer ts.Close()
...@@ -658,9 +656,11 @@ func TestHeadResponses(t *testing.T) { ...@@ -658,9 +656,11 @@ func TestHeadResponses(t *testing.T) {
if len(res.TransferEncoding) > 0 { if len(res.TransferEncoding) > 0 {
t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
} }
ct := res.Header.Get("Content-Type") if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
if ct != "" { t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
t.Errorf("expected no Content-Type; got %s", ct) }
if v := res.ContentLength; v != 10 {
t.Errorf("Content-Length: %d; want 10", v)
} }
body, err := ioutil.ReadAll(res.Body) body, err := ioutil.ReadAll(res.Body)
if err != nil { if err != nil {
......
...@@ -246,6 +246,10 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { ...@@ -246,6 +246,10 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) {
if !cw.wroteHeader { if !cw.wroteHeader {
cw.writeHeader(p) cw.writeHeader(p)
} }
if cw.res.req.Method == "HEAD" {
// Eat writes.
return len(p), nil
}
if cw.chunking { if cw.chunking {
_, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p))
if err != nil { if err != nil {
...@@ -704,6 +708,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { ...@@ -704,6 +708,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
cw.wroteHeader = true cw.wroteHeader = true
w := cw.res w := cw.res
isHEAD := w.req.Method == "HEAD"
// header is written out to w.conn.buf below. Depending on the // header is written out to w.conn.buf below. Depending on the
// state of the handler, we either own the map or not. If we // state of the handler, we either own the map or not. If we
...@@ -735,7 +740,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { ...@@ -735,7 +740,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// response header and this is our first (and last) write, set // response header and this is our first (and last) write, set
// it, even to zero. This helps HTTP/1.0 clients keep their // it, even to zero. This helps HTTP/1.0 clients keep their
// "keep-alive" connections alive. // "keep-alive" connections alive.
if w.handlerDone && header.get("Content-Length") == "" && w.req.Method != "HEAD" { if w.handlerDone && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) {
w.contentLength = int64(len(p)) w.contentLength = int64(len(p))
setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10)
} }
...@@ -752,7 +757,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { ...@@ -752,7 +757,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// Check for a explicit (and valid) Content-Length header. // Check for a explicit (and valid) Content-Length header.
hasCL := w.contentLength != -1 hasCL := w.contentLength != -1
if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { if w.req.wantsHttp10KeepAlive() && (isHEAD || hasCL) {
_, connectionHeaderSet := header["Connection"] _, connectionHeaderSet := header["Connection"]
if !connectionHeaderSet { if !connectionHeaderSet {
setHeader.connection = "keep-alive" setHeader.connection = "keep-alive"
...@@ -793,7 +798,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { ...@@ -793,7 +798,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
} else { } else {
// If no content type, apply sniffing algorithm to body. // If no content type, apply sniffing algorithm to body.
_, haveType := header["Content-Type"] _, haveType := header["Content-Type"]
if !haveType && w.req.Method != "HEAD" { if !haveType {
setHeader.contentType = DetectContentType(p) setHeader.contentType = DetectContentType(p)
} }
} }
...@@ -905,7 +910,7 @@ func (w *response) bodyAllowed() bool { ...@@ -905,7 +910,7 @@ func (w *response) bodyAllowed() bool {
if !w.wroteHeader { if !w.wroteHeader {
panic("") panic("")
} }
return w.status != StatusNotModified && w.req.Method != "HEAD" return w.status != StatusNotModified
} }
// The Life Of A Write is like this: // The Life Of A Write is like this:
...@@ -983,7 +988,7 @@ func (w *response) finishRequest() { ...@@ -983,7 +988,7 @@ func (w *response) finishRequest() {
w.req.MultipartForm.RemoveAll() w.req.MultipartForm.RemoveAll()
} }
if w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written { if w.req.Method != "HEAD" && w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written {
// Did not write enough. Avoid getting out of sync. // Did not write enough. Avoid getting out of sync.
w.closeAfterReply = true w.closeAfterReply = true
} }
......
...@@ -470,6 +470,7 @@ func TestTransportHeadResponses(t *testing.T) { ...@@ -470,6 +470,7 @@ func TestTransportHeadResponses(t *testing.T) {
res, err := c.Head(ts.URL) res, err := c.Head(ts.URL)
if err != nil { if err != nil {
t.Errorf("error on loop %d: %v", i, err) t.Errorf("error on loop %d: %v", i, err)
continue
} }
if e, g := "123", res.Header.Get("Content-Length"); e != g { if e, g := "123", res.Header.Get("Content-Length"); e != g {
t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
...@@ -477,6 +478,11 @@ func TestTransportHeadResponses(t *testing.T) { ...@@ -477,6 +478,11 @@ func TestTransportHeadResponses(t *testing.T) {
if e, g := int64(123), res.ContentLength; e != g { if e, g := int64(123), res.ContentLength; e != g {
t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
} }
if all, err := ioutil.ReadAll(res.Body); err != nil {
t.Errorf("loop %d: Body ReadAll: %v", i, err)
} else if len(all) != 0 {
t.Errorf("Bogus body %q", all)
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment