Commit 709b12ff authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix location of StateHijacked and StateActive

1) Move StateHijacked callback earlier, to make it
panic-proof.  A Hijack followed by a panic didn't previously
result in ConnState getting fired for StateHijacked.  Move it
earlier, to the time of hijack.

2) Don't fire StateActive unless any bytes were read off the
wire while waiting for a request. This means we don't
transition from New or Idle to Active if the client
disconnects or times out. This was documented before, but not
implemented properly.

This CL is required for an pending fix for Issue 7264

LGTM=josharian
R=josharian
CC=golang-codereviews
https://golang.org/cl/69860049
parent df8b6378
...@@ -2270,6 +2270,12 @@ func TestServerConnState(t *testing.T) { ...@@ -2270,6 +2270,12 @@ func TestServerConnState(t *testing.T) {
c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello.")) c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
c.Close() c.Close()
}, },
"/hijack-panic": func(w ResponseWriter, r *Request) {
c, _, _ := w.(Hijacker).Hijack()
c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
c.Close()
panic("intentional panic")
},
} }
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
handler[r.URL.Path](w, r) handler[r.URL.Path](w, r)
...@@ -2280,6 +2286,7 @@ func TestServerConnState(t *testing.T) { ...@@ -2280,6 +2286,7 @@ func TestServerConnState(t *testing.T) {
var stateLog = map[int][]ConnState{} var stateLog = map[int][]ConnState{}
var connID = map[net.Conn]int{} var connID = map[net.Conn]int{}
ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
ts.Config.ConnState = func(c net.Conn, state ConnState) { ts.Config.ConnState = func(c net.Conn, state ConnState) {
if c == nil { if c == nil {
t.Error("nil conn seen in state %s", state) t.Error("nil conn seen in state %s", state)
...@@ -2303,11 +2310,56 @@ func TestServerConnState(t *testing.T) { ...@@ -2303,11 +2310,56 @@ func TestServerConnState(t *testing.T) {
mustGet(t, ts.URL+"/", "Connection", "close") mustGet(t, ts.URL+"/", "Connection", "close")
mustGet(t, ts.URL+"/hijack") mustGet(t, ts.URL+"/hijack")
mustGet(t, ts.URL+"/hijack-panic")
// New->Closed
{
c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
c.Close()
}
// New->Active->Closed
{
c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
t.Fatal(err)
}
c.Close()
}
// New->Idle->Closed
{
c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
t.Fatal(err)
}
res, err := ReadResponse(bufio.NewReader(c), nil)
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
t.Fatal(err)
}
c.Close()
}
want := map[int][]ConnState{ want := map[int][]ConnState{
1: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed}, 1: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed},
2: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed}, 2: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed},
3: []ConnState{StateNew, StateActive, StateHijacked}, 3: []ConnState{StateNew, StateActive, StateHijacked},
4: []ConnState{StateNew, StateActive, StateHijacked},
5: []ConnState{StateNew, StateClosed},
6: []ConnState{StateNew, StateActive, StateClosed},
7: []ConnState{StateNew, StateActive, StateIdle, StateClosed},
} }
logString := func(m map[int][]ConnState) string { logString := func(m map[int][]ConnState) string {
var b bytes.Buffer var b bytes.Buffer
...@@ -2316,6 +2368,7 @@ func TestServerConnState(t *testing.T) { ...@@ -2316,6 +2368,7 @@ func TestServerConnState(t *testing.T) {
for _, s := range l { for _, s := range l {
fmt.Fprintf(&b, "%s ", s) fmt.Fprintf(&b, "%s ", s)
} }
b.WriteString("\n")
} }
return b.String() return b.String()
} }
......
...@@ -139,6 +139,7 @@ func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { ...@@ -139,6 +139,7 @@ func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
buf = c.buf buf = c.buf
c.rwc = nil c.rwc = nil
c.buf = nil c.buf = nil
c.setState(rwc, StateHijacked)
return return
} }
...@@ -497,6 +498,10 @@ func (srv *Server) maxHeaderBytes() int { ...@@ -497,6 +498,10 @@ func (srv *Server) maxHeaderBytes() int {
return DefaultMaxHeaderBytes return DefaultMaxHeaderBytes
} }
func (srv *Server) initialLimitedReaderSize() int64 {
return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
}
// wrapper around io.ReaderCloser which on first read, sends an // wrapper around io.ReaderCloser which on first read, sends an
// HTTP/1.1 100 Continue header // HTTP/1.1 100 Continue header
type expectContinueReader struct { type expectContinueReader struct {
...@@ -567,7 +572,7 @@ func (c *conn) readRequest() (w *response, err error) { ...@@ -567,7 +572,7 @@ func (c *conn) readRequest() (w *response, err error) {
}() }()
} }
c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ c.lr.N = c.server.initialLimitedReaderSize()
var req *Request var req *Request
if req, err = ReadRequest(c.buf.Reader); err != nil { if req, err = ReadRequest(c.buf.Reader); err != nil {
if c.lr.N == 0 { if c.lr.N == 0 {
...@@ -1127,10 +1132,10 @@ func (c *conn) serve() { ...@@ -1127,10 +1132,10 @@ func (c *conn) serve() {
for { for {
w, err := c.readRequest() w, err := c.readRequest()
// TODO(bradfitz): could push this StateActive if c.lr.N != c.server.initialLimitedReaderSize() {
// earlier, but in practice header will be all in one // If we read any bytes off the wire, we're active.
// packet/Read: c.setState(c.rwc, StateActive)
c.setState(c.rwc, StateActive) }
if err != nil { if err != nil {
if err == errTooLarge { if err == errTooLarge {
// Their HTTP client may or may not be // Their HTTP client may or may not be
...@@ -1176,7 +1181,6 @@ func (c *conn) serve() { ...@@ -1176,7 +1181,6 @@ func (c *conn) serve() {
// in parallel even if their responses need to be serialized. // in parallel even if their responses need to be serialized.
serverHandler{c.server}.ServeHTTP(w, w.req) serverHandler{c.server}.ServeHTTP(w, w.req)
if c.hijacked() { if c.hijacked() {
c.setState(origConn, StateHijacked)
return return
} }
w.finishRequest() w.finishRequest()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment