Commit 7c4c87c0 authored by Meir Fischer's avatar Meir Fischer Committed by Brad Fitzpatrick

net/http/httptrace: expose request headers for http/1.1

Some headers, which are set or modified by the http library,
are not written to the standard http.Request.Header and are
not included as part of http.Response.Request.Header.

Exposing all headers alleviates this problem.

This is not a complete solution to 19761 since it does not have http/2 support.

Updates #19761

Change-Id: Ie8d4f702f4f671666b120b332378644f094e288b
Reviewed-on: https://go-review.googlesource.com/67430
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 1d303a00
...@@ -6,6 +6,7 @@ package http ...@@ -6,6 +6,7 @@ package http
import ( import (
"io" "io"
"net/http/httptrace"
"net/textproto" "net/textproto"
"sort" "sort"
"strings" "strings"
...@@ -56,7 +57,11 @@ func (h Header) Del(key string) { ...@@ -56,7 +57,11 @@ func (h Header) Del(key string) {
// Write writes a header in wire format. // Write writes a header in wire format.
func (h Header) Write(w io.Writer) error { func (h Header) Write(w io.Writer) error {
return h.WriteSubset(w, nil) return h.write(w, nil)
}
func (h Header) write(w io.Writer, trace *httptrace.ClientTrace) error {
return h.writeSubset(w, nil, trace)
} }
func (h Header) clone() Header { func (h Header) clone() Header {
...@@ -145,11 +150,16 @@ func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *h ...@@ -145,11 +150,16 @@ func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *h
// WriteSubset writes a header in wire format. // WriteSubset writes a header in wire format.
// If exclude is not nil, keys where exclude[key] == true are not written. // If exclude is not nil, keys where exclude[key] == true are not written.
func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
return h.writeSubset(w, exclude, nil)
}
func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error {
ws, ok := w.(writeStringer) ws, ok := w.(writeStringer)
if !ok { if !ok {
ws = stringWriter{w} ws = stringWriter{w}
} }
kvs, sorter := h.sortedKeyValues(exclude) kvs, sorter := h.sortedKeyValues(exclude)
var formattedVals []string
for _, kv := range kvs { for _, kv := range kvs {
for _, v := range kv.values { for _, v := range kv.values {
v = headerNewlineToSpace.Replace(v) v = headerNewlineToSpace.Replace(v)
...@@ -160,6 +170,13 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { ...@@ -160,6 +170,13 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
return err return err
} }
} }
if trace != nil && trace.WroteHeaderField != nil {
formattedVals = append(formattedVals, v)
}
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(kv.key, formattedVals)
formattedVals = nil
} }
} }
headerSorterPool.Put(sorter) headerSorterPool.Put(sorter)
......
...@@ -142,8 +142,12 @@ type ClientTrace struct { ...@@ -142,8 +142,12 @@ type ClientTrace struct {
// failure. // failure.
TLSHandshakeDone func(tls.ConnectionState, error) TLSHandshakeDone func(tls.ConnectionState, error)
// WroteHeaderField is called after the Transport has written
// each request header.
WroteHeaderField func(key string, value []string)
// WroteHeaders is called after the Transport has written // WroteHeaders is called after the Transport has written
// the request headers. // all request headers.
WroteHeaders func() WroteHeaders func()
// Wait100Continue is called if the Request specified // Wait100Continue is called if the Request specified
......
...@@ -555,6 +555,9 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF ...@@ -555,6 +555,9 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
if err != nil { if err != nil {
return err return err
} }
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Host", []string{host})
}
// Use the defaultUserAgent unless the Header contains one, which // Use the defaultUserAgent unless the Header contains one, which
// may be blank to not send the header. // may be blank to not send the header.
...@@ -567,6 +570,9 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF ...@@ -567,6 +570,9 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
if err != nil { if err != nil {
return err return err
} }
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("User-Agent", []string{userAgent})
}
} }
// Process Body,ContentLength,Close,Trailer // Process Body,ContentLength,Close,Trailer
...@@ -574,18 +580,18 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF ...@@ -574,18 +580,18 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
if err != nil { if err != nil {
return err return err
} }
err = tw.WriteHeader(w) err = tw.writeHeader(w, trace)
if err != nil { if err != nil {
return err return err
} }
err = r.Header.WriteSubset(w, reqWriteExcludeHeader) err = r.Header.writeSubset(w, reqWriteExcludeHeader, trace)
if err != nil { if err != nil {
return err return err
} }
if extraHeaders != nil { if extraHeaders != nil {
err = extraHeaders.Write(w) err = extraHeaders.write(w, trace)
if err != nil { if err != nil {
return err return err
} }
...@@ -624,7 +630,7 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF ...@@ -624,7 +630,7 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
} }
// Write body and trailer // Write body and trailer
err = tw.WriteBody(w) err = tw.writeBody(w)
if err != nil { if err != nil {
if tw.bodyReadError == err { if tw.bodyReadError == err {
err = requestBodyReadError{err} err = requestBodyReadError{err}
......
...@@ -293,7 +293,7 @@ func (r *Response) Write(w io.Writer) error { ...@@ -293,7 +293,7 @@ func (r *Response) Write(w io.Writer) error {
if err != nil { if err != nil {
return err return err
} }
err = tw.WriteHeader(w) err = tw.writeHeader(w, nil)
if err != nil { if err != nil {
return err return err
} }
...@@ -319,7 +319,7 @@ func (r *Response) Write(w io.Writer) error { ...@@ -319,7 +319,7 @@ func (r *Response) Write(w io.Writer) error {
} }
// Write body and trailer // Write body and trailer
err = tw.WriteBody(w) err = tw.writeBody(w)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -338,7 +338,7 @@ type chunkWriter struct { ...@@ -338,7 +338,7 @@ type chunkWriter struct {
res *response res *response
// header is either nil or a deep clone of res.handlerHeader // header is either nil or a deep clone of res.handlerHeader
// at the time of res.WriteHeader, if res.WriteHeader is // at the time of res.writeHeader, if res.writeHeader is
// called and extra buffering is being done to calculate // called and extra buffering is being done to calculate
// Content-Type and/or Content-Length. // Content-Type and/or Content-Length.
header Header header Header
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http/httptrace"
"net/http/internal" "net/http/internal"
"net/textproto" "net/textproto"
"reflect" "reflect"
...@@ -280,11 +281,14 @@ func (t *transferWriter) shouldSendContentLength() bool { ...@@ -280,11 +281,14 @@ func (t *transferWriter) shouldSendContentLength() bool {
return false return false
} }
func (t *transferWriter) WriteHeader(w io.Writer) error { func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) error {
if t.Close && !hasToken(t.Header.get("Connection"), "close") { if t.Close && !hasToken(t.Header.get("Connection"), "close") {
if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil {
return err return err
} }
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Connection", []string{"close"})
}
} }
// Write Content-Length and/or Transfer-Encoding whose values are a // Write Content-Length and/or Transfer-Encoding whose values are a
...@@ -297,10 +301,16 @@ func (t *transferWriter) WriteHeader(w io.Writer) error { ...@@ -297,10 +301,16 @@ func (t *transferWriter) WriteHeader(w io.Writer) error {
if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil { if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil {
return err return err
} }
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Content-Length", []string{strconv.FormatInt(t.ContentLength, 10)})
}
} else if chunked(t.TransferEncoding) { } else if chunked(t.TransferEncoding) {
if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil { if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil {
return err return err
} }
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Transfer-Encoding", []string{"chunked"})
}
} }
// Write Trailer header // Write Trailer header
...@@ -321,13 +331,16 @@ func (t *transferWriter) WriteHeader(w io.Writer) error { ...@@ -321,13 +331,16 @@ func (t *transferWriter) WriteHeader(w io.Writer) error {
if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil { if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil {
return err return err
} }
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Trailer", keys)
}
} }
} }
return nil return nil
} }
func (t *transferWriter) WriteBody(w io.Writer) error { func (t *transferWriter) writeBody(w io.Writer) error {
var err error var err error
var ncopy int64 var ncopy int64
......
...@@ -3733,7 +3733,9 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { ...@@ -3733,7 +3733,9 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
}) })
req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader("some body")) body := "some body"
req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
trace := &httptrace.ClientTrace{ trace := &httptrace.ClientTrace{
GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
...@@ -3748,6 +3750,12 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { ...@@ -3748,6 +3750,12 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
} }
logf("ConnectDone: connected to %s %s = %v", network, addr, err) logf("ConnectDone: connected to %s %s = %v", network, addr, err)
}, },
WroteHeaderField: func(key string, value []string) {
logf("WroteHeaderField: %s: %v", key, value)
},
WroteHeaders: func() {
logf("WroteHeaders")
},
Wait100Continue: func() { logf("Wait100Continue") }, Wait100Continue: func() { logf("Wait100Continue") },
Got100Continue: func() { logf("Got100Continue") }, Got100Continue: func() { logf("Got100Continue") },
WroteRequest: func(e httptrace.WroteRequestInfo) { WroteRequest: func(e httptrace.WroteRequestInfo) {
...@@ -3817,7 +3825,15 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { ...@@ -3817,7 +3825,15 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
wantOnce("tls handshake done") wantOnce("tls handshake done")
} else { } else {
wantOnce("PutIdleConn = <nil>") wantOnce("PutIdleConn = <nil>")
} wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
// TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
// WroteHeaderField hook is not yet implemented in h2.)
wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
}
wantOnce("WroteHeaders")
wantOnce("Wait100Continue") wantOnce("Wait100Continue")
wantOnce("Got100Continue") wantOnce("Got100Continue")
wantOnce("WroteRequest: {Err:<nil>}") wantOnce("WroteRequest: {Err:<nil>}")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment