Commit 131dd305 authored by Catalin Irimie's avatar Catalin Irimie

Force Host header rewrite in Workhorse for Geo proxying

By default, the httputil reverse proxy does not change the Host header,
resulting in the upstream server receiving the Host header of the
current request, instead of the "expected" target (url) host.

Related to: https://github.com/golang/go/issues/28168

Changelog: changed
EE: true
parent 59b0cda6
...@@ -19,6 +19,7 @@ type Proxy struct { ...@@ -19,6 +19,7 @@ type Proxy struct {
reverseProxy *httputil.ReverseProxy reverseProxy *httputil.ReverseProxy
AllowResponseBuffering bool AllowResponseBuffering bool
customHeaders map[string]string customHeaders map[string]string
forceTargetHostHeader bool
} }
func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) { func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) {
...@@ -27,6 +28,12 @@ func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) { ...@@ -27,6 +28,12 @@ func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) {
} }
} }
func WithForcedTargetHostHeader() func(*Proxy) {
return func(proxy *Proxy) {
proxy.forceTargetHostHeader = true
}
}
func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, options ...func(*Proxy)) *Proxy { func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, options ...func(*Proxy)) *Proxy {
p := Proxy{Version: version, AllowResponseBuffering: true, customHeaders: make(map[string]string)} p := Proxy{Version: version, AllowResponseBuffering: true, customHeaders: make(map[string]string)}
...@@ -43,6 +50,17 @@ func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, op ...@@ -43,6 +50,17 @@ func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, op
option(&p) option(&p)
} }
if p.forceTargetHostHeader {
// because of https://github.com/golang/go/issues/28168, the
// upstream won't receive the expected Host header unless this
// is forced in the Director func here
previousDirector := p.reverseProxy.Director
p.reverseProxy.Director = func(request *http.Request) {
previousDirector(request)
request.Host = request.URL.Host
}
}
return &p return &p
} }
......
...@@ -243,6 +243,7 @@ func (u *upstream) updateGeoProxyFields(geoProxyURL *url.URL) { ...@@ -243,6 +243,7 @@ func (u *upstream) updateGeoProxyFields(geoProxyURL *url.URL) {
u.Version, u.Version,
geoProxyRoundTripper, geoProxyRoundTripper,
proxypkg.WithCustomHeaders(geoProxyWorkhorseHeaders), proxypkg.WithCustomHeaders(geoProxyWorkhorseHeaders),
proxypkg.WithForcedTargetHostHeader(),
) )
u.geoProxyCableRoute = u.wsRoute(`^/-/cable\z`, geoProxyUpstream) u.geoProxyCableRoute = u.wsRoute(`^/-/cable\z`, geoProxyUpstream)
u.geoProxyRoute = u.route("", "", geoProxyUpstream, withGeoProxy()) u.geoProxyRoute = u.route("", "", geoProxyUpstream, withGeoProxy())
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"regexp" "regexp"
"testing" "testing"
"time" "time"
...@@ -31,10 +32,15 @@ func newProxy(url string, rt http.RoundTripper, opts ...func(*proxy.Proxy)) *pro ...@@ -31,10 +32,15 @@ func newProxy(url string, rt http.RoundTripper, opts ...func(*proxy.Proxy)) *pro
} }
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { inboundURL, err := url.Parse("https://explicitly.set.host/url/path")
require.NoError(t, err, "parse inbound url")
urlRegexp := regexp.MustCompile(fmt.Sprintf(`%s\z`, inboundURL.Path))
ts := testhelper.TestServerWithHandler(urlRegexp, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "POST", r.Method, "method") require.Equal(t, "POST", r.Method, "method")
require.Equal(t, "test", r.Header.Get("Custom-Header"), "custom header") require.Equal(t, "test", r.Header.Get("Custom-Header"), "custom header")
require.Equal(t, testVersion, r.Header.Get("Gitlab-Workhorse"), "version header") require.Equal(t, testVersion, r.Header.Get("Gitlab-Workhorse"), "version header")
require.Equal(t, inboundURL.Host, r.Host, "sent host header")
require.Regexp( require.Regexp(
t, t,
...@@ -52,7 +58,7 @@ func TestProxyRequest(t *testing.T) { ...@@ -52,7 +58,7 @@ func TestProxyRequest(t *testing.T) {
fmt.Fprint(w, "RESPONSE") fmt.Fprint(w, "RESPONSE")
}) })
httpRequest, err := http.NewRequest("POST", ts.URL+"/url/path", bytes.NewBufferString("REQUEST")) httpRequest, err := http.NewRequest("POST", inboundURL.String(), bytes.NewBufferString("REQUEST"))
require.NoError(t, err) require.NoError(t, err)
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
...@@ -64,6 +70,30 @@ func TestProxyRequest(t *testing.T) { ...@@ -64,6 +70,30 @@ func TestProxyRequest(t *testing.T) {
require.Equal(t, "test", w.Header().Get("Custom-Response-Header"), "custom response header") require.Equal(t, "test", w.Header().Get("Custom-Response-Header"), "custom response header")
} }
func TestProxyWithForcedTargetHostHeader(t *testing.T) {
var tsUrl *url.URL
inboundURL, err := url.Parse("https://explicitly.set.host/url/path")
require.NoError(t, err, "parse upstream url")
urlRegexp := regexp.MustCompile(fmt.Sprintf(`%s\z`, inboundURL.Path))
ts := testhelper.TestServerWithHandler(urlRegexp, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, tsUrl.Host, r.Host, "upstream host header")
_, err := w.Write([]byte(`ok`))
require.NoError(t, err, "write ok response")
})
tsUrl, err = url.Parse(ts.URL)
require.NoError(t, err, "parse testserver URL")
httpRequest, err := http.NewRequest("POST", inboundURL.String(), nil)
require.NoError(t, err)
w := httptest.NewRecorder()
testProxy := newProxy(ts.URL, nil, proxy.WithForcedTargetHostHeader())
testProxy.ServeHTTP(w, httpRequest)
testhelper.RequireResponseBody(t, w, "ok")
}
func TestProxyWithCustomHeaders(t *testing.T) { func TestProxyWithCustomHeaders(t *testing.T) {
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := testhelper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "value", r.Header.Get("Custom-Header"), "custom proxy header") require.Equal(t, "value", r.Header.Get("Custom-Header"), "custom proxy header")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment