Commit 0bae36ad authored by Jacob Vosmaer's avatar Jacob Vosmaer

Hard-code backend dialer for TCP too

Workhorse usually connects to Rails over a Unix socket. This makes it
impossible to accidentally follow redirects to another host. This
change applies the same strictness when connecting to Rails over TCP.
parent 7e6fdf11
...@@ -28,7 +28,8 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api ...@@ -28,7 +28,8 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a := api.NewAPI(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper) parsedURL := helper.URLMustParse(ts.URL)
a := api.NewAPI(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
response := httptest.NewRecorder() response := httptest.NewRecorder()
a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest) a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest)
......
...@@ -91,8 +91,10 @@ func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *h ...@@ -91,8 +91,10 @@ func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *h
} }
httpRequest.Header.Set("Content-Type", contentType) httpRequest.Header.Set("Content-Type", contentType)
response := httptest.NewRecorder() response := httptest.NewRecorder()
apiClient := api.NewAPI(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper) parsedURL := helper.URLMustParse(ts.URL)
proxyClient := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper) roundTripper := badgateway.TestRoundTripper(parsedURL)
apiClient := api.NewAPI(parsedURL, "123", roundTripper)
proxyClient := proxy.NewProxy(parsedURL, "123", roundTripper)
UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest) UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest)
return response return response
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/url"
"time" "time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
...@@ -23,24 +24,47 @@ var DefaultTransport = &http.Transport{ ...@@ -23,24 +24,47 @@ var DefaultTransport = &http.Transport{
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
} }
var TestRoundTripper = NewRoundTripper("", 0)
type RoundTripper struct { type RoundTripper struct {
Transport *http.Transport Transport *http.Transport
} }
func NewRoundTripper(socket string, proxyHeadersTimeout time.Duration) *RoundTripper { func TestRoundTripper(backend *url.URL) *RoundTripper {
return NewRoundTripper(backend, "", 0)
}
func NewRoundTripper(backend *url.URL, socket string, proxyHeadersTimeout time.Duration) *RoundTripper {
tr := *DefaultTransport tr := *DefaultTransport
tr.ResponseHeaderTimeout = proxyHeadersTimeout tr.ResponseHeaderTimeout = proxyHeadersTimeout
if socket != "" { if backend != nil && socket == "" {
address := mustParseAddress(backend.Host, backend.Scheme)
tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("tcp", address)
}
} else if socket != "" {
tr.Dial = func(_, _ string) (net.Conn, error) { tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", socket) return DefaultDialer.Dial("unix", socket)
} }
} else {
panic("backend is nil and socket is empty")
} }
return &RoundTripper{Transport: &tr} return &RoundTripper{Transport: &tr}
} }
func mustParseAddress(address, scheme string) string {
if host, port, err := net.SplitHostPort(address); err == nil {
return host + ":" + port
}
address = fmt.Sprintf("%s:%s", address, scheme)
if host, port, err := net.SplitHostPort(address); err == nil {
return host + ":" + port
}
panic("could not parse host/port from addres / scheme")
}
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = t.Transport.RoundTrip(r) res, err = t.Transport.RoundTrip(r)
......
...@@ -76,7 +76,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -76,7 +76,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
handler := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper) handler := newProxy(ts.URL)
HandleFileUploads(response, httpRequest, handler, tempPath, nil) HandleFileUploads(response, httpRequest, handler, tempPath, nil)
testhelper.AssertResponseCode(t, response, 202) testhelper.AssertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { if response.Body.String() != "RESPONSE" {
...@@ -150,7 +150,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -150,7 +150,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder() response := httptest.NewRecorder()
handler := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", badgateway.TestRoundTripper) handler := newProxy(ts.URL)
HandleFileUploads(response, httpRequest, handler, tempPath, &testFormProcessor{}) HandleFileUploads(response, httpRequest, handler, tempPath, &testFormProcessor{})
testhelper.AssertResponseCode(t, response, 202) testhelper.AssertResponseCode(t, response, 202)
...@@ -210,3 +210,8 @@ func TestUploadProcessingFile(t *testing.T) { ...@@ -210,3 +210,8 @@ func TestUploadProcessingFile(t *testing.T) {
HandleFileUploads(response, httpRequest, nilHandler, tempPath, &testFormProcessor{}) HandleFileUploads(response, httpRequest, nilHandler, tempPath, &testFormProcessor{})
testhelper.AssertResponseCode(t, response, 500) testhelper.AssertResponseCode(t, response, 500)
} }
func newProxy(url string) *proxy.Proxy {
parsedURL := helper.URLMustParse(url)
return proxy.NewProxy(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
}
...@@ -37,11 +37,11 @@ func NewUpstream(backend *url.URL, socket string, version string, documentRoot s ...@@ -37,11 +37,11 @@ func NewUpstream(backend *url.URL, socket string, version string, documentRoot s
Version: version, Version: version,
DocumentRoot: documentRoot, DocumentRoot: documentRoot,
DevelopmentMode: developmentMode, DevelopmentMode: developmentMode,
RoundTripper: badgateway.NewRoundTripper(socket, proxyHeadersTimeout),
} }
if backend == nil { if backend == nil {
up.Backend = DefaultBackend up.Backend = DefaultBackend
} }
up.RoundTripper = badgateway.NewRoundTripper(up.Backend, socket, proxyHeadersTimeout)
up.configureURLPrefix() up.configureURLPrefix()
up.configureRoutes() up.configureRoutes()
return &up return &up
......
...@@ -21,10 +21,11 @@ import ( ...@@ -21,10 +21,11 @@ import (
const testVersion = "123" const testVersion = "123"
func newProxy(url string, rt *badgateway.RoundTripper) *proxy.Proxy { func newProxy(url string, rt *badgateway.RoundTripper) *proxy.Proxy {
parsedURL := helper.URLMustParse(url)
if rt == nil { if rt == nil {
rt = badgateway.TestRoundTripper rt = badgateway.TestRoundTripper(parsedURL)
} }
return proxy.NewProxy(helper.URLMustParse(url), testVersion, rt) return proxy.NewProxy(parsedURL, testVersion, rt)
} }
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment