Commit 3b988eb6 authored by Johan Brandhorst's avatar Johan Brandhorst Committed by Brad Fitzpatrick

net/http: use httptest.Server Client in tests

After merging https://go-review.googlesource.com/c/34639/,
it was pointed out to me that a lot of tests under net/http
could use the new functionality to simplify and unify testing.

Using the httptest.Server provided Client removes the need to
call CloseIdleConnections() on all Transports created, as it
is automatically called on the Transport associated with the
client when Server.Close() is called.

Change the transport used by the non-TLS
httptest.Server to a new *http.Transport rather than using
http.DefaultTransport implicitly. The TLS version already
used its own *http.Transport. This change is to prevent
concurrency problems with using DefaultTransport implicitly
across several httptest.Server's.

Add tests to ensure the httptest.Server.Client().Transport
RoundTripper interface is implemented by a *http.Transport,
as is now assumed across large parts of net/http tests.

Change-Id: I9f9d15f59d72893deead5678d314388718c91821
Reviewed-on: https://go-review.googlesource.com/37771
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 2bd6360e
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
...@@ -73,7 +72,7 @@ func TestClient(t *testing.T) { ...@@ -73,7 +72,7 @@ func TestClient(t *testing.T) {
ts := httptest.NewServer(robotsTxtHandler) ts := httptest.NewServer(robotsTxtHandler)
defer ts.Close() defer ts.Close()
c := &Client{Transport: &Transport{DisableKeepAlives: true}} c := ts.Client()
r, err := c.Get(ts.URL) r, err := c.Get(ts.URL)
var b []byte var b []byte
if err == nil { if err == nil {
...@@ -220,10 +219,7 @@ func TestClientRedirects(t *testing.T) { ...@@ -220,10 +219,7 @@ func TestClientRedirects(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
_, err := c.Get(ts.URL) _, err := c.Get(ts.URL)
if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
t.Errorf("with default client Get, expected error %q, got %q", e, g) t.Errorf("with default client Get, expected error %q, got %q", e, g)
...@@ -252,13 +248,10 @@ func TestClientRedirects(t *testing.T) { ...@@ -252,13 +248,10 @@ func TestClientRedirects(t *testing.T) {
var checkErr error var checkErr error
var lastVia []*Request var lastVia []*Request
var lastReq *Request var lastReq *Request
c = &Client{ c.CheckRedirect = func(req *Request, via []*Request) error {
Transport: tr, lastReq = req
CheckRedirect: func(req *Request, via []*Request) error { lastVia = via
lastReq = req return checkErr
lastVia = via
return checkErr
},
} }
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
...@@ -313,21 +306,16 @@ func TestClientRedirectContext(t *testing.T) { ...@@ -313,21 +306,16 @@ func TestClientRedirectContext(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{}
defer tr.CloseIdleConnections()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c := &Client{ c := ts.Client()
Transport: tr, c.CheckRedirect = func(req *Request, via []*Request) error {
CheckRedirect: func(req *Request, via []*Request) error { cancel()
cancel() select {
select { case <-req.Context().Done():
case <-req.Context().Done(): return nil
return nil case <-time.After(5 * time.Second):
case <-time.After(5 * time.Second): return errors.New("redirected request's context never expired after root request canceled")
return errors.New("redirected request's context never expired after root request canceled") }
}
},
} }
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
...@@ -461,11 +449,12 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa ...@@ -461,11 +449,12 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
for _, tt := range table { for _, tt := range table {
content := tt.redirectBody content := tt.redirectBody
req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content)) req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content))
req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil } req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil }
res, err := DefaultClient.Do(req) res, err := c.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -519,17 +508,12 @@ func TestClientRedirectUseResponse(t *testing.T) { ...@@ -519,17 +508,12 @@ func TestClientRedirectUseResponse(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections() c.CheckRedirect = func(req *Request, via []*Request) error {
if req.Response == nil {
c := &Client{ t.Error("expected non-nil Request.Response")
Transport: tr, }
CheckRedirect: func(req *Request, via []*Request) error { return ErrUseLastResponse
if req.Response == nil {
t.Error("expected non-nil Request.Response")
}
return ErrUseLastResponse
},
} }
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
...@@ -558,7 +542,7 @@ func TestClientRedirect308NoLocation(t *testing.T) { ...@@ -558,7 +542,7 @@ func TestClientRedirect308NoLocation(t *testing.T) {
w.WriteHeader(308) w.WriteHeader(308)
})) }))
defer ts.Close() defer ts.Close()
c := &Client{Transport: &Transport{DisableKeepAlives: true}} c := ts.Client()
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -586,7 +570,7 @@ func TestClientRedirect308NoGetBody(t *testing.T) { ...@@ -586,7 +570,7 @@ func TestClientRedirect308NoGetBody(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c := &Client{Transport: &Transport{DisableKeepAlives: true}} c := ts.Client()
req.GetBody = nil // so it can't rewind. req.GetBody = nil // so it can't rewind.
res, err := c.Do(req) res, err := c.Do(req)
if err != nil { if err != nil {
...@@ -678,12 +662,8 @@ func TestRedirectCookiesJar(t *testing.T) { ...@@ -678,12 +662,8 @@ func TestRedirectCookiesJar(t *testing.T) {
var ts *httptest.Server var ts *httptest.Server
ts = httptest.NewServer(echoCookiesRedirectHandler) ts = httptest.NewServer(echoCookiesRedirectHandler)
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections() c.Jar = new(TestJar)
c := &Client{
Transport: tr,
Jar: new(TestJar),
}
u, _ := url.Parse(ts.URL) u, _ := url.Parse(ts.URL)
c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
resp, err := c.Get(ts.URL) resp, err := c.Get(ts.URL)
...@@ -727,13 +707,10 @@ func TestJarCalls(t *testing.T) { ...@@ -727,13 +707,10 @@ func TestJarCalls(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
jar := new(RecordingJar) jar := new(RecordingJar)
c := &Client{ c := ts.Client()
Jar: jar, c.Jar = jar
Transport: &Transport{ c.Transport.(*Transport).Dial = func(_ string, _ string) (net.Conn, error) {
Dial: func(_ string, _ string) (net.Conn, error) { return net.Dial("tcp", ts.Listener.Addr().String())
return net.Dial("tcp", ts.Listener.Addr().String())
},
},
} }
_, err := c.Get("http://firsthost.fake/") _, err := c.Get("http://firsthost.fake/")
if err != nil { if err != nil {
...@@ -845,7 +822,8 @@ func TestClientWrites(t *testing.T) { ...@@ -845,7 +822,8 @@ func TestClientWrites(t *testing.T) {
} }
return c, err return c, err
} }
c := &Client{Transport: &Transport{Dial: dialer}} c := ts.Client()
c.Transport.(*Transport).Dial = dialer
_, err := c.Get(ts.URL) _, err := c.Get(ts.URL)
if err != nil { if err != nil {
...@@ -878,14 +856,11 @@ func TestClientInsecureTransport(t *testing.T) { ...@@ -878,14 +856,11 @@ func TestClientInsecureTransport(t *testing.T) {
// TODO(bradfitz): add tests for skipping hostname checks too? // TODO(bradfitz): add tests for skipping hostname checks too?
// would require a new cert for testing, and probably // would require a new cert for testing, and probably
// redundant with these tests. // redundant with these tests.
c := ts.Client()
for _, insecure := range []bool{true, false} { for _, insecure := range []bool{true, false} {
tr := &Transport{ c.Transport.(*Transport).TLSClientConfig = &tls.Config{
TLSClientConfig: &tls.Config{ InsecureSkipVerify: insecure,
InsecureSkipVerify: insecure,
},
} }
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if (err == nil) != insecure { if (err == nil) != insecure {
t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
...@@ -919,22 +894,6 @@ func TestClientErrorWithRequestURI(t *testing.T) { ...@@ -919,22 +894,6 @@ func TestClientErrorWithRequestURI(t *testing.T) {
} }
} }
func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
certs := x509.NewCertPool()
for _, c := range ts.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}
return &Transport{
TLSClientConfig: &tls.Config{RootCAs: certs},
}
}
func TestClientWithCorrectTLSServerName(t *testing.T) { func TestClientWithCorrectTLSServerName(t *testing.T) {
defer afterTest(t) defer afterTest(t)
...@@ -946,9 +905,8 @@ func TestClientWithCorrectTLSServerName(t *testing.T) { ...@@ -946,9 +905,8 @@ func TestClientWithCorrectTLSServerName(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
trans := newTLSTransport(t, ts) c := ts.Client()
trans.TLSClientConfig.ServerName = serverName c.Transport.(*Transport).TLSClientConfig.ServerName = serverName
c := &Client{Transport: trans}
if _, err := c.Get(ts.URL); err != nil { if _, err := c.Get(ts.URL); err != nil {
t.Fatalf("expected successful TLS connection, got error: %v", err) t.Fatalf("expected successful TLS connection, got error: %v", err)
} }
...@@ -961,9 +919,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { ...@@ -961,9 +919,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) {
errc := make(chanWriter, 10) // but only expecting 1 errc := make(chanWriter, 10) // but only expecting 1
ts.Config.ErrorLog = log.New(errc, "", 0) ts.Config.ErrorLog = log.New(errc, "", 0)
trans := newTLSTransport(t, ts) c := ts.Client()
trans.TLSClientConfig.ServerName = "badserver" c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver"
c := &Client{Transport: trans}
_, err := c.Get(ts.URL) _, err := c.Get(ts.URL)
if err == nil { if err == nil {
t.Fatalf("expected an error") t.Fatalf("expected an error")
...@@ -997,13 +954,12 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) { ...@@ -997,13 +954,12 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := newTLSTransport(t, ts) c := ts.Client()
tr := c.Transport.(*Transport)
tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names
tr.Dial = func(netw, addr string) (net.Conn, error) { tr.Dial = func(netw, addr string) (net.Conn, error) {
return net.Dial(netw, ts.Listener.Addr().String()) return net.Dial(netw, ts.Listener.Addr().String())
} }
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get("https://some-other-host.tld/") res, err := c.Get("https://some-other-host.tld/")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -1018,13 +974,12 @@ func TestResponseSetsTLSConnectionState(t *testing.T) { ...@@ -1018,13 +974,12 @@ func TestResponseSetsTLSConnectionState(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := newTLSTransport(t, ts) c := ts.Client()
tr := c.Transport.(*Transport)
tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA} tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}
tr.Dial = func(netw, addr string) (net.Conn, error) { tr.Dial = func(netw, addr string) (net.Conn, error) {
return net.Dial(netw, ts.Listener.Addr().String()) return net.Dial(netw, ts.Listener.Addr().String())
} }
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get("https://example.com/") res, err := c.Get("https://example.com/")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -1119,14 +1074,12 @@ func TestEmptyPasswordAuth(t *testing.T) { ...@@ -1119,14 +1074,12 @@ func TestEmptyPasswordAuth(t *testing.T) {
} }
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
req, err := NewRequest("GET", ts.URL, nil) req, err := NewRequest("GET", ts.URL, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
req.URL.User = url.User(gopher) req.URL.User = url.User(gopher)
c := ts.Client()
resp, err := c.Do(req) resp, err := c.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -1503,21 +1456,17 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { ...@@ -1503,21 +1456,17 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) {
defer ts2.Close() defer ts2.Close()
ts2URL = ts2.URL ts2URL = ts2.URL
tr := &Transport{} c := ts1.Client()
defer tr.CloseIdleConnections() c.CheckRedirect = func(r *Request, via []*Request) error {
c := &Client{ want := Header{
Transport: tr, "User-Agent": []string{ua},
CheckRedirect: func(r *Request, via []*Request) error { "X-Foo": []string{xfoo},
want := Header{ "Referer": []string{ts2URL},
"User-Agent": []string{ua}, }
"X-Foo": []string{xfoo}, if !reflect.DeepEqual(r.Header, want) {
"Referer": []string{ts2URL}, t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want)
} }
if !reflect.DeepEqual(r.Header, want) { return nil
t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want)
}
return nil
},
} }
req, _ := NewRequest("GET", ts2.URL, nil) req, _ := NewRequest("GET", ts2.URL, nil)
...@@ -1606,13 +1555,9 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) { ...@@ -1606,13 +1555,9 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{}
defer tr.CloseIdleConnections()
jar, _ := cookiejar.New(nil) jar, _ := cookiejar.New(nil)
c := &Client{ c := ts.Client()
Transport: tr, c.Jar = jar
Jar: jar,
}
u, _ := url.Parse(ts.URL) u, _ := url.Parse(ts.URL)
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
...@@ -1730,9 +1675,7 @@ func TestClientRedirectTypes(t *testing.T) { ...@@ -1730,9 +1675,7 @@ func TestClientRedirectTypes(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
for i, tt := range tests { for i, tt := range tests {
handlerc <- func(w ResponseWriter, r *Request) { handlerc <- func(w ResponseWriter, r *Request) {
w.Header().Set("Location", ts.URL) w.Header().Set("Location", ts.URL)
...@@ -1745,7 +1688,6 @@ func TestClientRedirectTypes(t *testing.T) { ...@@ -1745,7 +1688,6 @@ func TestClientRedirectTypes(t *testing.T) {
continue continue
} }
c := &Client{Transport: tr}
c.CheckRedirect = func(req *Request, via []*Request) error { c.CheckRedirect = func(req *Request, via []*Request) error {
if got, want := req.Method, tt.wantMethod; got != want { if got, want := req.Method, tt.wantMethod; got != want {
return fmt.Errorf("#%d: got next method %q; want %q", i, got, want) return fmt.Errorf("#%d: got next method %q; want %q", i, got, want)
...@@ -1799,9 +1741,8 @@ func TestTransportBodyReadError(t *testing.T) { ...@@ -1799,9 +1741,8 @@ func TestTransportBodyReadError(t *testing.T) {
w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err)) w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err))
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections() tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
// Do one initial successful request to create an idle TCP connection // Do one initial successful request to create an idle TCP connection
// for the subsequent request to reuse. (The Transport only retries // for the subsequent request to reuse. (The Transport only retries
......
...@@ -74,6 +74,7 @@ func TestServeFile(t *testing.T) { ...@@ -74,6 +74,7 @@ func TestServeFile(t *testing.T) {
ServeFile(w, r, "testdata/file") ServeFile(w, r, "testdata/file")
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
var err error var err error
...@@ -91,7 +92,7 @@ func TestServeFile(t *testing.T) { ...@@ -91,7 +92,7 @@ func TestServeFile(t *testing.T) {
req.Method = "GET" req.Method = "GET"
// straight GET // straight GET
_, body := getBody(t, "straight get", req) _, body := getBody(t, "straight get", req, c)
if !bytes.Equal(body, file) { if !bytes.Equal(body, file) {
t.Fatalf("body mismatch: got %q, want %q", body, file) t.Fatalf("body mismatch: got %q, want %q", body, file)
} }
...@@ -102,7 +103,7 @@ Cases: ...@@ -102,7 +103,7 @@ Cases:
if rt.r != "" { if rt.r != "" {
req.Header.Set("Range", rt.r) req.Header.Set("Range", rt.r)
} }
resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req) resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req, c)
if resp.StatusCode != rt.code { if resp.StatusCode != rt.code {
t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code) t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code)
} }
...@@ -704,7 +705,8 @@ func TestDirectoryIfNotModified(t *testing.T) { ...@@ -704,7 +705,8 @@ func TestDirectoryIfNotModified(t *testing.T) {
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
req.Header.Set("If-Modified-Since", lastMod) req.Header.Set("If-Modified-Since", lastMod)
res, err = DefaultClient.Do(req) c := ts.Client()
res, err = c.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -716,7 +718,7 @@ func TestDirectoryIfNotModified(t *testing.T) { ...@@ -716,7 +718,7 @@ func TestDirectoryIfNotModified(t *testing.T) {
// Advance the index.html file's modtime, but not the directory's. // Advance the index.html file's modtime, but not the directory's.
indexFile.modtime = indexFile.modtime.Add(1 * time.Hour) indexFile.modtime = indexFile.modtime.Add(1 * time.Hour)
res, err = DefaultClient.Do(req) res, err = c.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -995,7 +997,9 @@ func TestServeContent(t *testing.T) { ...@@ -995,7 +997,9 @@ func TestServeContent(t *testing.T) {
for k, v := range tt.reqHeader { for k, v := range tt.reqHeader {
req.Header.Set(k, v) req.Header.Set(k, v)
} }
res, err := DefaultClient.Do(req)
c := ts.Client()
res, err := c.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -1050,8 +1054,9 @@ func TestServeContentErrorMessages(t *testing.T) { ...@@ -1050,8 +1054,9 @@ func TestServeContentErrorMessages(t *testing.T) {
} }
ts := httptest.NewServer(FileServer(fs)) ts := httptest.NewServer(FileServer(fs))
defer ts.Close() defer ts.Close()
c := ts.Client()
for _, code := range []int{403, 404, 500} { for _, code := range []int{403, 404, 500} {
res, err := DefaultClient.Get(fmt.Sprintf("%s/%d", ts.URL, code)) res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code))
if err != nil { if err != nil {
t.Errorf("Error fetching /%d: %v", code, err) t.Errorf("Error fetching /%d: %v", code, err)
continue continue
...@@ -1125,8 +1130,8 @@ func TestLinuxSendfile(t *testing.T) { ...@@ -1125,8 +1130,8 @@ func TestLinuxSendfile(t *testing.T) {
} }
} }
func getBody(t *testing.T, testName string, req Request) (*Response, []byte) { func getBody(t *testing.T, testName string, req Request, client *Client) (*Response, []byte) {
r, err := DefaultClient.Do(&req) r, err := client.Do(&req)
if err != nil { if err != nil {
t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err) t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err)
} }
......
...@@ -93,7 +93,9 @@ func NewUnstartedServer(handler http.Handler) *Server { ...@@ -93,7 +93,9 @@ func NewUnstartedServer(handler http.Handler) *Server {
return &Server{ return &Server{
Listener: newLocalListener(), Listener: newLocalListener(),
Config: &http.Server{Handler: handler}, Config: &http.Server{Handler: handler},
client: &http.Client{}, client: &http.Client{
Transport: &http.Transport{},
},
} }
} }
......
...@@ -121,3 +121,27 @@ func TestServerClient(t *testing.T) { ...@@ -121,3 +121,27 @@ func TestServerClient(t *testing.T) {
t.Errorf("got %q, want hello", string(got)) t.Errorf("got %q, want hello", string(got))
} }
} }
// Tests that the Server.Client.Transport interface is implemented
// by a *http.Transport.
func TestServerClientTransportType(t *testing.T) {
ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
client := ts.Client()
if _, ok := client.Transport.(*http.Transport); !ok {
t.Errorf("got %T, want *http.Transport", client.Transport)
}
}
// Tests that the TLS Server.Client.Transport interface is implemented
// by a *http.Transport.
func TestTLSServerClientTransportType(t *testing.T) {
ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
client := ts.Client()
if _, ok := client.Transport.(*http.Transport); !ok {
t.Errorf("got %T, want *http.Transport", client.Transport)
}
}
...@@ -79,6 +79,7 @@ func TestReverseProxy(t *testing.T) { ...@@ -79,6 +79,7 @@ func TestReverseProxy(t *testing.T) {
proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
frontend := httptest.NewServer(proxyHandler) frontend := httptest.NewServer(proxyHandler)
defer frontend.Close() defer frontend.Close()
frontendClient := frontend.Client()
getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq, _ := http.NewRequest("GET", frontend.URL, nil)
getReq.Host = "some-name" getReq.Host = "some-name"
...@@ -86,7 +87,7 @@ func TestReverseProxy(t *testing.T) { ...@@ -86,7 +87,7 @@ func TestReverseProxy(t *testing.T) {
getReq.Header.Set("Proxy-Connection", "should be deleted") getReq.Header.Set("Proxy-Connection", "should be deleted")
getReq.Header.Set("Upgrade", "foo") getReq.Header.Set("Upgrade", "foo")
getReq.Close = true getReq.Close = true
res, err := http.DefaultClient.Do(getReq) res, err := frontendClient.Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -126,7 +127,7 @@ func TestReverseProxy(t *testing.T) { ...@@ -126,7 +127,7 @@ func TestReverseProxy(t *testing.T) {
// a response results in a StatusBadGateway. // a response results in a StatusBadGateway.
getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
getReq.Close = true getReq.Close = true
res, err = http.DefaultClient.Do(getReq) res, err = frontendClient.Do(getReq)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -172,7 +173,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { ...@@ -172,7 +173,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken) getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
getReq.Header.Set("Upgrade", "original value") getReq.Header.Set("Upgrade", "original value")
getReq.Header.Set(fakeConnectionToken, "should be deleted") getReq.Header.Set(fakeConnectionToken, "should be deleted")
res, err := http.DefaultClient.Do(getReq) res, err := frontend.Client().Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -220,7 +221,7 @@ func TestXForwardedFor(t *testing.T) { ...@@ -220,7 +221,7 @@ func TestXForwardedFor(t *testing.T) {
getReq.Header.Set("Connection", "close") getReq.Header.Set("Connection", "close")
getReq.Header.Set("X-Forwarded-For", prevForwardedFor) getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
getReq.Close = true getReq.Close = true
res, err := http.DefaultClient.Do(getReq) res, err := frontend.Client().Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -259,7 +260,7 @@ func TestReverseProxyQuery(t *testing.T) { ...@@ -259,7 +260,7 @@ func TestReverseProxyQuery(t *testing.T) {
frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
req.Close = true req.Close = true
res, err := http.DefaultClient.Do(req) res, err := frontend.Client().Do(req)
if err != nil { if err != nil {
t.Fatalf("%d. Get: %v", i, err) t.Fatalf("%d. Get: %v", i, err)
} }
...@@ -295,7 +296,7 @@ func TestReverseProxyFlushInterval(t *testing.T) { ...@@ -295,7 +296,7 @@ func TestReverseProxyFlushInterval(t *testing.T) {
req, _ := http.NewRequest("GET", frontend.URL, nil) req, _ := http.NewRequest("GET", frontend.URL, nil)
req.Close = true req.Close = true
res, err := http.DefaultClient.Do(req) res, err := frontend.Client().Do(req)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -349,13 +350,14 @@ func TestReverseProxyCancelation(t *testing.T) { ...@@ -349,13 +350,14 @@ func TestReverseProxyCancelation(t *testing.T) {
frontend := httptest.NewServer(proxyHandler) frontend := httptest.NewServer(proxyHandler)
defer frontend.Close() defer frontend.Close()
frontendClient := frontend.Client()
getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq, _ := http.NewRequest("GET", frontend.URL, nil)
go func() { go func() {
<-reqInFlight <-reqInFlight
http.DefaultTransport.(*http.Transport).CancelRequest(getReq) frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
}() }()
res, err := http.DefaultClient.Do(getReq) res, err := frontendClient.Do(getReq)
if res != nil { if res != nil {
t.Errorf("got response %v; want nil", res.Status) t.Errorf("got response %v; want nil", res.Status)
} }
...@@ -363,7 +365,7 @@ func TestReverseProxyCancelation(t *testing.T) { ...@@ -363,7 +365,7 @@ func TestReverseProxyCancelation(t *testing.T) {
// This should be an error like: // This should be an error like:
// Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079: // Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
// use of closed network connection // use of closed network connection
t.Error("DefaultClient.Do() returned nil error; want non-nil error") t.Error("Server.Client().Do() returned nil error; want non-nil error")
} }
} }
...@@ -428,11 +430,12 @@ func TestUserAgentHeader(t *testing.T) { ...@@ -428,11 +430,12 @@ func TestUserAgentHeader(t *testing.T) {
proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
frontend := httptest.NewServer(proxyHandler) frontend := httptest.NewServer(proxyHandler)
defer frontend.Close() defer frontend.Close()
frontendClient := frontend.Client()
getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq, _ := http.NewRequest("GET", frontend.URL, nil)
getReq.Header.Set("User-Agent", explicitUA) getReq.Header.Set("User-Agent", explicitUA)
getReq.Close = true getReq.Close = true
res, err := http.DefaultClient.Do(getReq) res, err := frontendClient.Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -441,7 +444,7 @@ func TestUserAgentHeader(t *testing.T) { ...@@ -441,7 +444,7 @@ func TestUserAgentHeader(t *testing.T) {
getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
getReq.Header.Set("User-Agent", "") getReq.Header.Set("User-Agent", "")
getReq.Close = true getReq.Close = true
res, err = http.DefaultClient.Do(getReq) res, err = frontendClient.Do(getReq)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -493,7 +496,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) { ...@@ -493,7 +496,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) {
req, _ := http.NewRequest("GET", frontend.URL, nil) req, _ := http.NewRequest("GET", frontend.URL, nil)
req.Close = true req.Close = true
res, err := http.DefaultClient.Do(req) res, err := frontend.Client().Do(req)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -540,7 +543,7 @@ func TestReverseProxy_Post(t *testing.T) { ...@@ -540,7 +543,7 @@ func TestReverseProxy_Post(t *testing.T) {
defer frontend.Close() defer frontend.Close()
postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
res, err := http.DefaultClient.Do(postReq) res, err := frontend.Client().Do(postReq)
if err != nil { if err != nil {
t.Fatalf("Do: %v", err) t.Fatalf("Do: %v", err)
} }
...@@ -573,7 +576,7 @@ func TestReverseProxy_NilBody(t *testing.T) { ...@@ -573,7 +576,7 @@ func TestReverseProxy_NilBody(t *testing.T) {
frontend := httptest.NewServer(proxyHandler) frontend := httptest.NewServer(proxyHandler)
defer frontend.Close() defer frontend.Close()
res, err := http.DefaultClient.Get(frontend.URL) res, err := frontend.Client().Get(frontend.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -151,7 +151,3 @@ func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error ...@@ -151,7 +151,3 @@ func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error
} }
return err return err
} }
func closeClient(c *http.Client) {
c.Transport.(*http.Transport).CloseIdleConnections()
}
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -43,10 +44,7 @@ func TestNextProtoUpgrade(t *testing.T) { ...@@ -43,10 +44,7 @@ func TestNextProtoUpgrade(t *testing.T) {
// Normal request, without NPN. // Normal request, without NPN.
{ {
tr := newTLSTransport(t, ts) c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -63,11 +61,18 @@ func TestNextProtoUpgrade(t *testing.T) { ...@@ -63,11 +61,18 @@ func TestNextProtoUpgrade(t *testing.T) {
// Request to an advertised but unhandled NPN protocol. // Request to an advertised but unhandled NPN protocol.
// Server will hang up. // Server will hang up.
{ {
tr := newTLSTransport(t, ts) certPool := x509.NewCertPool()
tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} certPool.AddCert(ts.Certificate())
tr := &Transport{
TLSClientConfig: &tls.Config{
RootCAs: certPool,
NextProtos: []string{"unhandled-proto"},
},
}
defer tr.CloseIdleConnections() defer tr.CloseIdleConnections()
c := &Client{Transport: tr} c := &Client{
Transport: tr,
}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err == nil { if err == nil {
defer res.Body.Close() defer res.Body.Close()
...@@ -80,7 +85,8 @@ func TestNextProtoUpgrade(t *testing.T) { ...@@ -80,7 +85,8 @@ func TestNextProtoUpgrade(t *testing.T) {
// Request using the "tls-0.9" protocol, which we register here. // Request using the "tls-0.9" protocol, which we register here.
// It is HTTP/0.9 over TLS. // It is HTTP/0.9 over TLS.
{ {
tlsConfig := newTLSTransport(t, ts).TLSClientConfig c := ts.Client()
tlsConfig := c.Transport.(*Transport).TLSClientConfig
tlsConfig.NextProtos = []string{"tls-0.9"} tlsConfig.NextProtos = []string{"tls-0.9"}
conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
if err != nil { if err != nil {
......
...@@ -474,9 +474,7 @@ func TestServerTimeouts(t *testing.T) { ...@@ -474,9 +474,7 @@ func TestServerTimeouts(t *testing.T) {
defer ts.Close() defer ts.Close()
// Hit the HTTP server successfully. // Hit the HTTP server successfully.
tr := &Transport{DisableKeepAlives: true} // they interfere with this test c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
r, err := c.Get(ts.URL) r, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatalf("http Get #1: %v", err) t.Fatalf("http Get #1: %v", err)
...@@ -548,12 +546,10 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { ...@@ -548,12 +546,10 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) {
ts.StartTLS() ts.StartTLS()
defer ts.Close() defer ts.Close()
tr := newTLSTransport(t, ts) c := ts.Client()
defer tr.CloseIdleConnections() if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil {
if err := ExportHttp2ConfigureTransport(tr); err != nil {
t.Fatal(err) t.Fatal(err)
} }
c := &Client{Transport: tr}
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
req, err := NewRequest("GET", ts.URL, nil) req, err := NewRequest("GET", ts.URL, nil)
...@@ -608,9 +604,7 @@ func TestOnlyWriteTimeout(t *testing.T) { ...@@ -608,9 +604,7 @@ func TestOnlyWriteTimeout(t *testing.T) {
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
tr := &Transport{DisableKeepAlives: false} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
errc := make(chan error) errc := make(chan error)
go func() { go func() {
...@@ -671,8 +665,7 @@ func TestIdentityResponse(t *testing.T) { ...@@ -671,8 +665,7 @@ func TestIdentityResponse(t *testing.T) {
ts := httptest.NewServer(handler) ts := httptest.NewServer(handler)
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
// Note: this relies on the assumption (which is true) that // Note: this relies on the assumption (which is true) that
// Get sends HTTP/1.1 or greater requests. Otherwise the // Get sends HTTP/1.1 or greater requests. Otherwise the
...@@ -949,9 +942,8 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { ...@@ -949,9 +942,8 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
tr := &Transport{DisableKeepAlives: true} c := ts.Client()
defer tr.CloseIdleConnections() c.Timeout = time.Second
c := &Client{Transport: tr, Timeout: time.Second}
fetch := func(num int, response chan<- string) { fetch := func(num int, response chan<- string) {
resp, err := c.Get(ts.URL) resp, err := c.Get(ts.URL)
...@@ -1022,9 +1014,7 @@ func TestIdentityResponseHeaders(t *testing.T) { ...@@ -1022,9 +1014,7 @@ func TestIdentityResponseHeaders(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatalf("Get error: %v", err) t.Fatalf("Get error: %v", err)
...@@ -1145,12 +1135,7 @@ func TestTLSServer(t *testing.T) { ...@@ -1145,12 +1135,7 @@ func TestTLSServer(t *testing.T) {
t.Errorf("expected test TLS server to start with https://, got %q", ts.URL) t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
return return
} }
noVerifyTransport := &Transport{ client := ts.Client()
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
client := &Client{Transport: noVerifyTransport}
res, err := client.Get(ts.URL) res, err := client.Get(ts.URL)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -1967,8 +1952,7 @@ func TestTimeoutHandlerRace(t *testing.T) { ...@@ -1967,8 +1952,7 @@ func TestTimeoutHandlerRace(t *testing.T) {
ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
var wg sync.WaitGroup var wg sync.WaitGroup
gate := make(chan bool, 10) gate := make(chan bool, 10)
...@@ -2011,8 +1995,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { ...@@ -2011,8 +1995,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) {
if testing.Short() { if testing.Short() {
n = 10 n = 10
} }
c := &Client{Transport: new(Transport)}
defer closeClient(c) c := ts.Client()
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
gate <- true gate <- true
wg.Add(1) wg.Add(1)
...@@ -2099,8 +2083,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { ...@@ -2099,8 +2083,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
// Issue was caused by the timeout handler starting the timer when // Issue was caused by the timeout handler starting the timer when
// was created, not when the request. So wait for more than the timeout // was created, not when the request. So wait for more than the timeout
...@@ -2127,8 +2110,7 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) { ...@@ -2127,8 +2110,7 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) {
ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
...@@ -2364,9 +2346,7 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { ...@@ -2364,9 +2346,7 @@ func TestServerWriteHijackZeroBytes(t *testing.T) {
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -2411,8 +2391,7 @@ func TestStripPrefix(t *testing.T) { ...@@ -2411,8 +2391,7 @@ func TestStripPrefix(t *testing.T) {
ts := httptest.NewServer(StripPrefix("/foo", h)) ts := httptest.NewServer(StripPrefix("/foo", h))
defer ts.Close() defer ts.Close()
c := &Client{Transport: new(Transport)} c := ts.Client()
defer closeClient(c)
res, err := c.Get(ts.URL + "/foo/bar") res, err := c.Get(ts.URL + "/foo/bar")
if err != nil { if err != nil {
...@@ -3654,9 +3633,7 @@ func TestServerConnState(t *testing.T) { ...@@ -3654,9 +3633,7 @@ func TestServerConnState(t *testing.T) {
} }
ts.Start() ts.Start()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
mustGet := func(url string, headers ...string) { mustGet := func(url string, headers ...string) {
req, err := NewRequest("GET", url, nil) req, err := NewRequest("GET", url, nil)
...@@ -4491,15 +4468,9 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { ...@@ -4491,15 +4468,9 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) {
b.ResetTimer() b.ResetTimer()
b.SetParallelism(parallelism) b.SetParallelism(parallelism)
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
noVerifyTransport := &Transport{ c := ts.Client()
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
defer noVerifyTransport.CloseIdleConnections()
client := &Client{Transport: noVerifyTransport}
for pb.Next() { for pb.Next() {
res, err := client.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
b.Logf("Get: %v", err) b.Logf("Get: %v", err)
continue continue
...@@ -4934,10 +4905,7 @@ func TestServerIdleTimeout(t *testing.T) { ...@@ -4934,10 +4905,7 @@ func TestServerIdleTimeout(t *testing.T) {
ts.Config.IdleTimeout = 2 * time.Second ts.Config.IdleTimeout = 2 * time.Second
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
get := func() string { get := func() string {
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
...@@ -4998,9 +4966,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { ...@@ -4998,9 +4966,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections() tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
get := func() string { return get(t, c, ts.URL) } get := func() string { return get(t, c, ts.URL) }
...@@ -5119,9 +5086,7 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { ...@@ -5119,9 +5086,7 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) {
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
......
...@@ -131,11 +131,9 @@ func TestTransportKeepAlives(t *testing.T) { ...@@ -131,11 +131,9 @@ func TestTransportKeepAlives(t *testing.T) {
ts := httptest.NewServer(hostPortHandler) ts := httptest.NewServer(hostPortHandler)
defer ts.Close() defer ts.Close()
c := ts.Client()
for _, disableKeepAlive := range []bool{false, true} { for _, disableKeepAlive := range []bool{false, true} {
tr := &Transport{DisableKeepAlives: disableKeepAlive} c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
fetch := func(n int) string { fetch := func(n int) string {
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
...@@ -166,12 +164,11 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { ...@@ -166,12 +164,11 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) {
connSet, testDial := makeTestDial(t) connSet, testDial := makeTestDial(t)
for _, connectionClose := range []bool{false, true} { c := ts.Client()
tr := &Transport{ tr := c.Transport.(*Transport)
Dial: testDial, tr.Dial = testDial
}
c := &Client{Transport: tr}
for _, connectionClose := range []bool{false, true} {
fetch := func(n int) string { fetch := func(n int) string {
req := new(Request) req := new(Request)
var err error var err error
...@@ -217,12 +214,10 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { ...@@ -217,12 +214,10 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) {
connSet, testDial := makeTestDial(t) connSet, testDial := makeTestDial(t)
c := ts.Client()
tr := c.Transport.(*Transport)
tr.Dial = testDial
for _, connectionClose := range []bool{false, true} { for _, connectionClose := range []bool{false, true} {
tr := &Transport{
Dial: testDial,
}
c := &Client{Transport: tr}
fetch := func(n int) string { fetch := func(n int) string {
req := new(Request) req := new(Request)
var err error var err error
...@@ -273,10 +268,9 @@ func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { ...@@ -273,10 +268,9 @@ func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
ts := httptest.NewServer(hostPortHandler) ts := httptest.NewServer(hostPortHandler)
defer ts.Close() defer ts.Close()
tr := &Transport{ c := ts.Client()
DisableKeepAlives: true, c.Transport.(*Transport).DisableKeepAlives = true
}
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -291,9 +285,8 @@ func TestTransportIdleCacheKeys(t *testing.T) { ...@@ -291,9 +285,8 @@ func TestTransportIdleCacheKeys(t *testing.T) {
defer afterTest(t) defer afterTest(t)
ts := httptest.NewServer(hostPortHandler) ts := httptest.NewServer(hostPortHandler)
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{DisableKeepAlives: false} tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
...@@ -385,9 +378,11 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { ...@@ -385,9 +378,11 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
} }
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := c.Transport.(*Transport)
maxIdleConnsPerHost := 2 maxIdleConnsPerHost := 2
tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConnsPerHost} tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
c := &Client{Transport: tr}
// Start 3 outstanding requests and wait for the server to get them. // Start 3 outstanding requests and wait for the server to get them.
// Their responses will hang until we write to resch, though. // Their responses will hang until we write to resch, though.
...@@ -450,9 +445,8 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { ...@@ -450,9 +445,8 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections() tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
doReq := func(name string) string { doReq := func(name string) string {
// Do a POST instead of a GET to prevent the Transport's // Do a POST instead of a GET to prevent the Transport's
...@@ -496,9 +490,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { ...@@ -496,9 +490,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) {
defer afterTest(t) defer afterTest(t)
ts := httptest.NewServer(hostPortHandler) ts := httptest.NewServer(hostPortHandler)
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{}
c := &Client{Transport: tr}
fetch := func(n, retries int) string { fetch := func(n, retries int) string {
condFatalf := func(format string, arg ...interface{}) { condFatalf := func(format string, arg ...interface{}) {
...@@ -564,10 +556,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { ...@@ -564,10 +556,7 @@ func TestStressSurpriseServerCloses(t *testing.T) {
conn.Close() conn.Close()
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{DisableKeepAlives: false}
c := &Client{Transport: tr}
defer tr.CloseIdleConnections()
// Do a bunch of traffic from different goroutines. Send to activityc // Do a bunch of traffic from different goroutines. Send to activityc
// after each request completes, regardless of whether it failed. // after each request completes, regardless of whether it failed.
...@@ -620,9 +609,8 @@ func TestTransportHeadResponses(t *testing.T) { ...@@ -620,9 +609,8 @@ func TestTransportHeadResponses(t *testing.T) {
w.WriteHeader(200) w.WriteHeader(200)
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{DisableKeepAlives: false}
c := &Client{Transport: tr}
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
res, err := c.Head(ts.URL) res, err := c.Head(ts.URL)
if err != nil { if err != nil {
...@@ -656,10 +644,7 @@ func TestTransportHeadChunkedResponse(t *testing.T) { ...@@ -656,10 +644,7 @@ func TestTransportHeadChunkedResponse(t *testing.T) {
w.WriteHeader(200) w.WriteHeader(200)
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{DisableKeepAlives: false}
c := &Client{Transport: tr}
defer tr.CloseIdleConnections()
// Ensure that we wait for the readLoop to complete before // Ensure that we wait for the readLoop to complete before
// calling Head again // calling Head again
...@@ -720,6 +705,7 @@ func TestRoundTripGzip(t *testing.T) { ...@@ -720,6 +705,7 @@ func TestRoundTripGzip(t *testing.T) {
} }
})) }))
defer ts.Close() defer ts.Close()
tr := ts.Client().Transport.(*Transport)
for i, test := range roundTripTests { for i, test := range roundTripTests {
// Test basic request (no accept-encoding) // Test basic request (no accept-encoding)
...@@ -727,7 +713,7 @@ func TestRoundTripGzip(t *testing.T) { ...@@ -727,7 +713,7 @@ func TestRoundTripGzip(t *testing.T) {
if test.accept != "" { if test.accept != "" {
req.Header.Set("Accept-Encoding", test.accept) req.Header.Set("Accept-Encoding", test.accept)
} }
res, err := DefaultTransport.RoundTrip(req) res, err := tr.RoundTrip(req)
var body []byte var body []byte
if test.compressed { if test.compressed {
var r *gzip.Reader var r *gzip.Reader
...@@ -792,10 +778,9 @@ func TestTransportGzip(t *testing.T) { ...@@ -792,10 +778,9 @@ func TestTransportGzip(t *testing.T) {
gz.Close() gz.Close()
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
for _, chunked := range []string{"1", "0"} { for _, chunked := range []string{"1", "0"} {
c := &Client{Transport: &Transport{}}
// First fetch something large, but only read some of it. // First fetch something large, but only read some of it.
res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
if err != nil { if err != nil {
...@@ -845,7 +830,6 @@ func TestTransportGzip(t *testing.T) { ...@@ -845,7 +830,6 @@ func TestTransportGzip(t *testing.T) {
} }
// And a HEAD request too, because they're always weird. // And a HEAD request too, because they're always weird.
c := &Client{Transport: &Transport{}}
res, err := c.Head(ts.URL) res, err := c.Head(ts.URL)
if err != nil { if err != nil {
t.Fatalf("Head: %v", err) t.Fatalf("Head: %v", err)
...@@ -915,11 +899,13 @@ func TestTransportExpect100Continue(t *testing.T) { ...@@ -915,11 +899,13 @@ func TestTransportExpect100Continue(t *testing.T) {
{path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent.
} }
c := ts.Client()
for i, v := range tests { for i, v := range tests {
tr := &Transport{ExpectContinueTimeout: 2 * time.Second} tr := &Transport{
ExpectContinueTimeout: 2 * time.Second,
}
defer tr.CloseIdleConnections() defer tr.CloseIdleConnections()
c := &Client{Transport: tr} c.Transport = tr
body := bytes.NewReader(v.body) body := bytes.NewReader(v.body)
req, err := NewRequest("PUT", ts.URL+v.path, body) req, err := NewRequest("PUT", ts.URL+v.path, body)
if err != nil { if err != nil {
...@@ -1016,7 +1002,8 @@ func TestSocks5Proxy(t *testing.T) { ...@@ -1016,7 +1002,8 @@ func TestSocks5Proxy(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}} c := ts.Client()
c.Transport.(*Transport).Proxy = ProxyURL(pu)
if _, err := c.Head(ts.URL); err != nil { if _, err := c.Head(ts.URL); err != nil {
t.Error(err) t.Error(err)
} }
...@@ -1052,7 +1039,8 @@ func TestTransportProxy(t *testing.T) { ...@@ -1052,7 +1039,8 @@ func TestTransportProxy(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}} c := ts.Client()
c.Transport.(*Transport).Proxy = ProxyURL(pu)
if _, err := c.Head(ts.URL); err != nil { if _, err := c.Head(ts.URL); err != nil {
t.Error(err) t.Error(err)
} }
...@@ -1122,9 +1110,7 @@ func TestTransportGzipRecursive(t *testing.T) { ...@@ -1122,9 +1110,7 @@ func TestTransportGzipRecursive(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -1152,9 +1138,7 @@ func TestTransportGzipShort(t *testing.T) { ...@@ -1152,9 +1138,7 @@ func TestTransportGzipShort(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -1195,9 +1179,8 @@ func TestTransportPersistConnLeak(t *testing.T) { ...@@ -1195,9 +1179,8 @@ func TestTransportPersistConnLeak(t *testing.T) {
w.WriteHeader(204) w.WriteHeader(204)
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{} tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
n0 := runtime.NumGoroutine() n0 := runtime.NumGoroutine()
...@@ -1260,9 +1243,8 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { ...@@ -1260,9 +1243,8 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{} tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
n0 := runtime.NumGoroutine() n0 := runtime.NumGoroutine()
body := []byte("Hello") body := []byte("Hello")
...@@ -1294,8 +1276,7 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { ...@@ -1294,8 +1276,7 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) {
// This used to crash; https://golang.org/issue/3266 // This used to crash; https://golang.org/issue/3266
func TestTransportIdleConnCrash(t *testing.T) { func TestTransportIdleConnCrash(t *testing.T) {
defer afterTest(t) defer afterTest(t)
tr := &Transport{} var tr *Transport
c := &Client{Transport: tr}
unblockCh := make(chan bool, 1) unblockCh := make(chan bool, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
...@@ -1303,6 +1284,8 @@ func TestTransportIdleConnCrash(t *testing.T) { ...@@ -1303,6 +1284,8 @@ func TestTransportIdleConnCrash(t *testing.T) {
tr.CloseIdleConnections() tr.CloseIdleConnections()
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
tr = c.Transport.(*Transport)
didreq := make(chan bool) didreq := make(chan bool)
go func() { go func() {
...@@ -1332,8 +1315,7 @@ func TestIssue3644(t *testing.T) { ...@@ -1332,8 +1315,7 @@ func TestIssue3644(t *testing.T) {
} }
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
c := &Client{Transport: tr}
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -1358,8 +1340,7 @@ func TestIssue3595(t *testing.T) { ...@@ -1358,8 +1340,7 @@ func TestIssue3595(t *testing.T) {
Error(w, deniedMsg, StatusUnauthorized) Error(w, deniedMsg, StatusUnauthorized)
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
c := &Client{Transport: tr}
res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
if err != nil { if err != nil {
t.Errorf("Post: %v", err) t.Errorf("Post: %v", err)
...@@ -1383,8 +1364,8 @@ func TestChunkedNoContent(t *testing.T) { ...@@ -1383,8 +1364,8 @@ func TestChunkedNoContent(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
for _, closeBody := range []bool{true, false} { for _, closeBody := range []bool{true, false} {
c := &Client{Transport: &Transport{}}
const n = 4 const n = 4
for i := 1; i <= n; i++ { for i := 1; i <= n; i++ {
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
...@@ -1424,10 +1405,7 @@ func TestTransportConcurrency(t *testing.T) { ...@@ -1424,10 +1405,7 @@ func TestTransportConcurrency(t *testing.T) {
SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
defer SetPendingDialHooks(nil, nil) defer SetPendingDialHooks(nil, nil)
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
reqs := make(chan string) reqs := make(chan string)
defer close(reqs) defer close(reqs)
...@@ -1469,23 +1447,20 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { ...@@ -1469,23 +1447,20 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
io.Copy(w, neverEnding('a')) io.Copy(w, neverEnding('a'))
}) })
ts := httptest.NewServer(mux) ts := httptest.NewServer(mux)
defer ts.Close()
timeout := 100 * time.Millisecond timeout := 100 * time.Millisecond
client := &Client{ c := ts.Client()
Transport: &Transport{ c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
Dial: func(n, addr string) (net.Conn, error) { conn, err := net.Dial(n, addr)
conn, err := net.Dial(n, addr) if err != nil {
if err != nil { return nil, err
return nil, err }
} conn.SetDeadline(time.Now().Add(timeout))
conn.SetDeadline(time.Now().Add(timeout)) if debug {
if debug { conn = NewLoggingConn("client", conn)
conn = NewLoggingConn("client", conn) }
} return conn, nil
return conn, nil
},
DisableKeepAlives: true,
},
} }
getFailed := false getFailed := false
...@@ -1497,7 +1472,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { ...@@ -1497,7 +1472,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
if debug { if debug {
println("run", i+1, "of", nRuns) println("run", i+1, "of", nRuns)
} }
sres, err := client.Get(ts.URL + "/get") sres, err := c.Get(ts.URL + "/get")
if err != nil { if err != nil {
if !getFailed { if !getFailed {
// Make the timeout longer, once. // Make the timeout longer, once.
...@@ -1519,7 +1494,6 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { ...@@ -1519,7 +1494,6 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
if debug { if debug {
println("tests complete; waiting for handlers to finish") println("tests complete; waiting for handlers to finish")
} }
ts.Close()
} }
func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
...@@ -1537,21 +1511,17 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ...@@ -1537,21 +1511,17 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
ts := httptest.NewServer(mux) ts := httptest.NewServer(mux)
timeout := 100 * time.Millisecond timeout := 100 * time.Millisecond
client := &Client{ c := ts.Client()
Transport: &Transport{ c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
Dial: func(n, addr string) (net.Conn, error) { conn, err := net.Dial(n, addr)
conn, err := net.Dial(n, addr) if err != nil {
if err != nil { return nil, err
return nil, err }
} conn.SetDeadline(time.Now().Add(timeout))
conn.SetDeadline(time.Now().Add(timeout)) if debug {
if debug { conn = NewLoggingConn("client", conn)
conn = NewLoggingConn("client", conn) }
} return conn, nil
return conn, nil
},
DisableKeepAlives: true,
},
} }
getFailed := false getFailed := false
...@@ -1563,7 +1533,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ...@@ -1563,7 +1533,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
if debug { if debug {
println("run", i+1, "of", nRuns) println("run", i+1, "of", nRuns)
} }
sres, err := client.Get(ts.URL + "/get") sres, err := c.Get(ts.URL + "/get")
if err != nil { if err != nil {
if !getFailed { if !getFailed {
// Make the timeout longer, once. // Make the timeout longer, once.
...@@ -1577,7 +1547,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ...@@ -1577,7 +1547,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
break break
} }
req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
_, err = client.Do(req) _, err = c.Do(req)
if err == nil { if err == nil {
sres.Body.Close() sres.Body.Close()
t.Errorf("Unexpected successful PUT") t.Errorf("Unexpected successful PUT")
...@@ -1609,11 +1579,8 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { ...@@ -1609,11 +1579,8 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
ts := httptest.NewServer(mux) ts := httptest.NewServer(mux)
defer ts.Close() defer ts.Close()
tr := &Transport{ c := ts.Client()
ResponseHeaderTimeout: 500 * time.Millisecond, c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond
}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
tests := []struct { tests := []struct {
path string path string
...@@ -1680,9 +1647,8 @@ func TestTransportCancelRequest(t *testing.T) { ...@@ -1680,9 +1647,8 @@ func TestTransportCancelRequest(t *testing.T) {
defer ts.Close() defer ts.Close()
defer close(unblockc) defer close(unblockc)
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections() tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
res, err := c.Do(req) res, err := c.Do(req)
...@@ -1790,9 +1756,8 @@ func TestCancelRequestWithChannel(t *testing.T) { ...@@ -1790,9 +1756,8 @@ func TestCancelRequestWithChannel(t *testing.T) {
defer ts.Close() defer ts.Close()
defer close(unblockc) defer close(unblockc)
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections() tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
ch := make(chan struct{}) ch := make(chan struct{})
...@@ -1849,9 +1814,7 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { ...@@ -1849,9 +1814,7 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
defer ts.Close() defer ts.Close()
defer close(unblockc) defer close(unblockc)
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
if withCtx { if withCtx {
...@@ -1939,9 +1902,8 @@ func TestTransportCloseResponseBody(t *testing.T) { ...@@ -1939,9 +1902,8 @@ func TestTransportCloseResponseBody(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections() tr := c.Transport.(*Transport)
c := &Client{Transport: tr}
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
defer tr.CancelRequest(req) defer tr.CancelRequest(req)
...@@ -2061,18 +2023,12 @@ func TestTransportSocketLateBinding(t *testing.T) { ...@@ -2061,18 +2023,12 @@ func TestTransportSocketLateBinding(t *testing.T) {
defer ts.Close() defer ts.Close()
dialGate := make(chan bool, 1) dialGate := make(chan bool, 1)
tr := &Transport{ c := ts.Client()
Dial: func(n, addr string) (net.Conn, error) { c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
if <-dialGate { if <-dialGate {
return net.Dial(n, addr) return net.Dial(n, addr)
} }
return nil, errors.New("manually closed") return nil, errors.New("manually closed")
},
DisableKeepAlives: false,
}
defer tr.CloseIdleConnections()
c := &Client{
Transport: tr,
} }
dialGate <- true // only allow one dial dialGate <- true // only allow one dial
...@@ -2326,14 +2282,11 @@ func TestIdleConnChannelLeak(t *testing.T) { ...@@ -2326,14 +2282,11 @@ func TestIdleConnChannelLeak(t *testing.T) {
SetReadLoopBeforeNextReadHook(func() { didRead <- true }) SetReadLoopBeforeNextReadHook(func() { didRead <- true })
defer SetReadLoopBeforeNextReadHook(nil) defer SetReadLoopBeforeNextReadHook(nil)
tr := &Transport{ c := ts.Client()
Dial: func(netw, addr string) (net.Conn, error) { tr := c.Transport.(*Transport)
return net.Dial(netw, ts.Listener.Addr().String()) tr.Dial = func(netw, addr string) (net.Conn, error) {
}, return net.Dial(netw, ts.Listener.Addr().String())
} }
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
// First, without keep-alives. // First, without keep-alives.
for _, disableKeep := range []bool{true, false} { for _, disableKeep := range []bool{true, false} {
...@@ -2376,13 +2329,11 @@ func TestTransportClosesRequestBody(t *testing.T) { ...@@ -2376,13 +2329,11 @@ func TestTransportClosesRequestBody(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
cl := &Client{Transport: tr}
closes := 0 closes := 0
res, err := cl.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -2468,20 +2419,16 @@ func TestTLSServerClosesConnection(t *testing.T) { ...@@ -2468,20 +2419,16 @@ func TestTLSServerClosesConnection(t *testing.T) {
fmt.Fprintf(w, "hello") fmt.Fprintf(w, "hello")
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{
TLSClientConfig: &tls.Config{ c := ts.Client()
InsecureSkipVerify: true, tr := c.Transport.(*Transport)
},
}
defer tr.CloseIdleConnections()
client := &Client{Transport: tr}
var nSuccess = 0 var nSuccess = 0
var errs []error var errs []error
const trials = 20 const trials = 20
for i := 0; i < trials; i++ { for i := 0; i < trials; i++ {
tr.CloseIdleConnections() tr.CloseIdleConnections()
res, err := client.Get(ts.URL + "/keep-alive-then-die") res, err := c.Get(ts.URL + "/keep-alive-then-die")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -2496,7 +2443,7 @@ func TestTLSServerClosesConnection(t *testing.T) { ...@@ -2496,7 +2443,7 @@ func TestTLSServerClosesConnection(t *testing.T) {
// Now try again and see if we successfully // Now try again and see if we successfully
// pick a new connection. // pick a new connection.
res, err = client.Get(ts.URL + "/") res, err = c.Get(ts.URL + "/")
if err != nil { if err != nil {
errs = append(errs, err) errs = append(errs, err)
continue continue
...@@ -2575,22 +2522,20 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { ...@@ -2575,22 +2522,20 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
go io.Copy(ioutil.Discard, conn) go io.Copy(ioutil.Discard, conn)
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
client := &Client{Transport: tr}
const bodySize = 256 << 10 const bodySize = 256 << 10
finalBit := make(byteFromChanReader, 1) finalBit := make(byteFromChanReader, 1)
req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
req.ContentLength = bodySize req.ContentLength = bodySize
res, err := client.Do(req) res, err := c.Do(req)
if err := wantBody(res, err, "foo"); err != nil { if err := wantBody(res, err, "foo"); err != nil {
t.Errorf("POST response: %v", err) t.Errorf("POST response: %v", err)
} }
donec := make(chan bool) donec := make(chan bool)
go func() { go func() {
defer close(donec) defer close(donec)
res, err = client.Get(ts.URL) res, err = c.Get(ts.URL)
if err := wantBody(res, err, "bar"); err != nil { if err := wantBody(res, err, "bar"); err != nil {
t.Errorf("GET response: %v", err) t.Errorf("GET response: %v", err)
return return
...@@ -2622,10 +2567,9 @@ func TestTransportIssue10457(t *testing.T) { ...@@ -2622,10 +2567,9 @@ func TestTransportIssue10457(t *testing.T) {
conn.Close() conn.Close()
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{} c := ts.Client()
defer tr.CloseIdleConnections()
cl := &Client{Transport: tr} res, err := c.Get(ts.URL)
res, err := cl.Get(ts.URL)
if err != nil { if err != nil {
t.Fatalf("Get: %v", err) t.Fatalf("Get: %v", err)
} }
...@@ -2686,29 +2630,26 @@ func TestRetryIdempotentRequestsOnError(t *testing.T) { ...@@ -2686,29 +2630,26 @@ func TestRetryIdempotentRequestsOnError(t *testing.T) {
defer ts.Close() defer ts.Close()
var writeNumAtomic int32 var writeNumAtomic int32
tr := &Transport{ c := ts.Client()
Dial: func(network, addr string) (net.Conn, error) { c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
logf("Dial") logf("Dial")
c, err := net.Dial(network, ts.Listener.Addr().String()) c, err := net.Dial(network, ts.Listener.Addr().String())
if err != nil { if err != nil {
logf("Dial error: %v", err) logf("Dial error: %v", err)
return nil, err return nil, err
} }
return &writerFuncConn{ return &writerFuncConn{
Conn: c, Conn: c,
write: func(p []byte) (n int, err error) { write: func(p []byte) (n int, err error) {
if atomic.AddInt32(&writeNumAtomic, 1) == 2 { if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
logf("intentional write failure") logf("intentional write failure")
return 0, errors.New("second write fails") return 0, errors.New("second write fails")
} }
logf("Write(%q)", p) logf("Write(%q)", p)
return c.Write(p) return c.Write(p)
}, },
}, nil }, nil
},
} }
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
SetRoundTripRetried(func() { SetRoundTripRetried(func() {
logf("Retried.") logf("Retried.")
...@@ -2752,6 +2693,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { ...@@ -2752,6 +2693,7 @@ func TestTransportClosesBodyOnError(t *testing.T) {
readBody <- err readBody <- err
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
fakeErr := errors.New("fake error") fakeErr := errors.New("fake error")
didClose := make(chan bool, 1) didClose := make(chan bool, 1)
req, _ := NewRequest("POST", ts.URL, struct { req, _ := NewRequest("POST", ts.URL, struct {
...@@ -2767,7 +2709,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { ...@@ -2767,7 +2709,7 @@ func TestTransportClosesBodyOnError(t *testing.T) {
return nil return nil
}), }),
}) })
res, err := DefaultClient.Do(req) res, err := c.Do(req)
if res != nil { if res != nil {
defer res.Body.Close() defer res.Body.Close()
} }
...@@ -2801,23 +2743,19 @@ func TestTransportDialTLS(t *testing.T) { ...@@ -2801,23 +2743,19 @@ func TestTransportDialTLS(t *testing.T) {
mu.Unlock() mu.Unlock()
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{ c := ts.Client()
DialTLS: func(netw, addr string) (net.Conn, error) { c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
mu.Lock() mu.Lock()
didDial = true didDial = true
mu.Unlock() mu.Unlock()
c, err := tls.Dial(netw, addr, &tls.Config{ c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
InsecureSkipVerify: true, if err != nil {
}) return nil, err
if err != nil { }
return nil, err return c, c.Handshake()
}
return c, c.Handshake()
},
} }
defer tr.CloseIdleConnections()
client := &Client{Transport: tr} res, err := c.Get(ts.URL)
res, err := client.Get(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -2899,10 +2837,11 @@ func TestTransportRangeAndGzip(t *testing.T) { ...@@ -2899,10 +2837,11 @@ func TestTransportRangeAndGzip(t *testing.T) {
reqc <- r reqc <- r
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
req, _ := NewRequest("GET", ts.URL, nil) req, _ := NewRequest("GET", ts.URL, nil)
req.Header.Set("Range", "bytes=7-11") req.Header.Set("Range", "bytes=7-11")
res, err := DefaultClient.Do(req) res, err := c.Do(req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -2931,9 +2870,7 @@ func TestTransportResponseCancelRace(t *testing.T) { ...@@ -2931,9 +2870,7 @@ func TestTransportResponseCancelRace(t *testing.T) {
w.Write(b[:]) w.Write(b[:])
})) }))
defer ts.Close() defer ts.Close()
tr := ts.Client().Transport.(*Transport)
tr := &Transport{}
defer tr.CloseIdleConnections()
req, err := NewRequest("GET", ts.URL, nil) req, err := NewRequest("GET", ts.URL, nil)
if err != nil { if err != nil {
...@@ -2967,9 +2904,7 @@ func TestTransportDialCancelRace(t *testing.T) { ...@@ -2967,9 +2904,7 @@ func TestTransportDialCancelRace(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close() defer ts.Close()
tr := ts.Client().Transport.(*Transport)
tr := &Transport{}
defer tr.CloseIdleConnections()
req, err := NewRequest("GET", ts.URL, nil) req, err := NewRequest("GET", ts.URL, nil)
if err != nil { if err != nil {
...@@ -3096,6 +3031,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) { ...@@ -3096,6 +3031,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) {
w.WriteHeader(StatusOK) w.WriteHeader(StatusOK)
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
fail := 0 fail := 0
count := 100 count := 100
...@@ -3105,10 +3041,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) { ...@@ -3105,10 +3041,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tr := new(Transport) resp, err := c.Do(req)
defer tr.CloseIdleConnections()
client := &Client{Transport: tr}
resp, err := client.Do(req)
if err != nil { if err != nil {
fail++ fail++
t.Logf("%d = %#v", i, err) t.Logf("%d = %#v", i, err)
...@@ -3321,10 +3254,8 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { ...@@ -3321,10 +3254,8 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
w.Write(rgz) // arbitrary gzip response w.Write(rgz) // arbitrary gzip response
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
tr := &Transport{}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
res, err := c.Get(ts.URL) res, err := c.Get(ts.URL)
if err != nil { if err != nil {
...@@ -3353,12 +3284,9 @@ func TestTransportResponseHeaderLength(t *testing.T) { ...@@ -3353,12 +3284,9 @@ func TestTransportResponseHeaderLength(t *testing.T) {
} }
})) }))
defer ts.Close() defer ts.Close()
c := ts.Client()
c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
tr := &Transport{
MaxResponseHeaderBytes: 512 << 10,
}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
if res, err := c.Get(ts.URL); err != nil { if res, err := c.Get(ts.URL); err != nil {
t.Fatal(err) t.Fatal(err)
} else { } else {
...@@ -3619,8 +3547,8 @@ func TestTransportRejectsAlphaPort(t *testing.T) { ...@@ -3619,8 +3547,8 @@ func TestTransportRejectsAlphaPort(t *testing.T) {
// connections. The http2 test is done in TestTransportEventTrace_h2 // connections. The http2 test is done in TestTransportEventTrace_h2
func TestTLSHandshakeTrace(t *testing.T) { func TestTLSHandshakeTrace(t *testing.T) {
defer afterTest(t) defer afterTest(t)
s := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer s.Close() defer ts.Close()
var mu sync.Mutex var mu sync.Mutex
var start, done bool var start, done bool
...@@ -3640,10 +3568,8 @@ func TestTLSHandshakeTrace(t *testing.T) { ...@@ -3640,10 +3568,8 @@ func TestTLSHandshakeTrace(t *testing.T) {
}, },
} }
tr := &Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} c := ts.Client()
defer tr.CloseIdleConnections() req, err := NewRequest("GET", ts.URL, nil)
c := &Client{Transport: tr}
req, err := NewRequest("GET", s.URL, nil)
if err != nil { if err != nil {
t.Fatal("Unable to construct test request:", err) t.Fatal("Unable to construct test request:", err)
} }
...@@ -3670,16 +3596,14 @@ func TestTransportMaxIdleConns(t *testing.T) { ...@@ -3670,16 +3596,14 @@ func TestTransportMaxIdleConns(t *testing.T) {
// No body for convenience. // No body for convenience.
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{ c := ts.Client()
MaxIdleConns: 4, tr := c.Transport.(*Transport)
} tr.MaxIdleConns = 4
defer tr.CloseIdleConnections()
ip, port, err := net.SplitHostPort(ts.Listener.Addr().String()) ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c := &Client{Transport: tr}
ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) { ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
}) })
...@@ -3975,17 +3899,16 @@ func TestTransportProxyConnectHeader(t *testing.T) { ...@@ -3975,17 +3899,16 @@ func TestTransportProxyConnectHeader(t *testing.T) {
c.Close() c.Close()
})) }))
defer ts.Close() defer ts.Close()
tr := &Transport{
ProxyConnectHeader: Header{ c := ts.Client()
"User-Agent": {"foo"}, c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
"Other": {"bar"}, return url.Parse(ts.URL)
},
Proxy: func(r *Request) (*url.URL, error) {
return url.Parse(ts.URL)
},
} }
defer tr.CloseIdleConnections() c.Transport.(*Transport).ProxyConnectHeader = Header{
c := &Client{Transport: tr} "User-Agent": {"foo"},
"Other": {"bar"},
}
res, err := c.Get("https://dummy.tld/") // https to force a CONNECT res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
if err == nil { if err == nil {
res.Body.Close() res.Body.Close()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment