Commit 8314ff28 authored by Jacob Vosmaer (GitLab)'s avatar Jacob Vosmaer (GitLab)

Merge branch 'strict-roundtripper' into 'master'

Strict roundtripper

Workhorse usually connects to Rails over a Unix socket. This makes it
impossible to accidentally follow redirects to another host. This
change applies the same strictness when connecting to Rails over TCP.


See merge request !56
parents 56af46b9 6ed75e25
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"testing" "testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api" "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
) )
...@@ -27,7 +28,8 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api ...@@ -27,7 +28,8 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a := api.NewAPI(helper.URLMustParse(ts.URL), "123", nil) parsedURL := helper.URLMustParse(ts.URL)
a := api.NewAPI(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
response := httptest.NewRecorder() response := httptest.NewRecorder()
a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest) a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest)
......
...@@ -20,9 +20,6 @@ type API struct { ...@@ -20,9 +20,6 @@ type API struct {
} }
func NewAPI(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *API { func NewAPI(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *API {
if roundTripper == nil {
roundTripper = badgateway.NewRoundTripper("", 0)
}
return &API{ return &API{
Client: &http.Client{Transport: roundTripper}, Client: &http.Client{Transport: roundTripper},
URL: myURL, URL: myURL,
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"testing" "testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/api" "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy" "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
...@@ -90,8 +91,10 @@ func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *h ...@@ -90,8 +91,10 @@ func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *h
} }
httpRequest.Header.Set("Content-Type", contentType) httpRequest.Header.Set("Content-Type", contentType)
response := httptest.NewRecorder() response := httptest.NewRecorder()
apiClient := api.NewAPI(helper.URLMustParse(ts.URL), "123", nil) parsedURL := helper.URLMustParse(ts.URL)
proxyClient := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil) roundTripper := badgateway.TestRoundTripper(parsedURL)
apiClient := api.NewAPI(parsedURL, "123", roundTripper)
proxyClient := proxy.NewProxy(parsedURL, "123", roundTripper)
UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest) UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest)
return response return response
} }
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/url"
"time" "time"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
...@@ -27,18 +28,45 @@ type RoundTripper struct { ...@@ -27,18 +28,45 @@ type RoundTripper struct {
Transport *http.Transport Transport *http.Transport
} }
func NewRoundTripper(socket string, proxyHeadersTimeout time.Duration) *RoundTripper { func TestRoundTripper(backend *url.URL) *RoundTripper {
return NewRoundTripper(backend, "", 0)
}
func NewRoundTripper(backend *url.URL, socket string, proxyHeadersTimeout time.Duration) *RoundTripper {
tr := *DefaultTransport tr := *DefaultTransport
tr.ResponseHeaderTimeout = proxyHeadersTimeout tr.ResponseHeaderTimeout = proxyHeadersTimeout
if socket != "" { if backend != nil && socket == "" {
address := mustParseAddress(backend.Host, backend.Scheme)
tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("tcp", address)
}
} else if socket != "" {
tr.Dial = func(_, _ string) (net.Conn, error) { tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", socket) return DefaultDialer.Dial("unix", socket)
} }
} else {
panic("backend is nil and socket is empty")
} }
return &RoundTripper{Transport: &tr} return &RoundTripper{Transport: &tr}
} }
func mustParseAddress(address, scheme string) string {
if scheme == "https" {
panic("TLS is not supported for backend connections")
}
for _, suffix := range []string{"", ":" + scheme} {
address += suffix
if host, port, err := net.SplitHostPort(address); err == nil && host != "" && port != "" {
return host + ":" + port
}
}
panic("could not parse host:port from address and scheme")
}
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = t.Transport.RoundTrip(r) res, err = t.Transport.RoundTrip(r)
......
package badgateway
import (
"testing"
)
func TestMustParseAddress(t *testing.T) {
successExamples := []struct{ address, scheme, expected string }{
{"1.2.3.4:56", "http", "1.2.3.4:56"},
{"[::1]:23", "http", "::1:23"},
{"4.5.6.7", "http", "4.5.6.7:http"},
}
for _, example := range successExamples {
result := mustParseAddress(example.address, example.scheme)
if example.expected != result {
t.Errorf("expected %q, got %q", example.expected, result)
}
}
panicExamples := []struct{ address, scheme string }{
{"1.2.3.4", ""},
{"1.2.3.4", "https"},
}
for _, panicExample := range panicExamples {
func() {
defer func() {
if r := recover(); r == nil {
t.Errorf("expected panic for %v but none occurred", panicExample)
}
}()
t.Log(mustParseAddress(panicExample.address, panicExample.scheme))
}()
}
}
...@@ -21,11 +21,7 @@ func NewProxy(myURL *url.URL, version string, roundTripper *badgateway.RoundTrip ...@@ -21,11 +21,7 @@ func NewProxy(myURL *url.URL, version string, roundTripper *badgateway.RoundTrip
u := *myURL // Make a copy of p.URL u := *myURL // Make a copy of p.URL
u.Path = "" u.Path = ""
p.reverseProxy = httputil.NewSingleHostReverseProxy(&u) p.reverseProxy = httputil.NewSingleHostReverseProxy(&u)
if roundTripper != nil { p.reverseProxy.Transport = roundTripper
p.reverseProxy.Transport = roundTripper
} else {
p.reverseProxy.Transport = badgateway.NewRoundTripper("", 0)
}
return &p return &p
} }
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"strings" "strings"
"testing" "testing"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy" "gitlab.com/gitlab-org/gitlab-workhorse/internal/proxy"
"gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
...@@ -75,7 +76,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -75,7 +76,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
handler := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil) handler := newProxy(ts.URL)
HandleFileUploads(response, httpRequest, handler, tempPath, nil) HandleFileUploads(response, httpRequest, handler, tempPath, nil)
testhelper.AssertResponseCode(t, response, 202) testhelper.AssertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { if response.Body.String() != "RESPONSE" {
...@@ -149,7 +150,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -149,7 +150,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
response := httptest.NewRecorder() response := httptest.NewRecorder()
handler := proxy.NewProxy(helper.URLMustParse(ts.URL), "123", nil) handler := newProxy(ts.URL)
HandleFileUploads(response, httpRequest, handler, tempPath, &testFormProcessor{}) HandleFileUploads(response, httpRequest, handler, tempPath, &testFormProcessor{})
testhelper.AssertResponseCode(t, response, 202) testhelper.AssertResponseCode(t, response, 202)
...@@ -209,3 +210,8 @@ func TestUploadProcessingFile(t *testing.T) { ...@@ -209,3 +210,8 @@ func TestUploadProcessingFile(t *testing.T) {
HandleFileUploads(response, httpRequest, nilHandler, tempPath, &testFormProcessor{}) HandleFileUploads(response, httpRequest, nilHandler, tempPath, &testFormProcessor{})
testhelper.AssertResponseCode(t, response, 500) testhelper.AssertResponseCode(t, response, 500)
} }
func newProxy(url string) *proxy.Proxy {
parsedURL := helper.URLMustParse(url)
return proxy.NewProxy(parsedURL, "123", badgateway.TestRoundTripper(parsedURL))
}
...@@ -37,11 +37,11 @@ func NewUpstream(backend *url.URL, socket string, version string, documentRoot s ...@@ -37,11 +37,11 @@ func NewUpstream(backend *url.URL, socket string, version string, documentRoot s
Version: version, Version: version,
DocumentRoot: documentRoot, DocumentRoot: documentRoot,
DevelopmentMode: developmentMode, DevelopmentMode: developmentMode,
RoundTripper: badgateway.NewRoundTripper(socket, proxyHeadersTimeout),
} }
if backend == nil { if backend == nil {
up.Backend = DefaultBackend up.Backend = DefaultBackend
} }
up.RoundTripper = badgateway.NewRoundTripper(up.Backend, socket, proxyHeadersTimeout)
up.configureURLPrefix() up.configureURLPrefix()
up.configureRoutes() up.configureRoutes()
return &up return &up
......
...@@ -21,7 +21,11 @@ import ( ...@@ -21,7 +21,11 @@ import (
const testVersion = "123" const testVersion = "123"
func newProxy(url string, rt *badgateway.RoundTripper) *proxy.Proxy { func newProxy(url string, rt *badgateway.RoundTripper) *proxy.Proxy {
return proxy.NewProxy(helper.URLMustParse(url), testVersion, rt) parsedURL := helper.URLMustParse(url)
if rt == nil {
rt = badgateway.TestRoundTripper(parsedURL)
}
return proxy.NewProxy(parsedURL, testVersion, rt)
} }
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment