Commit aae7b695 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http: move RemoteAddr & UsingTLS from ResponseWriter to Request

ResponseWriter.RemoteAddr() string -> Request.RemoteAddr string
ResponseWriter.UsingTLS() bool -> Request.TLS *tls.ConnectionState

R=rsc, bradfitzwork
CC=gburd, golang-dev
https://golang.org/cl/4248075
parent ee23ab16
...@@ -152,7 +152,7 @@ func usage() { ...@@ -152,7 +152,7 @@ func usage() {
func loggingHandler(h http.Handler) http.Handler { func loggingHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
log.Printf("%s\t%s", w.RemoteAddr(), req.URL) log.Printf("%s\t%s", req.RemoteAddr, req.URL)
h.ServeHTTP(w, req) h.ServeHTTP(w, req)
}) })
} }
......
...@@ -74,11 +74,15 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { ...@@ -74,11 +74,15 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
"PATH_INFO=" + pathInfo, "PATH_INFO=" + pathInfo,
"SCRIPT_NAME=" + root, "SCRIPT_NAME=" + root,
"SCRIPT_FILENAME=" + h.Path, "SCRIPT_FILENAME=" + h.Path,
"REMOTE_ADDR=" + rw.RemoteAddr(), "REMOTE_ADDR=" + req.RemoteAddr,
"REMOTE_HOST=" + rw.RemoteAddr(), "REMOTE_HOST=" + req.RemoteAddr,
"SERVER_PORT=" + port, "SERVER_PORT=" + port,
} }
if req.TLS != nil {
env = append(env, "HTTPS=on")
}
if len(req.Cookie) > 0 { if len(req.Cookie) > 0 {
b := new(bytes.Buffer) b := new(bytes.Buffer)
for idx, c := range req.Cookie { for idx, c := range req.Cookie {
......
...@@ -37,6 +37,7 @@ func newRequest(httpreq string) *http.Request { ...@@ -37,6 +37,7 @@ func newRequest(httpreq string) *http.Request {
if err != nil { if err != nil {
panic("cgi: bogus http request in test: " + httpreq) panic("cgi: bogus http request in test: " + httpreq)
} }
req.RemoteAddr = "1.2.3.4"
return req return req
} }
......
...@@ -18,8 +18,6 @@ type ResponseRecorder struct { ...@@ -18,8 +18,6 @@ type ResponseRecorder struct {
HeaderMap http.Header // the HTTP response headers HeaderMap http.Header // the HTTP response headers
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
Flushed bool Flushed bool
FakeRemoteAddr string // the fake RemoteAddr to return, or "" for DefaultRemoteAddr
FakeUsingTLS bool // whether to return true from the UsingTLS method
} }
// NewRecorder returns an initialized ResponseRecorder. // NewRecorder returns an initialized ResponseRecorder.
...@@ -34,20 +32,6 @@ func NewRecorder() *ResponseRecorder { ...@@ -34,20 +32,6 @@ func NewRecorder() *ResponseRecorder {
// an explicit DefaultRemoteAddr isn't set on ResponseRecorder. // an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
const DefaultRemoteAddr = "1.2.3.4" const DefaultRemoteAddr = "1.2.3.4"
// RemoteAddr returns the value of rw.FakeRemoteAddr, if set, else
// returns DefaultRemoteAddr.
func (rw *ResponseRecorder) RemoteAddr() string {
if rw.FakeRemoteAddr != "" {
return rw.FakeRemoteAddr
}
return DefaultRemoteAddr
}
// UsingTLS returns the fake value in rw.FakeUsingTLS
func (rw *ResponseRecorder) UsingTLS() bool {
return rw.FakeUsingTLS
}
// Header returns the response headers. // Header returns the response headers.
func (rw *ResponseRecorder) Header() http.Header { func (rw *ResponseRecorder) Header() http.Header {
return rw.HeaderMap return rw.HeaderMap
......
...@@ -11,6 +11,7 @@ package http ...@@ -11,6 +11,7 @@ package http
import ( import (
"bufio" "bufio"
"crypto/tls"
"container/vector" "container/vector"
"fmt" "fmt"
"io" "io"
...@@ -137,6 +138,22 @@ type Request struct { ...@@ -137,6 +138,22 @@ type Request struct {
// response has multiple trailer lines with the same key, they will be // response has multiple trailer lines with the same key, they will be
// concatenated, delimited by commas. // concatenated, delimited by commas.
Trailer Header Trailer Header
// RemoteAddr allows HTTP servers and other software to record
// the network address that sent the request, usually for
// logging. This field is not filled in by ReadRequest and
// has no defined format. The HTTP server in this package
// sets RemoteAddr to an "IP:port" address before invoking a
// handler.
RemoteAddr string
// TLS allows HTTP servers and other software to record
// information about the TLS connection on which the request
// was received. This field is not filled in by ReadRequest.
// The HTTP server in this package sets the field for
// TLS-enabled connections before invoking a handler;
// otherwise it leaves the field nil.
TLS *tls.ConnectionState
} }
// ProtoAtLeast returns whether the HTTP protocol used // ProtoAtLeast returns whether the HTTP protocol used
......
...@@ -229,6 +229,7 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { ...@@ -229,6 +229,7 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
} }
func TestServerTimeouts(t *testing.T) { func TestServerTimeouts(t *testing.T) {
// TODO(bradfitz): convert this to use httptest.Server
l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 0}) l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 0})
if err != nil { if err != nil {
t.Fatalf("listen error: %v", err) t.Fatalf("listen error: %v", err)
...@@ -406,3 +407,23 @@ func TestServeHTTP10Close(t *testing.T) { ...@@ -406,3 +407,23 @@ func TestServeHTTP10Close(t *testing.T) {
success <- true success <- true
} }
func TestSetsRemoteAddr(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%s", r.RemoteAddr)
}))
defer ts.Close()
res, _, err := Get(ts.URL)
if err != nil {
t.Fatalf("Get error: %v", err)
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
ip := string(body)
if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
t.Fatalf("Expected local addr; got %q", ip)
}
}
...@@ -48,12 +48,6 @@ type Handler interface { ...@@ -48,12 +48,6 @@ type Handler interface {
// A ResponseWriter interface is used by an HTTP handler to // A ResponseWriter interface is used by an HTTP handler to
// construct an HTTP response. // construct an HTTP response.
type ResponseWriter interface { type ResponseWriter interface {
// RemoteAddr returns the address of the client that sent the current request
RemoteAddr() string
// UsingTLS returns true if the client is connected using TLS
UsingTLS() bool
// Header returns the header map that will be sent by WriteHeader. // Header returns the header map that will be sent by WriteHeader.
// Changing the header after a call to WriteHeader (or Write) has // Changing the header after a call to WriteHeader (or Write) has
// no effect. // no effect.
...@@ -102,7 +96,7 @@ type conn struct { ...@@ -102,7 +96,7 @@ type conn struct {
rwc net.Conn // i/o connection rwc net.Conn // i/o connection
buf *bufio.ReadWriter // buffered rwc buf *bufio.ReadWriter // buffered rwc
hijacked bool // connection has been hijacked by handler hijacked bool // connection has been hijacked by handler
usingTLS bool // a flag indicating connection over TLS tlsState *tls.ConnectionState // or nil when not using TLS
} }
// A response represents the server side of an HTTP response. // A response represents the server side of an HTTP response.
...@@ -130,10 +124,15 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { ...@@ -130,10 +124,15 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) {
c.remoteAddr = rwc.RemoteAddr().String() c.remoteAddr = rwc.RemoteAddr().String()
c.handler = handler c.handler = handler
c.rwc = rwc c.rwc = rwc
_, c.usingTLS = rwc.(*tls.Conn)
br := bufio.NewReader(rwc) br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc) bw := bufio.NewWriter(rwc)
c.buf = bufio.NewReadWriter(br, bw) c.buf = bufio.NewReadWriter(br, bw)
if tlsConn, ok := rwc.(*tls.Conn); ok {
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()
}
return c, nil return c, nil
} }
...@@ -173,6 +172,9 @@ func (c *conn) readRequest() (w *response, err os.Error) { ...@@ -173,6 +172,9 @@ func (c *conn) readRequest() (w *response, err os.Error) {
return nil, err return nil, err
} }
req.RemoteAddr = c.remoteAddr
req.TLS = c.tlsState
w = new(response) w = new(response)
w.conn = c w.conn = c
w.req = req w.req = req
...@@ -187,12 +189,6 @@ func (c *conn) readRequest() (w *response, err os.Error) { ...@@ -187,12 +189,6 @@ func (c *conn) readRequest() (w *response, err os.Error) {
return w, nil return w, nil
} }
func (w *response) UsingTLS() bool {
return w.conn.usingTLS
}
func (w *response) RemoteAddr() string { return w.conn.remoteAddr }
func (w *response) Header() Header { func (w *response) Header() Header {
return w.header return w.header
} }
......
...@@ -514,7 +514,7 @@ func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { ...@@ -514,7 +514,7 @@ func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
conn, _, err := w.(http.Hijacker).Hijack() conn, _, err := w.(http.Hijacker).Hijack()
if err != nil { if err != nil {
log.Print("rpc hijacking ", w.RemoteAddr(), ": ", err.String()) log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.String())
return return
} }
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
......
...@@ -98,7 +98,7 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { ...@@ -98,7 +98,7 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
var location string var location string
if w.UsingTLS() { if req.TLS != nil {
location = "wss://" + req.Host + req.URL.RawPath location = "wss://" + req.Host + req.URL.RawPath
} else { } else {
location = "ws://" + req.Host + req.URL.RawPath location = "ws://" + req.Host + req.URL.RawPath
...@@ -192,7 +192,7 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { ...@@ -192,7 +192,7 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
defer rwc.Close() defer rwc.Close()
var location string var location string
if w.UsingTLS() { if req.TLS != nil {
location = "wss://" + req.Host + req.URL.RawPath location = "wss://" + req.Host + req.URL.RawPath
} else { } else {
location = "ws://" + req.Host + req.URL.RawPath location = "ws://" + req.Host + req.URL.RawPath
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment