Commit 1faa8869 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: set the Request context for incoming server requests

Updates #13021
Updates #15224

Change-Id: Ia3cd608bb887fcfd8d81b035fa57bd5eb8edf09b
Reviewed-on: https://go-review.googlesource.com/21810Reviewed-by: default avatarAndrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: default avatarEmmanuel Odeke <emm.odeke@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent bd724976
...@@ -266,9 +266,13 @@ type Request struct { ...@@ -266,9 +266,13 @@ type Request struct {
// //
// The returned context is always non-nil; it defaults to the // The returned context is always non-nil; it defaults to the
// background context. // background context.
//
// For outgoing client requests, the context controls cancelation.
//
// For incoming server requests, the context is canceled when either
// the client's connection closes, or when the ServeHTTP method
// returns.
func (r *Request) Context() context.Context { func (r *Request) Context() context.Context {
// TODO(bradfitz): document above what Context means for server and client
// requests, once implemented.
if r.ctx != nil { if r.ctx != nil {
return r.ctx return r.ctx
} }
......
...@@ -9,6 +9,7 @@ package http_test ...@@ -9,6 +9,7 @@ package http_test
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
...@@ -3989,6 +3990,72 @@ func TestServerValidatesHeaders(t *testing.T) { ...@@ -3989,6 +3990,72 @@ func TestServerValidatesHeaders(t *testing.T) {
} }
} }
func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
defer afterTest(t)
ctxc := make(chan context.Context, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ctx := r.Context()
select {
case <-ctx.Done():
t.Error("should not be Done in ServeHTTP")
default:
}
ctxc <- ctx
}))
defer ts.Close()
res, err := Get(ts.URL)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
ctx := <-ctxc
select {
case <-ctx.Done():
default:
t.Error("context should be done after ServeHTTP completes")
}
}
func TestServerRequestContextCancel_ConnClose(t *testing.T) {
// Currently the context is not canceled when the connection
// is closed because we're not reading from the connection
// until after ServeHTTP for the previous handler is done.
// Until the server code is modified to always be in a read
// (Issue 15224), this test doesn't work yet.
t.Skip("TODO(bradfitz): this test doesn't yet work; golang.org/issue/15224")
defer afterTest(t)
inHandler := make(chan struct{})
handlerDone := make(chan struct{})
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
close(inHandler)
select {
case <-r.Context().Done():
case <-time.After(3 * time.Second):
t.Errorf("timeout waiting for context to be done")
}
close(handlerDone)
}))
defer ts.Close()
c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
defer c.Close()
io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
select {
case <-inHandler:
case <-time.After(3 * time.Second):
t.Fatalf("timeout waiting to see ServeHTTP get called")
}
c.Close() // this should trigger the context being done
select {
case <-handlerDone:
case <-time.After(3 * time.Second):
t.Fatalf("timeout waiting to see ServeHTTP exit")
}
}
func BenchmarkClientServer(b *testing.B) { func BenchmarkClientServer(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.StopTimer() b.StopTimer()
......
...@@ -9,6 +9,7 @@ package http ...@@ -9,6 +9,7 @@ package http
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
...@@ -312,6 +313,7 @@ type response struct { ...@@ -312,6 +313,7 @@ type response struct {
conn *conn conn *conn
req *Request // request for this response req *Request // request for this response
reqBody io.ReadCloser reqBody io.ReadCloser
cancelCtx context.CancelFunc // when ServeHTTP exits
wroteHeader bool // reply header has been (logically) written wroteHeader bool // reply header has been (logically) written
wroteContinue bool // 100 Continue response was written wroteContinue bool // 100 Continue response was written
wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive"
...@@ -686,7 +688,7 @@ func appendTime(b []byte, t time.Time) []byte { ...@@ -686,7 +688,7 @@ func appendTime(b []byte, t time.Time) []byte {
var errTooLarge = errors.New("http: request too large") var errTooLarge = errors.New("http: request too large")
// Read next request from connection. // Read next request from connection.
func (c *conn) readRequest() (w *response, err error) { func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
if c.hijacked() { if c.hijacked() {
return nil, ErrHijacked return nil, ErrHijacked
} }
...@@ -715,6 +717,10 @@ func (c *conn) readRequest() (w *response, err error) { ...@@ -715,6 +717,10 @@ func (c *conn) readRequest() (w *response, err error) {
} }
return nil, err return nil, err
} }
ctx, cancelCtx := context.WithCancel(ctx)
req.ctx = ctx
c.lastMethod = req.Method c.lastMethod = req.Method
c.r.setInfiniteReadLimit() c.r.setInfiniteReadLimit()
...@@ -749,6 +755,7 @@ func (c *conn) readRequest() (w *response, err error) { ...@@ -749,6 +755,7 @@ func (c *conn) readRequest() (w *response, err error) {
w = &response{ w = &response{
conn: c, conn: c,
cancelCtx: cancelCtx,
req: req, req: req,
reqBody: req.Body, reqBody: req.Body,
handlerHeader: make(Header), handlerHeader: make(Header),
...@@ -1432,12 +1439,20 @@ func (c *conn) serve() { ...@@ -1432,12 +1439,20 @@ func (c *conn) serve() {
} }
} }
// HTTP/1.x from here on.
c.r = &connReader{r: c.rwc} c.r = &connReader{r: c.rwc}
c.bufr = newBufioReader(c.r) c.bufr = newBufioReader(c.r)
c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
// TODO: allow changing base context? can't imagine concrete
// use cases yet.
baseCtx := context.Background()
ctx, cancelCtx := context.WithCancel(baseCtx)
defer cancelCtx()
for { for {
w, err := c.readRequest() w, err := c.readRequest(ctx)
if c.r.remain != c.server.initialReadLimitSize() { if c.r.remain != c.server.initialReadLimitSize() {
// If we read any bytes off the wire, we're active. // If we read any bytes off the wire, we're active.
c.setState(c.rwc, StateActive) c.setState(c.rwc, StateActive)
...@@ -1485,6 +1500,7 @@ func (c *conn) serve() { ...@@ -1485,6 +1500,7 @@ func (c *conn) serve() {
// [*] Not strictly true: HTTP pipelining. We could let them all process // [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized. // in parallel even if their responses need to be serialized.
serverHandler{c.server}.ServeHTTP(w, w.req) serverHandler{c.server}.ServeHTTP(w, w.req)
w.cancelCtx()
if c.hijacked() { if c.hijacked() {
return return
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment