Commit c7e3abe4 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Add Proxy type

parent ffa068f5
...@@ -39,7 +39,7 @@ func (api *API) lfsAuthorizeHandler(handleFunc serviceHandleFunc) httpHandleFunc ...@@ -39,7 +39,7 @@ func (api *API) lfsAuthorizeHandler(handleFunc serviceHandleFunc) httpHandleFunc
}, "/authorize") }, "/authorize")
} }
func (u *upstream) handleStoreLfsObject(w http.ResponseWriter, r *http.Request, a *apiResponse) { func (p *Proxy) handleStoreLfsObject(w http.ResponseWriter, r *http.Request, a *apiResponse) {
file, err := ioutil.TempFile(a.StoreLFSPath, a.LfsOid) file, err := ioutil.TempFile(a.StoreLFSPath, a.LfsOid)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err)) fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
...@@ -75,5 +75,5 @@ func (u *upstream) handleStoreLfsObject(w http.ResponseWriter, r *http.Request, ...@@ -75,5 +75,5 @@ func (u *upstream) handleStoreLfsObject(w http.ResponseWriter, r *http.Request,
r.ContentLength = 0 r.ContentLength = 0
// And proxy the request // And proxy the request
u.proxyRequest(w, r) p.ServeHTTP(w, r)
} }
...@@ -63,12 +63,13 @@ var httpRoutes []httpRoute ...@@ -63,12 +63,13 @@ var httpRoutes []httpRoute
func compileRoutes(u *upstream) { func compileRoutes(u *upstream) {
api := u.API api := u.API
proxy := u.Proxy
httpRoutes = []httpRoute{ httpRoutes = []httpRoute{
// Git Clone // Git Clone
httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), api.repoPreAuthorizeHandler(handleGetInfoRefs)}, httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), api.repoPreAuthorizeHandler(handleGetInfoRefs)},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(api.repoPreAuthorizeHandler(handlePostRPC))}, httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(api.repoPreAuthorizeHandler(handlePostRPC))},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(api.repoPreAuthorizeHandler(handlePostRPC))}, httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(api.repoPreAuthorizeHandler(handlePostRPC))},
httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), api.lfsAuthorizeHandler(u.handleStoreLfsObject)}, httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), api.lfsAuthorizeHandler(proxy.handleStoreLfsObject)},
// Repository Archive // Repository Archive
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), api.repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), api.repoPreAuthorizeHandler(handleGetArchive)},
...@@ -85,11 +86,11 @@ func compileRoutes(u *upstream) { ...@@ -85,11 +86,11 @@ func compileRoutes(u *upstream) {
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), api.repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), api.repoPreAuthorizeHandler(handleGetArchive)},
// CI Artifacts API // CI Artifacts API
httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(u.artifactsAuthorizeHandler(u.handleFileUploads))}, httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(api.artifactsAuthorizeHandler(proxy.handleFileUploads))},
// Explicitly proxy API requests // Explicitly proxy API requests
httpRoute{"", regexp.MustCompile(apiPattern), u.proxyRequest}, httpRoute{"", regexp.MustCompile(apiPattern), proxy.ServeHTTP},
httpRoute{"", regexp.MustCompile(ciAPIPattern), u.proxyRequest}, httpRoute{"", regexp.MustCompile(ciAPIPattern), proxy.ServeHTTP},
// Serve assets // Serve assets
httpRoute{"", regexp.MustCompile(`^/assets/`), httpRoute{"", regexp.MustCompile(`^/assets/`),
...@@ -97,7 +98,7 @@ func compileRoutes(u *upstream) { ...@@ -97,7 +98,7 @@ func compileRoutes(u *upstream) {
handleDevelopmentMode(developmentMode, handleDevelopmentMode(developmentMode,
handleDeployPage(documentRoot, handleDeployPage(documentRoot,
handleRailsError(documentRoot, handleRailsError(documentRoot,
u.proxyRequest, proxy.ServeHTTP,
), ),
), ),
), ),
...@@ -109,7 +110,7 @@ func compileRoutes(u *upstream) { ...@@ -109,7 +110,7 @@ func compileRoutes(u *upstream) {
u.handleServeFile(documentRoot, CacheDisabled, u.handleServeFile(documentRoot, CacheDisabled,
handleDeployPage(documentRoot, handleDeployPage(documentRoot,
handleRailsError(documentRoot, handleRailsError(documentRoot,
u.proxyRequest, proxy.ServeHTTP,
), ),
), ),
), ),
......
...@@ -51,7 +51,7 @@ func headerClone(h http.Header) http.Header { ...@@ -51,7 +51,7 @@ func headerClone(h http.Header) http.Header {
return h2 return h2
} }
func (u *upstream) proxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Clone request // Clone request
req := *r req := *r
req.Header = headerClone(r.Header) req.Header = headerClone(r.Header)
...@@ -61,5 +61,5 @@ func (u *upstream) proxyRequest(w http.ResponseWriter, r *http.Request) { ...@@ -61,5 +61,5 @@ func (u *upstream) proxyRequest(w http.ResponseWriter, r *http.Request) {
rw := newSendFileResponseWriter(w, &req) rw := newSendFileResponseWriter(w, &req)
defer rw.Flush() defer rw.Flush()
u.httpProxy.ServeHTTP(&rw, &req) p.ReverseProxy.ServeHTTP(&rw, &req)
} }
...@@ -41,7 +41,7 @@ func TestProxyRequest(t *testing.T) { ...@@ -41,7 +41,7 @@ func TestProxyRequest(t *testing.T) {
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.proxyRequest(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 202) assertResponseCode(t, w, 202)
assertResponseBody(t, w, "RESPONSE") assertResponseBody(t, w, "RESPONSE")
...@@ -63,7 +63,7 @@ func TestProxyError(t *testing.T) { ...@@ -63,7 +63,7 @@ func TestProxyError(t *testing.T) {
u := newUpstream("http://localhost:655575/", &transport) u := newUpstream("http://localhost:655575/", &transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.proxyRequest(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 502) assertResponseCode(t, w, 502)
assertResponseBody(t, w, "dial tcp: invalid port 655575") assertResponseBody(t, w, "dial tcp: invalid port 655575")
} }
...@@ -93,7 +93,7 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -93,7 +93,7 @@ func TestProxyReadTimeout(t *testing.T) {
u := newUpstream(ts.URL, transport) u := newUpstream(ts.URL, transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.proxyRequest(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 502) assertResponseCode(t, w, 502)
assertResponseBody(t, w, "net/http: timeout awaiting response headers") assertResponseBody(t, w, "net/http: timeout awaiting response headers")
} }
...@@ -117,7 +117,7 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -117,7 +117,7 @@ func TestProxyHandlerTimeout(t *testing.T) {
u := newUpstream(ts.URL, transport) u := newUpstream(ts.URL, transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.proxyRequest(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 503) assertResponseCode(t, w, 503)
assertResponseBody(t, w, "Request took too long") assertResponseBody(t, w, "Request took too long")
} }
...@@ -85,7 +85,7 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te ...@@ -85,7 +85,7 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te
return cleanup, nil return cleanup, nil
} }
func (u *upstream) handleFileUploads(w http.ResponseWriter, r *http.Request) { func (p *Proxy) handleFileUploads(w http.ResponseWriter, r *http.Request) {
tempPath := r.Header.Get(tempPathHeader) tempPath := r.Header.Get(tempPathHeader)
if tempPath == "" { if tempPath == "" {
fail500(w, errors.New("handleFileUploads: TempPath empty")) fail500(w, errors.New("handleFileUploads: TempPath empty"))
...@@ -101,7 +101,7 @@ func (u *upstream) handleFileUploads(w http.ResponseWriter, r *http.Request) { ...@@ -101,7 +101,7 @@ func (u *upstream) handleFileUploads(w http.ResponseWriter, r *http.Request) {
cleanup, err := rewriteFormFilesFromMultipart(r, writer, tempPath) cleanup, err := rewriteFormFilesFromMultipart(r, writer, tempPath)
if err != nil { if err != nil {
if err == http.ErrNotMultipart { if err == http.ErrNotMultipart {
u.proxyRequest(w, r) p.ServeHTTP(w, r)
} else { } else {
fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err)) fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err))
} }
...@@ -121,5 +121,5 @@ func (u *upstream) handleFileUploads(w http.ResponseWriter, r *http.Request) { ...@@ -121,5 +121,5 @@ func (u *upstream) handleFileUploads(w http.ResponseWriter, r *http.Request) {
r.Header.Set("Content-Type", writer.FormDataContentType()) r.Header.Set("Content-Type", writer.FormDataContentType())
// Proxy the request // Proxy the request
u.proxyRequest(w, r) p.ServeHTTP(w, r)
} }
...@@ -17,7 +17,7 @@ import ( ...@@ -17,7 +17,7 @@ import (
func TestUploadTempPathRequirement(t *testing.T) { func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := &http.Request{} request := &http.Request{}
newUpstream("http://localhost", nil).handleFileUploads(response, request) dummyUpstream.Proxy.handleFileUploads(response, request)
assertResponseCode(t, response, 500) assertResponseCode(t, response, 500)
} }
...@@ -53,7 +53,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -53,7 +53,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
httpRequest.Header.Set(tempPathHeader, tempPath) httpRequest.Header.Set(tempPathHeader, tempPath)
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, nil)
u.handleFileUploads(response, httpRequest) u.Proxy.handleFileUploads(response, httpRequest)
assertResponseCode(t, response, 202) assertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { if response.Body.String() != "RESPONSE" {
t.Fatal("Expected RESPONSE in response body") t.Fatal("Expected RESPONSE in response body")
...@@ -128,7 +128,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -128,7 +128,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, nil)
u.handleFileUploads(response, httpRequest) u.Proxy.handleFileUploads(response, httpRequest)
assertResponseCode(t, response, 202) assertResponseCode(t, response, 202)
if _, err := os.Stat(filePath); !os.IsNotExist(err) { if _, err := os.Stat(filePath); !os.IsNotExist(err) {
......
...@@ -23,12 +23,16 @@ type API struct { ...@@ -23,12 +23,16 @@ type API struct {
} }
type upstream struct { type upstream struct {
*API API *API
httpProxy *httputil.ReverseProxy Proxy *Proxy
authBackend string authBackend string
relativeURLRoot string relativeURLRoot string
} }
type Proxy struct {
ReverseProxy *httputil.ReverseProxy
}
type apiResponse struct { type apiResponse struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git // GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull' // push' and 'git pull'
...@@ -57,6 +61,15 @@ type apiResponse struct { ...@@ -57,6 +61,15 @@ type apiResponse struct {
TempPath string TempPath string
} }
func newProxy(url *url.URL, transport http.RoundTripper) *Proxy {
// Modify a copy of url
proxyURL := *url
proxyURL.Path = ""
proxy := Proxy{ReverseProxy: httputil.NewSingleHostReverseProxy(&proxyURL)}
proxy.ReverseProxy.Transport = transport
return &proxy
}
func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream { func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream {
parsedURL, err := url.Parse(authBackend) parsedURL, err := url.Parse(authBackend)
if err != nil { if err != nil {
...@@ -68,17 +81,13 @@ func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream ...@@ -68,17 +81,13 @@ func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream
relativeURLRoot += "/" relativeURLRoot += "/"
} }
// Modify a copy of parsedURL
proxyURL := *parsedURL
proxyURL.Path = ""
up := &upstream{ up := &upstream{
authBackend: authBackend, authBackend: authBackend,
API: &API{Client: &http.Client{Transport: authTransport}, URL: parsedURL}, API: &API{Client: &http.Client{Transport: authTransport}, URL: parsedURL},
httpProxy: httputil.NewSingleHostReverseProxy(&proxyURL), Proxy: newProxy(parsedURL, authTransport),
relativeURLRoot: relativeURLRoot, relativeURLRoot: relativeURLRoot,
} }
up.httpProxy.Transport = authTransport
return up return up
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment