Commit b4037dcf authored by Jacob Vosmaer's avatar Jacob Vosmaer

Stop passing gitRequests when not needed

parent 0592ff1c
package main package main
func (u *upstream) artifactsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { import (
return u.preAuthorizeHandler(handleFunc, "/authorize") "net/http"
)
func (u *upstream) artifactsAuthorizeHandler(h handleFunc) handleFunc {
return u.preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
req := r.Request
req.Header.Set("Gitlab-Workhorse-Temp-Path", r.TempPath)
h(w, req)
}, "/authorize")
} }
...@@ -51,9 +51,9 @@ func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix st ...@@ -51,9 +51,9 @@ func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix st
return authReq, nil return authReq, nil
} }
func (u *upstream) preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc { func (u *upstream) preAuthorizeHandler(h serviceHandleFunc, suffix string) handleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *http.Request) {
authReq, err := u.newUpstreamRequest(r.Request, nil, suffix) authReq, err := u.newUpstreamRequest(r, nil, suffix)
if err != nil { if err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err)) fail500(w, fmt.Errorf("preAuthorizeHandler: newUpstreamRequest: %v", err))
return return
...@@ -85,10 +85,11 @@ func (u *upstream) preAuthorizeHandler(handleFunc serviceHandleFunc, suffix stri ...@@ -85,10 +85,11 @@ func (u *upstream) preAuthorizeHandler(handleFunc serviceHandleFunc, suffix stri
return return
} }
g := &gitRequest{Request: r}
// The auth backend validated the client request and told us additional // The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth // request metadata. We must extract this information from the auth
// response body. // response body.
if err := json.NewDecoder(authResponse.Body).Decode(&r.authorizationResponse); err != nil { if err := json.NewDecoder(authResponse.Body).Decode(&g.authorizationResponse); err != nil {
fail500(w, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err)) fail500(w, fmt.Errorf("preAuthorizeHandler: decode authorization response: %v", err))
return return
} }
...@@ -104,6 +105,6 @@ func (u *upstream) preAuthorizeHandler(handleFunc serviceHandleFunc, suffix stri ...@@ -104,6 +105,6 @@ func (u *upstream) preAuthorizeHandler(handleFunc serviceHandleFunc, suffix stri
} }
} }
handleFunc(w, r) h(w, g)
} }
} }
...@@ -24,12 +24,9 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, aut ...@@ -24,12 +24,9 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, aut
t.Fatal(err) t.Fatal(err)
} }
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, nil)
request := gitRequest{
Request: httpRequest,
}
response := httptest.NewRecorder() response := httptest.NewRecorder()
u.preAuthorizeHandler(okHandler, suffix)(response, &request) u.preAuthorizeHandler(okHandler, suffix)(response, httpRequest)
assertResponseCode(t, response, expectedCode) assertResponseCode(t, response, expectedCode)
return response return response
} }
......
...@@ -6,8 +6,8 @@ import ( ...@@ -6,8 +6,8 @@ import (
"path/filepath" "path/filepath"
) )
func handleDeployPage(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc { func handleDeployPage(documentRoot *string, handler handleFunc) handleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *http.Request) {
deployPage := filepath.Join(*documentRoot, "index.html") deployPage := filepath.Join(*documentRoot, "index.html")
data, err := ioutil.ReadFile(deployPage) data, err := ioutil.ReadFile(deployPage)
if err != nil { if err != nil {
......
...@@ -19,7 +19,7 @@ func TestIfNoDeployPageExist(t *testing.T) { ...@@ -19,7 +19,7 @@ func TestIfNoDeployPageExist(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) { handleDeployPage(&dir, func(w http.ResponseWriter, r *http.Request) {
executed = true executed = true
})(w, nil) })(w, nil)
if !executed { if !executed {
...@@ -40,7 +40,7 @@ func TestIfDeployPageExist(t *testing.T) { ...@@ -40,7 +40,7 @@ func TestIfDeployPageExist(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) { handleDeployPage(&dir, func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, nil) })(w, nil)
if executed { if executed {
......
...@@ -2,10 +2,10 @@ package main ...@@ -2,10 +2,10 @@ package main
import "net/http" import "net/http"
func handleDevelopmentMode(developmentMode *bool, handler serviceHandleFunc) serviceHandleFunc { func handleDevelopmentMode(developmentMode *bool, handler handleFunc) handleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *http.Request) {
if !*developmentMode { if !*developmentMode {
http.NotFound(w, r.Request) http.NotFound(w, r)
return return
} }
......
...@@ -13,9 +13,9 @@ func TestDevelopmentModeEnabled(t *testing.T) { ...@@ -13,9 +13,9 @@ func TestDevelopmentModeEnabled(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) { handleDevelopmentMode(&developmentMode, func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, &gitRequest{Request: r}) })(w, r)
if !executed { if !executed {
t.Error("The handler should get executed") t.Error("The handler should get executed")
} }
...@@ -28,9 +28,9 @@ func TestDevelopmentModeDisabled(t *testing.T) { ...@@ -28,9 +28,9 @@ func TestDevelopmentModeDisabled(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) { handleDevelopmentMode(&developmentMode, func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, &gitRequest{Request: r}) })(w, r)
if executed { if executed {
t.Error("The handler should not get executed") t.Error("The handler should not get executed")
} }
......
...@@ -59,8 +59,8 @@ func (s *errorPageResponseWriter) Flush() { ...@@ -59,8 +59,8 @@ func (s *errorPageResponseWriter) Flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
func handleRailsError(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc { func handleRailsError(documentRoot *string, handler handleFunc) handleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *http.Request) {
rw := errorPageResponseWriter{ rw := errorPageResponseWriter{
rw: w, rw: w,
path: documentRoot, path: documentRoot,
......
...@@ -22,7 +22,7 @@ func TestIfErrorPageIsPresented(t *testing.T) { ...@@ -22,7 +22,7 @@ func TestIfErrorPageIsPresented(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
handleRailsError(&dir, func(w http.ResponseWriter, r *gitRequest) { handleRailsError(&dir, func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, "Not Found") fmt.Fprint(w, "Not Found")
})(w, nil) })(w, nil)
...@@ -42,7 +42,7 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { ...@@ -42,7 +42,7 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
errorResponse := "ERROR" errorResponse := "ERROR"
handleRailsError(&dir, func(w http.ResponseWriter, r *gitRequest) { handleRailsError(&dir, func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404) w.WriteHeader(404)
fmt.Fprint(w, errorResponse) fmt.Fprint(w, errorResponse)
})(w, nil) })(w, nil)
......
...@@ -26,7 +26,7 @@ func looksLikeRepo(p string) bool { ...@@ -26,7 +26,7 @@ func looksLikeRepo(p string) bool {
return true return true
} }
func (u *upstream) repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func (u *upstream) repoPreAuthorizeHandler(handleFunc serviceHandleFunc) handleFunc {
return u.preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) { return u.preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.RepoPath == "" { if r.RepoPath == "" {
fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty")) fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty"))
......
...@@ -7,8 +7,8 @@ import ( ...@@ -7,8 +7,8 @@ import (
"net/http" "net/http"
) )
func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func contentEncodingHandler(h handleFunc) handleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *http.Request) {
var body io.ReadCloser var body io.ReadCloser
var err error var err error
...@@ -32,6 +32,6 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc { ...@@ -32,6 +32,6 @@ func contentEncodingHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
r.Body = body r.Body = body
r.Header.Del("Content-Encoding") r.Header.Del("Content-Encoding")
handleFunc(w, r) h(w, r)
} }
} }
...@@ -27,15 +27,14 @@ func TestGzipEncoding(t *testing.T) { ...@@ -27,15 +27,14 @@ func TestGzipEncoding(t *testing.T) {
} }
req.Header.Set("Content-Encoding", "gzip") req.Header.Set("Content-Encoding", "gzip")
request := gitRequest{Request: req} contentEncodingHandler(func(w http.ResponseWriter, r *http.Request) {
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
if _, ok := r.Body.(*gzip.Reader); !ok { if _, ok := r.Body.(*gzip.Reader); !ok {
t.Fatal("Expected gzip reader for body, but it's:", reflect.TypeOf(r.Body)) t.Fatal("Expected gzip reader for body, but it's:", reflect.TypeOf(r.Body))
} }
if r.Header.Get("Content-Encoding") != "" { if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted") t.Fatal("Content-Encoding should be deleted")
} }
})(resp, &request) })(resp, req)
assertResponseCode(t, resp, 200) assertResponseCode(t, resp, 200)
} }
...@@ -52,15 +51,14 @@ func TestNoEncoding(t *testing.T) { ...@@ -52,15 +51,14 @@ func TestNoEncoding(t *testing.T) {
} }
req.Header.Set("Content-Encoding", "") req.Header.Set("Content-Encoding", "")
request := gitRequest{Request: req} contentEncodingHandler(func(_ http.ResponseWriter, r *http.Request) {
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.Body != body { if r.Body != body {
t.Fatal("Expected the same body") t.Fatal("Expected the same body")
} }
if r.Header.Get("Content-Encoding") != "" { if r.Header.Get("Content-Encoding") != "" {
t.Fatal("Content-Encoding should be deleted") t.Fatal("Content-Encoding should be deleted")
} }
})(resp, &request) })(resp, req)
assertResponseCode(t, resp, 200) assertResponseCode(t, resp, 200)
} }
...@@ -74,10 +72,9 @@ func TestInvalidEncoding(t *testing.T) { ...@@ -74,10 +72,9 @@ func TestInvalidEncoding(t *testing.T) {
} }
req.Header.Set("Content-Encoding", "application/unknown") req.Header.Set("Content-Encoding", "application/unknown")
request := gitRequest{Request: req} contentEncodingHandler(func(_ http.ResponseWriter, _ *http.Request) {
contentEncodingHandler(func(w http.ResponseWriter, r *gitRequest) {
t.Fatal("it shouldn't be executed") t.Fatal("it shouldn't be executed")
})(resp, &request) })(resp, req)
assertResponseCode(t, resp, 500) assertResponseCode(t, resp, 500)
} }
...@@ -17,7 +17,7 @@ import ( ...@@ -17,7 +17,7 @@ import (
"path/filepath" "path/filepath"
) )
func (u *upstream) lfsAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { func (u *upstream) lfsAuthorizeHandler(handleFunc serviceHandleFunc) handleFunc {
return u.preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) { return u.preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.StoreLFSPath == "" { if r.StoreLFSPath == "" {
...@@ -75,5 +75,5 @@ func (u *upstream) handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) { ...@@ -75,5 +75,5 @@ func (u *upstream) handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) {
r.ContentLength = 0 r.ContentLength = 0
// And proxy the request // And proxy the request
u.proxyRequest(w, r) u.proxyRequest(w, r.Request)
} }
...@@ -43,9 +43,11 @@ var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets ...@@ -43,9 +43,11 @@ var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets
type httpRoute struct { type httpRoute struct {
method string method string
regex *regexp.Regexp regex *regexp.Regexp
handleFunc serviceHandleFunc handleFunc handleFunc
} }
type handleFunc func(http.ResponseWriter, *http.Request)
const projectPattern = `^/[^/]+/[^/]+/` const projectPattern = `^/[^/]+/[^/]+/`
const gitProjectPattern = `^/[^/]+/[^/]+\.git/` const gitProjectPattern = `^/[^/]+/[^/]+\.git/`
...@@ -63,8 +65,8 @@ func compileRoutes(u *upstream) { ...@@ -63,8 +65,8 @@ func compileRoutes(u *upstream) {
httpRoutes = []httpRoute{ httpRoutes = []httpRoute{
// Git Clone // Git Clone
httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), u.repoPreAuthorizeHandler(handleGetInfoRefs)}, httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), u.repoPreAuthorizeHandler(handleGetInfoRefs)},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), u.repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(u.repoPreAuthorizeHandler(handlePostRPC))},
httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), u.repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(u.repoPreAuthorizeHandler(handlePostRPC))},
httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), u.lfsAuthorizeHandler(u.handleStoreLfsObject)}, httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), u.lfsAuthorizeHandler(u.handleStoreLfsObject)},
// Repository Archive // Repository Archive
...@@ -82,7 +84,7 @@ func compileRoutes(u *upstream) { ...@@ -82,7 +84,7 @@ func compileRoutes(u *upstream) {
httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), u.repoPreAuthorizeHandler(handleGetArchive)}, httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), u.repoPreAuthorizeHandler(handleGetArchive)},
// CI Artifacts API // CI Artifacts API
httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), u.artifactsAuthorizeHandler(contentEncodingHandler(u.handleFileUploads))}, httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(u.artifactsAuthorizeHandler(u.handleFileUploads))},
// Explicitly proxy API requests // Explicitly proxy API requests
httpRoute{"", regexp.MustCompile(apiPattern), u.proxyRequest}, httpRoute{"", regexp.MustCompile(apiPattern), u.proxyRequest},
......
...@@ -51,15 +51,15 @@ func headerClone(h http.Header) http.Header { ...@@ -51,15 +51,15 @@ func headerClone(h http.Header) http.Header {
return h2 return h2
} }
func (u *upstream) proxyRequest(w http.ResponseWriter, r *gitRequest) { func (u *upstream) proxyRequest(w http.ResponseWriter, r *http.Request) {
// Clone request // Clone request
req := *r.Request req := r
req.Header = headerClone(r.Header) req.Header = headerClone(r.Header)
// Set Workhorse version // Set Workhorse version
req.Header.Set("Gitlab-Workhorse", Version) req.Header.Set("Gitlab-Workhorse", Version)
rw := newSendFileResponseWriter(w, &req) rw := newSendFileResponseWriter(w, req)
defer rw.Flush() defer rw.Flush()
u.httpProxy.ServeHTTP(&rw, &req) u.httpProxy.ServeHTTP(&rw, req)
} }
...@@ -39,13 +39,9 @@ func TestProxyRequest(t *testing.T) { ...@@ -39,13 +39,9 @@ func TestProxyRequest(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
request := gitRequest{
Request: httpRequest,
}
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.proxyRequest(w, &request) u.proxyRequest(w, httpRequest)
assertResponseCode(t, w, 202) assertResponseCode(t, w, 202)
assertResponseBody(t, w, "RESPONSE") assertResponseBody(t, w, "RESPONSE")
...@@ -65,13 +61,9 @@ func TestProxyError(t *testing.T) { ...@@ -65,13 +61,9 @@ func TestProxyError(t *testing.T) {
transport: http.DefaultTransport, transport: http.DefaultTransport,
} }
request := gitRequest{
Request: httpRequest,
}
u := newUpstream("http://localhost:655575/", &transport) u := newUpstream("http://localhost:655575/", &transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.proxyRequest(w, &request) u.proxyRequest(w, httpRequest)
assertResponseCode(t, w, 502) assertResponseCode(t, w, 502)
assertResponseBody(t, w, "dial tcp: invalid port 655575") assertResponseBody(t, w, "dial tcp: invalid port 655575")
} }
...@@ -98,13 +90,10 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -98,13 +90,10 @@ func TestProxyReadTimeout(t *testing.T) {
}, },
} }
request := gitRequest{
Request: httpRequest,
}
u := newUpstream(ts.URL, transport) u := newUpstream(ts.URL, transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.proxyRequest(w, &request) u.proxyRequest(w, httpRequest)
assertResponseCode(t, w, 502) assertResponseCode(t, w, 502)
assertResponseBody(t, w, "net/http: timeout awaiting response headers") assertResponseBody(t, w, "net/http: timeout awaiting response headers")
} }
...@@ -125,13 +114,10 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -125,13 +114,10 @@ func TestProxyHandlerTimeout(t *testing.T) {
transport: http.DefaultTransport, transport: http.DefaultTransport,
} }
request := gitRequest{
Request: httpRequest,
}
u := newUpstream(ts.URL, transport) u := newUpstream(ts.URL, transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.proxyRequest(w, &request) u.proxyRequest(w, httpRequest)
assertResponseCode(t, w, 503) assertResponseCode(t, w, 503)
assertResponseBody(t, w, "Request took too long") assertResponseBody(t, w, "Request took too long")
} }
...@@ -16,8 +16,8 @@ const ( ...@@ -16,8 +16,8 @@ const (
CacheExpireMax CacheExpireMax
) )
func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serviceHandleFunc) serviceHandleFunc { func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler handleFunc) handleFunc {
return func(w http.ResponseWriter, r *gitRequest) { return func(w http.ResponseWriter, r *http.Request) {
file := filepath.Join(*documentRoot, u.relativeURIPath(cleanURIPath(r.URL.Path))) file := filepath.Join(*documentRoot, u.relativeURIPath(cleanURIPath(r.URL.Path)))
// The filepath.Join does Clean traversing directories up // The filepath.Join does Clean traversing directories up
...@@ -50,7 +50,7 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou ...@@ -50,7 +50,7 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou
if notFoundHandler != nil { if notFoundHandler != nil {
notFoundHandler(w, r) notFoundHandler(w, r)
} else { } else {
http.NotFound(w, r.Request) http.NotFound(w, r)
} }
return return
} }
...@@ -65,6 +65,6 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou ...@@ -65,6 +65,6 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou
} }
log.Printf("Send static file %q (%q) for %s %q", file, w.Header().Get("Content-Encoding"), r.Method, r.RequestURI) log.Printf("Send static file %q (%q) for %s %q", file, w.Header().Get("Content-Encoding"), r.Method, r.RequestURI)
http.ServeContent(w, r.Request, filepath.Base(file), fi.ModTime(), content) http.ServeContent(w, r, filepath.Base(file), fi.ModTime(), content)
} }
} }
...@@ -14,12 +14,9 @@ import ( ...@@ -14,12 +14,9 @@ import (
func TestServingNonExistingFile(t *testing.T) { func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, request) newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, httpRequest)
assertResponseCode(t, w, 404) assertResponseCode(t, w, 404)
} }
...@@ -31,38 +28,28 @@ func TestServingDirectory(t *testing.T) { ...@@ -31,38 +28,28 @@ func TestServingDirectory(t *testing.T) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, request) newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, httpRequest)
assertResponseCode(t, w, 404) assertResponseCode(t, w, 404)
} }
func TestServingMalformedUri(t *testing.T) { func TestServingMalformedUri(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil) httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil)
request := &gitRequest{
Request: httpRequest,
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, request) newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, httpRequest)
assertResponseCode(t, w, 404) assertResponseCode(t, w, 404)
} }
func TestExecutingHandlerWhenNoFileFound(t *testing.T) { func TestExecutingHandlerWhenNoFileFound(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
}
executed := false executed := false
newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, func(w http.ResponseWriter, r *gitRequest) { newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, func(_ http.ResponseWriter, r *http.Request) {
executed = (r == request) executed = (r == httpRequest)
})(nil, request) })(nil, httpRequest)
if !executed { if !executed {
t.Error("The handler should get executed") t.Error("The handler should get executed")
} }
...@@ -76,15 +63,12 @@ func TestServingTheActualFile(t *testing.T) { ...@@ -76,15 +63,12 @@ func TestServingTheActualFile(t *testing.T) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
}
fileContent := "STATIC" fileContent := "STATIC"
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, request) newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, httpRequest)
assertResponseCode(t, w, 200) assertResponseCode(t, w, 200)
if w.Body.String() != fileContent { if w.Body.String() != fileContent {
t.Error("We should serve the file: ", w.Body.String()) t.Error("We should serve the file: ", w.Body.String())
...@@ -99,9 +83,6 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) { ...@@ -99,9 +83,6 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
request := &gitRequest{
Request: httpRequest,
}
if enableGzip { if enableGzip {
httpRequest.Header.Set("Accept-Encoding", "gzip, deflate") httpRequest.Header.Set("Accept-Encoding", "gzip, deflate")
...@@ -118,7 +99,7 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) { ...@@ -118,7 +99,7 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) {
ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, request) newUpstream("http://localhost", nil).handleServeFile(&dir, CacheDisabled, nil)(w, httpRequest)
assertResponseCode(t, w, 200) assertResponseCode(t, w, 200)
if enableGzip { if enableGzip {
assertResponseHeader(t, w, "Content-Encoding", "gzip") assertResponseHeader(t, w, "Content-Encoding", "gzip")
......
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
"os" "os"
) )
func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cleanup func(), err error) { func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string) (cleanup func(), err error) {
// Create multipart reader // Create multipart reader
reader, err := r.MultipartReader() reader, err := r.MultipartReader()
if err != nil { if err != nil {
...@@ -47,12 +47,12 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle ...@@ -47,12 +47,12 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
// Copy form field // Copy form field
if filename := p.FileName(); filename != "" { if filename := p.FileName(); filename != "" {
// Create temporary directory where the uploaded file will be stored // Create temporary directory where the uploaded file will be stored
if err := os.MkdirAll(r.TempPath, 0700); err != nil { if err := os.MkdirAll(tempPath, 0700); err != nil {
return cleanup, err return cleanup, err
} }
// Create temporary file in path returned by Authorization filter // Create temporary file in path returned by Authorization filter
file, err := ioutil.TempFile(r.TempPath, "upload_") file, err := ioutil.TempFile(tempPath, "upload_")
if err != nil { if err != nil {
return cleanup, err return cleanup, err
} }
...@@ -83,8 +83,9 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle ...@@ -83,8 +83,9 @@ func rewriteFormFilesFromMultipart(r *gitRequest, writer *multipart.Writer) (cle
return cleanup, nil return cleanup, nil
} }
func (u *upstream) handleFileUploads(w http.ResponseWriter, r *gitRequest) { func (u *upstream) handleFileUploads(w http.ResponseWriter, r *http.Request) {
if r.TempPath == "" { tempPath := r.Header.Get("Gitlab-Workhorse-Temp-Path")
if tempPath == "" {
fail500(w, errors.New("handleFileUploads: TempPath empty")) fail500(w, errors.New("handleFileUploads: TempPath empty"))
return return
} }
...@@ -94,7 +95,7 @@ func (u *upstream) handleFileUploads(w http.ResponseWriter, r *gitRequest) { ...@@ -94,7 +95,7 @@ func (u *upstream) handleFileUploads(w http.ResponseWriter, r *gitRequest) {
defer writer.Close() defer writer.Close()
// Rewrite multipart form data // Rewrite multipart form data
cleanup, err := rewriteFormFilesFromMultipart(r, writer) cleanup, err := rewriteFormFilesFromMultipart(r, writer, tempPath)
if err != nil { if err != nil {
if err == http.ErrNotMultipart { if err == http.ErrNotMultipart {
u.proxyRequest(w, r) u.proxyRequest(w, r)
......
...@@ -16,12 +16,8 @@ import ( ...@@ -16,12 +16,8 @@ import (
func TestUploadTempPathRequirement(t *testing.T) { func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{ request := &http.Request{}
authorizationResponse: authorizationResponse{ newUpstream("http://localhost", nil).handleFileUploads(response, request)
TempPath: "",
},
}
newUpstream("http://localhost", nil).handleFileUploads(response, &request)
assertResponseCode(t, response, 500) assertResponseCode(t, response, 500)
} }
...@@ -53,15 +49,11 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -53,15 +49,11 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{
Request: httpRequest, httpRequest.Header.Set("Gitlab-Workhorse-Temp-Path", tempPath)
authorizationResponse: authorizationResponse{
TempPath: tempPath,
},
}
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, nil)
u.handleFileUploads(response, &request) u.handleFileUploads(response, httpRequest)
assertResponseCode(t, response, 202) assertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { if response.Body.String() != "RESPONSE" {
t.Fatal("Expected RESPONSE in response body") t.Fatal("Expected RESPONSE in response body")
...@@ -132,17 +124,11 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -132,17 +124,11 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Body = ioutil.NopCloser(&buffer) httpRequest.Body = ioutil.NopCloser(&buffer)
httpRequest.ContentLength = int64(buffer.Len()) httpRequest.ContentLength = int64(buffer.Len())
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
httpRequest.Header.Set("Gitlab-Workhorse-Temp-Path", tempPath)
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := gitRequest{
Request: httpRequest,
authorizationResponse: authorizationResponse{
TempPath: tempPath,
},
}
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, nil)
u.handleFileUploads(response, &request) u.handleFileUploads(response, httpRequest)
assertResponseCode(t, response, 202) assertResponseCode(t, response, 202)
if _, err := os.Stat(filePath); !os.IsNotExist(err) { if _, err := os.Stat(filePath); !os.IsNotExist(err) {
......
...@@ -132,9 +132,5 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -132,9 +132,5 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
return return
} }
request := gitRequest{ g.handleFunc(&w, r)
Request: r,
}
g.handleFunc(&w, &request)
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment