Commit a7666718 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Remove upstream from gitRequest

parent b0525d6c
package main package main
func artifactsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func (u *upstream) artifactsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(handleFunc, "/authorize") return u.preAuthorizeHandler(handleFunc, "/authorize")
} }
...@@ -51,15 +51,15 @@ func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix st ...@@ -51,15 +51,15 @@ func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix st
return authReq, nil return authReq, nil
} }
func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc { func (u *upstream) preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *gitRequest) {
authReq, err := r.u.newUpstreamRequest(r.Request, nil, suffix) authReq, err := u.newUpstreamRequest(r.Request, nil, suffix)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err)) fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err))
return return
} }
authResponse, err := r.u.httpClient.Do(authReq) authResponse, err := u.httpClient.Do(authReq)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: do %v: %v", authReq.URL.Path, err)) fail500(w, fmt.Errorf("preAuthorizeHandler: do %v: %v", authReq.URL.Path, err))
return return
......
...@@ -23,14 +23,13 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, aut ...@@ -23,14 +23,13 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, aut
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
u := newUpstream(ts.URL, nil)
request := gitRequest{ request := gitRequest{
Request: httpRequest, Request: httpRequest,
u: newUpstream(ts.URL, nil),
} }
response := httptest.NewRecorder() response := httptest.NewRecorder()
preAuthorizeHandler(okHandler, suffix)(response, &request) u.preAuthorizeHandler(okHandler, suffix)(response, &request)
assertResponseCode(t, response, expectedCode) assertResponseCode(t, response, expectedCode)
return response return response
} }
......
...@@ -26,8 +26,8 @@ func looksLikeRepo(p string) bool { ...@@ -26,8 +26,8 @@ func looksLikeRepo(p string) bool {
return true return true
} }
func repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func (u *upstream) repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) { return u.preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.RepoPath == "" { if r.RepoPath == "" {
fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty")) fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty"))
return return
......
...@@ -17,8 +17,8 @@ import ( ...@@ -17,8 +17,8 @@ import (
"path/filepath" "path/filepath"
) )
func lfsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func (u *upstream) lfsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) { return u.preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.StoreLFSPath == "" { if r.StoreLFSPath == "" {
fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty")) fail500(w, errors.New("lfsAuthorizeHandler: StoreLFSPath empty"))
...@@ -39,7 +39,7 @@ func lfsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { ...@@ -39,7 +39,7 @@ func lfsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
}, "/authorize") }, "/authorize")
} }
func handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) { func (u *upstream) handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) {
file, err := ioutil.TempFile(r.StoreLFSPath, r.LfsOid) file, err := ioutil.TempFile(r.StoreLFSPath, r.LfsOid)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err)) fail500(w, fmt.Errorf("handleStoreLfsObject: create tempfile: %v", err))
...@@ -75,5 +75,5 @@ func handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) { ...@@ -75,5 +75,5 @@ func handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) {
r.ContentLength = 0 r.ContentLength = 0
// And proxy the request // And proxy the request
proxyRequest(w, r) u.proxyRequest(w, r)
} }
...@@ -59,34 +59,34 @@ const ciAPIPattern = `^/ci/api/` ...@@ -59,34 +59,34 @@ const ciAPIPattern = `^/ci/api/`
// see upstream.ServeHTTP // see upstream.ServeHTTP
var httpRoutes []httpRoute var httpRoutes []httpRoute
func compileRoutes() { func compileRoutes(u *upstream) {
httpRoutes = []httpRoute{ httpRoutes = []httpRoute{
// Git Clone // Git Clone
httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)}, httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), u.repoPreAuthorizeHandler(handleGetInfoRefs)},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), u.repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), u.repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))},
httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)}, httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), u.lfsAuthorizeHandler(u.handleStoreLfsObject)},
// Repository Archive // Repository Archive
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
// Repository Archive API // Repository Archive API
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
// CI Artifacts API // CI Artifacts API
httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))}, httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), u.artifactsAuthorizeHandler(contentEncodingHandler(u.handleFileUploads))},
// Explicitly proxy API requests // Explicitly proxy API requests
httpRoute{"", regexp.MustCompile(apiPattern), proxyRequest}, httpRoute{"", regexp.MustCompile(apiPattern), u.proxyRequest},
httpRoute{"", regexp.MustCompile(ciAPIPattern), proxyRequest}, httpRoute{"", regexp.MustCompile(ciAPIPattern), u.proxyRequest},
// Serve assets // Serve assets
httpRoute{"", regexp.MustCompile(`^/assets/`), httpRoute{"", regexp.MustCompile(`^/assets/`),
...@@ -94,7 +94,7 @@ func compileRoutes() { ...@@ -94,7 +94,7 @@ func compileRoutes() {
handleDevelopmentMode(developmentMode, handleDevelopmentMode(developmentMode,
handleDeployPage(documentRoot, handleDeployPage(documentRoot,
handleRailsError(documentRoot, handleRailsError(documentRoot,
proxyRequest, u.proxyRequest,
), ),
), ),
), ),
...@@ -106,7 +106,7 @@ func compileRoutes() { ...@@ -106,7 +106,7 @@ func compileRoutes() {
handleServeFile(documentRoot, CacheDisabled, handleServeFile(documentRoot, CacheDisabled,
handleDeployPage(documentRoot, handleDeployPage(documentRoot,
handleRailsError(documentRoot, handleRailsError(documentRoot,
proxyRequest, u.proxyRequest,
), ),
), ),
), ),
...@@ -173,6 +173,6 @@ func main() { ...@@ -173,6 +173,6 @@ func main() {
} }
upstream := newUpstream(*authBackend, proxyTransport) upstream := newUpstream(*authBackend, proxyTransport)
compileRoutes() compileRoutes(upstream)
log.Fatal(http.Serve(listener, upstream)) log.Fatal(http.Serve(listener, upstream))
} }
...@@ -326,8 +326,9 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se ...@@ -326,8 +326,9 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
} }
func startWorkhorseServer(authBackend string) *httptest.Server { func startWorkhorseServer(authBackend string) *httptest.Server {
compileRoutes() u := newUpstream(authBackend, nil)
return httptest.NewServer(newUpstream(authBackend, nil)) compileRoutes(u)
return httptest.NewServer(u)
} }
func runOrFail(t *testing.T, cmd *exec.Cmd) { func runOrFail(t *testing.T, cmd *exec.Cmd) {
......
...@@ -51,7 +51,7 @@ func headerClone(h http.Header) http.Header { ...@@ -51,7 +51,7 @@ func headerClone(h http.Header) http.Header {
return h2 return h2
} }
func proxyRequest(w http.ResponseWriter, r *gitRequest) { func (u *upstream) proxyRequest(w http.ResponseWriter, r *gitRequest) {
// Clone request // Clone request
req := *r.Request req := *r.Request
req.Header = headerClone(r.Header) req.Header = headerClone(r.Header)
...@@ -61,5 +61,5 @@ func proxyRequest(w http.ResponseWriter, r *gitRequest) { ...@@ -61,5 +61,5 @@ func proxyRequest(w http.ResponseWriter, r *gitRequest) {
rw := newSendFileResponseWriter(w, &req) rw := newSendFileResponseWriter(w, &req)
defer rw.Flush() defer rw.Flush()
r.u.httpProxy.ServeHTTP(&rw, &req) u.httpProxy.ServeHTTP(&rw, &req)
} }
...@@ -41,11 +41,11 @@ func TestProxyRequest(t *testing.T) { ...@@ -41,11 +41,11 @@ func TestProxyRequest(t *testing.T) {
request := gitRequest{ request := gitRequest{
Request: httpRequest, Request: httpRequest,
u: newUpstream(ts.URL, nil),
} }
u := newUpstream(ts.URL, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) u.proxyRequest(w, &request)
assertResponseCode(t, w, 202) assertResponseCode(t, w, 202)
assertResponseBody(t, w, "RESPONSE") assertResponseBody(t, w, "RESPONSE")
...@@ -67,11 +67,11 @@ func TestProxyError(t *testing.T) { ...@@ -67,11 +67,11 @@ func TestProxyError(t *testing.T) {
request := gitRequest{ request := gitRequest{
Request: httpRequest, Request: httpRequest,
u: newUpstream("http://localhost:655575/", &transport),
} }
u := newUpstream("http://localhost:655575/", &transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) u.proxyRequest(w, &request)
assertResponseCode(t, w, 502) assertResponseCode(t, w, 502)
assertResponseBody(t, w, "dial tcp: invalid port 655575") assertResponseBody(t, w, "dial tcp: invalid port 655575")
} }
...@@ -100,11 +100,11 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -100,11 +100,11 @@ func TestProxyReadTimeout(t *testing.T) {
request := gitRequest{ request := gitRequest{
Request: httpRequest, Request: httpRequest,
u: newUpstream(ts.URL, transport),
} }
u := newUpstream(ts.URL, transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) u.proxyRequest(w, &request)
assertResponseCode(t, w, 502) assertResponseCode(t, w, 502)
assertResponseBody(t, w, "net/http: timeout awaiting response headers") assertResponseBody(t, w, "net/http: timeout awaiting response headers")
} }
...@@ -127,11 +127,11 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -127,11 +127,11 @@ func TestProxyHandlerTimeout(t *testing.T) {
request := gitRequest{ request := gitRequest{
Request: httpRequest, Request: httpRequest,
u: newUpstream(ts.URL, transport),
} }
u := newUpstream(ts.URL, transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxyRequest(w, &request) u.proxyRequest(w, &request)
assertResponseCode(t, w, 503) assertResponseCode(t, w, 503)
assertResponseBody(t, w, "Request took too long") assertResponseBody(t, w, "Request took too long")
} }
...@@ -83,7 +83,7 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle ...@@ -83,7 +83,7 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
return cleanup, nil return cleanup, nil
} }
func handleFileUploads(w http.ResponseWriter, r *gitRequest) { func (u *upstream) handleFileUploads(w http.ResponseWriter, r *gitRequest) {
if r.TempPath == "" { if r.TempPath == "" {
fail500(w, errors.New("handleFileUploads: TempPath empty")) fail500(w, errors.New("handleFileUploads: TempPath empty"))
return return
...@@ -97,7 +97,7 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) { ...@@ -97,7 +97,7 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) {
cleanup, err := rewriteFormFilesFromMultipart(r, writer) cleanup, err := rewriteFormFilesFromMultipart(r, writer)
if err != nil { if err != nil {
if err == http.ErrNotMultipart { if err == http.ErrNotMultipart {
proxyRequest(w, r) u.proxyRequest(w, r)
} else { } else {
fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err)) fail500(w, fmt.Errorf("handleFileUploads: extract files from multipart: %v", err))
} }
...@@ -117,5 +117,5 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) { ...@@ -117,5 +117,5 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) {
r.Header.Set("Content-Type", writer.FormDataContentType()) r.Header.Set("Content-Type", writer.FormDataContentType())
// Proxy the request // Proxy the request
proxyRequest(w, r) u.proxyRequest(w, r)
} }
...@@ -21,7 +21,7 @@ func TestUploadTempPathRequirement(t *testing.T) { ...@@ -21,7 +21,7 @@ func TestUploadTempPathRequirement(t *testing.T) {
TempPath: "", TempPath: "",
}, },
} }
handleFileUploads(response, &request) newUpstream("http://localhost", nil).handleFileUploads(response, &request)
assertResponseCode(t, response, 500) assertResponseCode(t, response, 500)
} }
...@@ -55,12 +55,13 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -55,12 +55,13 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{ request := gitRequest{
Request: httpRequest, Request: httpRequest,
u: newUpstream(ts.URL, nil),
authorizationResponse: authorizationResponse{ authorizationResponse: authorizationResponse{
TempPath: tempPath, TempPath: tempPath,
}, },
} }
handleFileUploads(response, &request) u := newUpstream(ts.URL, nil)
u.handleFileUploads(response, &request)
assertResponseCode(t, response, 202) assertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { if response.Body.String() != "RESPONSE" {
t.Fatal("Expected RESPONSE in response body") t.Fatal("Expected RESPONSE in response body")
...@@ -135,12 +136,13 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -135,12 +136,13 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{ request := gitRequest{
Request: httpRequest, Request: httpRequest,
u: newUpstream(ts.URL, nil),
authorizationResponse: authorizationResponse{ authorizationResponse: authorizationResponse{
TempPath: tempPath, TempPath: tempPath,
}, },
} }
handleFileUploads(response, &request) u := newUpstream(ts.URL, nil)
u.handleFileUploads(response, &request)
assertResponseCode(t, response, 202) assertResponseCode(t, response, 202)
if _, err := os.Stat(filePath); !os.IsNotExist(err) { if _, err := os.Stat(filePath); !os.IsNotExist(err) {
......
...@@ -57,7 +57,6 @@ type authorizationResponse struct { ...@@ -57,7 +57,6 @@ type authorizationResponse struct {
type gitRequest struct { type gitRequest struct {
*http.Request *http.Request
authorizationResponse authorizationResponse
u *upstream
// This field contains the URL.Path stripped from RelativeUrlRoot // This field contains the URL.Path stripped from RelativeUrlRoot
relativeURIPath string relativeURIPath string
...@@ -140,7 +139,6 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -140,7 +139,6 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
request := gitRequest{ request := gitRequest{
Request: r, Request: r,
relativeURIPath: relativeURIPath, relativeURIPath: relativeURIPath,
u: u,
} }
g.handleFunc(&w, &request) g.handleFunc(&w, &request)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment