Commit 4b0bc7c3 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: relax recently-updated rules and behavior of CloseNotifier

The CloseNotifier implementation and documentation was
substantially changed in https://golang.org/cl/17750 but it was a bit
too aggressive.

Issue #13666 highlighted that in addition to breaking external
projects, even the standard library (httputil.ReverseProxy) didn't
obey the new rules about not using CloseNotifier until the
Request.Body is fully consumed.

So, instead of fixing httputil.ReverseProxy, dial back the rules a
bit. It's now okay to call CloseNotify before consuming the request
body. The docs now say CloseNotifier may wait to fire before the
request body is fully consumed, but doesn't say that the behavior is
undefined anymore. Instead, we just wait until the request body is
consumed and start watching for EOF from the client then.

This CL also adds a test to ReverseProxy (using a POST request) that
would've caught this earlier.

Fixes #13666

Change-Id: Ib4e8c29c4bfbe7511f591cf9ffcda23a0f0b1269
Reviewed-on: https://go-review.googlesource.com/18144Reviewed-by: default avatarRuss Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
parent 66f1f89d
...@@ -8,6 +8,7 @@ package httputil ...@@ -8,6 +8,7 @@ package httputil
import ( import (
"bufio" "bufio"
"bytes"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
...@@ -104,7 +105,6 @@ func TestReverseProxy(t *testing.T) { ...@@ -104,7 +105,6 @@ func TestReverseProxy(t *testing.T) {
if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
} }
} }
func TestXForwardedFor(t *testing.T) { func TestXForwardedFor(t *testing.T) {
...@@ -384,3 +384,43 @@ func TestReverseProxyGetPutBuffer(t *testing.T) { ...@@ -384,3 +384,43 @@ func TestReverseProxyGetPutBuffer(t *testing.T) {
t.Errorf("Log events = %q; want %q", log, wantLog) t.Errorf("Log events = %q; want %q", log, wantLog)
} }
} }
func TestReverseProxy_Post(t *testing.T) {
const backendResponse = "I am the backend"
const backendStatus = 200
var requestBody = bytes.Repeat([]byte("a"), 1<<20)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slurp, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Error("Backend body read = %v", err)
}
if len(slurp) != len(requestBody) {
t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
}
if !bytes.Equal(slurp, requestBody) {
t.Error("Backend read wrong request body.") // 1MB; omitting details
}
w.Write([]byte(backendResponse))
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxyHandler := NewSingleHostReverseProxy(backendURL)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
res, err := http.DefaultClient.Do(postReq)
if err != nil {
t.Fatalf("Do: %v", err)
}
if g, e := res.StatusCode, backendStatus; g != e {
t.Errorf("got res.StatusCode %d; expected %d", g, e)
}
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendResponse; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
...@@ -692,7 +692,13 @@ func putTextprotoReader(r *textproto.Reader) { ...@@ -692,7 +692,13 @@ func putTextprotoReader(r *textproto.Reader) {
} }
// ReadRequest reads and parses an incoming request from b. // ReadRequest reads and parses an incoming request from b.
func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, true) } func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, deleteHostHeader) }
// Constants for readRequest's deleteHostHeader parameter.
const (
deleteHostHeader = true
keepHostHeader = false
)
func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) { func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) {
tp := newTextprotoReader(b) tp := newTextprotoReader(b)
......
...@@ -2482,6 +2482,59 @@ func TestHijackAfterCloseNotifier(t *testing.T) { ...@@ -2482,6 +2482,59 @@ func TestHijackAfterCloseNotifier(t *testing.T) {
} }
} }
func TestHijackBeforeRequestBodyRead(t *testing.T) {
defer afterTest(t)
var requestBody = bytes.Repeat([]byte("a"), 1<<20)
bodyOkay := make(chan bool, 1)
gotCloseNotify := make(chan bool, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
defer close(bodyOkay) // caller will read false if nothing else
reqBody := r.Body
r.Body = nil // to test that server.go doesn't use this value.
gone := w.(CloseNotifier).CloseNotify()
slurp, err := ioutil.ReadAll(reqBody)
if err != nil {
t.Error("Body read: %v", err)
return
}
if len(slurp) != len(requestBody) {
t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
return
}
if !bytes.Equal(slurp, requestBody) {
t.Error("Backend read wrong request body.") // 1MB; omitting details
return
}
bodyOkay <- true
select {
case <-gone:
gotCloseNotify <- true
case <-time.After(5 * time.Second):
gotCloseNotify <- false
}
}))
defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
len(requestBody), requestBody)
if !<-bodyOkay {
// already failed.
return
}
conn.Close()
if !<-gotCloseNotify {
t.Error("timeout waiting for CloseNotify")
}
}
func TestOptions(t *testing.T) { func TestOptions(t *testing.T) {
uric := make(chan string, 2) // only expect 1, but leave space for 2 uric := make(chan string, 2) // only expect 1, but leave space for 2
mux := NewServeMux() mux := NewServeMux()
......
...@@ -126,7 +126,7 @@ type CloseNotifier interface { ...@@ -126,7 +126,7 @@ type CloseNotifier interface {
// single value (true) when the client connection has gone // single value (true) when the client connection has gone
// away. // away.
// //
// CloseNotify is undefined before Request.Body has been // CloseNotify may wait to notify until Request.Body has been
// fully read. // fully read.
// //
// After the Handler has returned, there is no guarantee // After the Handler has returned, there is no guarantee
...@@ -135,7 +135,7 @@ type CloseNotifier interface { ...@@ -135,7 +135,7 @@ type CloseNotifier interface {
// If the protocol is HTTP/1.1 and CloseNotify is called while // If the protocol is HTTP/1.1 and CloseNotify is called while
// processing an idempotent request (such a GET) while // processing an idempotent request (such a GET) while
// HTTP/1.1 pipelining is in use, the arrival of a subsequent // HTTP/1.1 pipelining is in use, the arrival of a subsequent
// pipelined request will cause a value to be sent on the // pipelined request may cause a value to be sent on the
// returned channel. In practice HTTP/1.1 pipelining is not // returned channel. In practice HTTP/1.1 pipelining is not
// enabled in browsers and not seen often in the wild. If this // enabled in browsers and not seen often in the wild. If this
// is a problem, use HTTP/2 or only use CloseNotify on methods // is a problem, use HTTP/2 or only use CloseNotify on methods
...@@ -353,7 +353,9 @@ type response struct { ...@@ -353,7 +353,9 @@ type response struct {
dateBuf [len(TimeFormat)]byte dateBuf [len(TimeFormat)]byte
clenBuf [10]byte clenBuf [10]byte
closeNotifyCh <-chan bool // guarded by conn.mu // closeNotifyCh is non-nil once CloseNotify is called.
// Guarded by conn.mu
closeNotifyCh <-chan bool
} }
// declareTrailer is called for each Trailer header when the // declareTrailer is called for each Trailer header when the
...@@ -693,7 +695,7 @@ func (c *conn) readRequest() (w *response, err error) { ...@@ -693,7 +695,7 @@ func (c *conn) readRequest() (w *response, err error) {
peek, _ := c.bufr.Peek(4) // ReadRequest will get err below peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
c.bufr.Discard(numLeadingCRorLF(peek)) c.bufr.Discard(numLeadingCRorLF(peek))
} }
req, err := readRequest(c.bufr, false) req, err := readRequest(c.bufr, keepHostHeader)
c.mu.Unlock() c.mu.Unlock()
if err != nil { if err != nil {
if c.r.hitReadLimit() { if c.r.hitReadLimit() {
...@@ -986,7 +988,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { ...@@ -986,7 +988,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
} }
if discard { if discard {
_, err := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) _, err := io.CopyN(ioutil.Discard, w.reqBody, maxPostHandlerReadBytes+1)
switch err { switch err {
case nil: case nil:
// There must be even more data left over. // There must be even more data left over.
...@@ -995,7 +997,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { ...@@ -995,7 +997,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// Body was already consumed and closed. // Body was already consumed and closed.
case io.EOF: case io.EOF:
// The remaining body was just consumed, close it. // The remaining body was just consumed, close it.
err = w.req.Body.Close() err = w.reqBody.Close()
if err != nil { if err != nil {
w.closeAfterReply = true w.closeAfterReply = true
} }
...@@ -1540,14 +1542,57 @@ func (w *response) CloseNotify() <-chan bool { ...@@ -1540,14 +1542,57 @@ func (w *response) CloseNotify() <-chan bool {
var once sync.Once var once sync.Once
notify := func() { once.Do(func() { ch <- true }) } notify := func() { once.Do(func() { ch <- true }) }
if requestBodyRemains(w.reqBody) {
// They're still consuming the request body, so we
// shouldn't notify yet.
registerOnHitEOF(w.reqBody, func() {
c.mu.Lock()
defer c.mu.Unlock()
startCloseNotifyBackgroundRead(c, notify)
})
} else {
startCloseNotifyBackgroundRead(c, notify)
}
return ch
}
// c.mu must be held.
func startCloseNotifyBackgroundRead(c *conn, notify func()) {
if c.bufr.Buffered() > 0 { if c.bufr.Buffered() > 0 {
// A pipelined request or unread request body data is available // They've consumed the request body, so anything
// unread. Per the CloseNotifier docs, fire immediately. // remaining is a pipelined request, which we
// document as firing on.
notify() notify()
} else { } else {
c.r.startBackgroundRead(notify) c.r.startBackgroundRead(notify)
} }
return ch }
func registerOnHitEOF(rc io.ReadCloser, fn func()) {
switch v := rc.(type) {
case *expectContinueReader:
registerOnHitEOF(v.readCloser, fn)
case *body:
v.registerOnHitEOF(fn)
default:
panic("unexpected type " + fmt.Sprintf("%T", rc))
}
}
// requestBodyRemains reports whether future calls to Read
// on rc might yield more data.
func requestBodyRemains(rc io.ReadCloser) bool {
if rc == eofReader {
return false
}
switch v := rc.(type) {
case *expectContinueReader:
return requestBodyRemains(v.readCloser)
case *body:
return v.bodyRemains()
default:
panic("unexpected type " + fmt.Sprintf("%T", rc))
}
} }
// The HandlerFunc type is an adapter to allow the use of // The HandlerFunc type is an adapter to allow the use of
......
...@@ -621,10 +621,11 @@ type body struct { ...@@ -621,10 +621,11 @@ type body struct {
closing bool // is the connection to be closed after reading body? closing bool // is the connection to be closed after reading body?
doEarlyClose bool // whether Close should stop early doEarlyClose bool // whether Close should stop early
mu sync.Mutex // guards closed, and calls to Read and Close mu sync.Mutex // guards following, and calls to Read and Close
sawEOF bool sawEOF bool
closed bool closed bool
earlyClose bool // Close called and we didn't read to the end of src earlyClose bool // Close called and we didn't read to the end of src
onHitEOF func() // if non-nil, func to call when EOF is Read
} }
// ErrBodyReadAfterClose is returned when reading a Request or Response // ErrBodyReadAfterClose is returned when reading a Request or Response
...@@ -684,6 +685,10 @@ func (b *body) readLocked(p []byte) (n int, err error) { ...@@ -684,6 +685,10 @@ func (b *body) readLocked(p []byte) (n int, err error) {
} }
} }
if b.sawEOF && b.onHitEOF != nil {
b.onHitEOF()
}
return n, err return n, err
} }
...@@ -818,6 +823,20 @@ func (b *body) didEarlyClose() bool { ...@@ -818,6 +823,20 @@ func (b *body) didEarlyClose() bool {
return b.earlyClose return b.earlyClose
} }
// bodyRemains reports whether future Read calls might
// yield data.
func (b *body) bodyRemains() bool {
b.mu.Lock()
defer b.mu.Unlock()
return !b.sawEOF
}
func (b *body) registerOnHitEOF(fn func()) {
b.mu.Lock()
defer b.mu.Unlock()
b.onHitEOF = fn
}
// bodyLocked is a io.Reader reading from a *body when its mutex is // bodyLocked is a io.Reader reading from a *body when its mutex is
// already held. // already held.
type bodyLocked struct { type bodyLocked struct {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment