Commit b8867389 authored by Kamil Trzcinski's avatar Kamil Trzcinski

Refactor middleware

- Use gitRequest to store all processed request data
- Drop some headers when creating upstream request
parent 74eb3e7c
PREFIX=/usr/local PREFIX=/usr/local
VERSION=$(shell git describe)-$(shell date -u +%Y%m%d.%H%M%S) VERSION=$(shell git describe)-$(shell date -u +%Y%m%d.%H%M%S)
gitlab-workhorse: main.go upstream.go archive.go git-http.go helpers.go xsendfile.go gitlab-workhorse: main.go upstream.go archive.go git-http.go helpers.go xsendfile.go authorization.go
go build -ldflags "-X main.Version ${VERSION}" -o gitlab-workhorse go build -ldflags "-X main.Version ${VERSION}" -o gitlab-workhorse
install: gitlab-workhorse install: gitlab-workhorse
......
...@@ -16,13 +16,13 @@ import ( ...@@ -16,13 +16,13 @@ import (
"time" "time"
) )
func handleGetArchive(w http.ResponseWriter, r *gitRequest, format string) { func handleGetArchive(w http.ResponseWriter, r *gitRequest) {
archiveFilename := path.Base(r.ArchivePath) archiveFilename := path.Base(r.ArchivePath)
if cachedArchive, err := os.Open(r.ArchivePath); err == nil { if cachedArchive, err := os.Open(r.ArchivePath); err == nil {
defer cachedArchive.Close() defer cachedArchive.Close()
log.Printf("Serving cached file %q", r.ArchivePath) log.Printf("Serving cached file %q", r.ArchivePath)
setArchiveHeaders(w, format, archiveFilename) setArchiveHeaders(w, r.rpc, archiveFilename)
// Even if somebody deleted the cachedArchive from disk since we opened // Even if somebody deleted the cachedArchive from disk since we opened
// the file, Unix file semantics guarantee we can still read from the // the file, Unix file semantics guarantee we can still read from the
// open file in this process. // open file in this process.
...@@ -41,7 +41,7 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest, format string) { ...@@ -41,7 +41,7 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest, format string) {
defer tempFile.Close() defer tempFile.Close()
defer os.Remove(tempFile.Name()) defer os.Remove(tempFile.Name())
compressCmd, archiveFormat := parseArchiveFormat(format) compressCmd, archiveFormat := parseArchiveFormat(r.rpc)
archiveCmd := gitCommand("", "git", "--git-dir="+r.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+r.ArchivePrefix+"/", r.CommitId) archiveCmd := gitCommand("", "git", "--git-dir="+r.RepoPath, "archive", "--format="+archiveFormat, "--prefix="+r.ArchivePrefix+"/", r.CommitId)
archiveStdout, err := archiveCmd.StdoutPipe() archiveStdout, err := archiveCmd.StdoutPipe()
...@@ -82,7 +82,7 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest, format string) { ...@@ -82,7 +82,7 @@ func handleGetArchive(w http.ResponseWriter, r *gitRequest, format string) {
archiveReader := io.TeeReader(stdout, tempFile) archiveReader := io.TeeReader(stdout, tempFile)
// Start writing the response // Start writing the response
setArchiveHeaders(w, format, archiveFilename) setArchiveHeaders(w, r.rpc, archiveFilename)
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if _, err := io.Copy(w, archiveReader); err != nil { if _, err := io.Copy(w, archiveReader); err != nil {
logContext("handleGetArchive read from subprocess", err) logContext("handleGetArchive read from subprocess", err)
......
package main
import (
"encoding/json"
"errors"
"io"
"net/http"
"strings"
)
func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc {
return func(w http.ResponseWriter, r *gitRequest) {
authReq, err := r.u.newUpstreamRequest(r.Request, nil, suffix)
if err != nil {
fail500(w, "newUpstreamRequest", err)
return
}
authResponse, err := r.u.httpClient.Do(authReq)
if err != nil {
fail500(w, "doAuthRequest", err)
return
}
defer authResponse.Body.Close()
if authResponse.StatusCode != 200 {
// The Git request is not allowed by the backend. Maybe the
// client needs to send HTTP Basic credentials. Forward the
// response from the auth backend to our client. This includes
// the 'WWW-Authenticate' header that acts as a hint that
// Basic auth credentials are needed.
for k, v := range authResponse.Header {
// Accomodate broken clients that do case-sensitive header lookup
if k == "Www-Authenticate" {
w.Header()["WWW-Authenticate"] = v
} else {
w.Header()[k] = v
}
}
w.WriteHeader(authResponse.StatusCode)
io.Copy(w, authResponse.Body)
return
}
// The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth
// response body.
if err := json.NewDecoder(authResponse.Body).Decode(&r.authorizationResponse); err != nil {
fail500(w, "decode authorization response", err)
return
}
// Don't hog a TCP connection in CLOSE_WAIT, we can already close it now
authResponse.Body.Close()
// Negotiate authentication (Kerberos) may need to return a WWW-Authenticate
// header to the client even in case of success as per RFC4559.
for k, v := range authResponse.Header {
// Case-insensitive comparison as per RFC7230
if strings.EqualFold(k, "WWW-Authenticate") {
w.Header()[k] = v
}
}
handleFunc(w, r)
}
}
func repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc {
return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) {
if r.GL_ID == "" || r.RepoPath == "" {
fail500(w, "repoPreAuthorizeHandler", errors.New("missing authorization response"))
return
}
if !looksLikeRepo(r.RepoPath) {
http.Error(w, "Not Found", 404)
return
}
handleFunc(w, r)
}, "")
}
...@@ -12,7 +12,7 @@ import ( ...@@ -12,7 +12,7 @@ import (
"strings" "strings"
) )
func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest, _ string) { func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) {
rpc := r.URL.Query().Get("service") rpc := r.URL.Query().Get("service")
if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") { if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") {
// The 'dumb' Git HTTP protocol is not supported // The 'dumb' Git HTTP protocol is not supported
...@@ -56,7 +56,7 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest, _ string) { ...@@ -56,7 +56,7 @@ func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest, _ string) {
} }
} }
func handlePostRPC(w http.ResponseWriter, r *gitRequest, rpc string) { func handlePostRPC(w http.ResponseWriter, r *gitRequest) {
var body io.ReadCloser var body io.ReadCloser
var err error var err error
...@@ -73,7 +73,7 @@ func handlePostRPC(w http.ResponseWriter, r *gitRequest, rpc string) { ...@@ -73,7 +73,7 @@ func handlePostRPC(w http.ResponseWriter, r *gitRequest, rpc string) {
defer body.Close() defer body.Close()
// Prepare our Git subprocess // Prepare our Git subprocess
cmd := gitCommand(r.GL_ID, "git", subCommand(rpc), "--stateless-rpc", r.RepoPath) cmd := gitCommand(r.GL_ID, "git", subCommand(r.rpc), "--stateless-rpc", r.RepoPath)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
fail500(w, "handlePostRPC", err) fail500(w, "handlePostRPC", err)
...@@ -108,7 +108,7 @@ func handlePostRPC(w http.ResponseWriter, r *gitRequest, rpc string) { ...@@ -108,7 +108,7 @@ func handlePostRPC(w http.ResponseWriter, r *gitRequest, rpc string) {
body.Close() body.Close()
// Start writing the response // Start writing the response
w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-result", rpc)) w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-result", r.rpc))
w.Header().Add("Cache-Control", "no-cache") w.Header().Add("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
......
...@@ -336,7 +336,8 @@ func testAuthServer(code int, body string) *httptest.Server { ...@@ -336,7 +336,8 @@ func testAuthServer(code int, body string) *httptest.Server {
} }
func startServerOrFail(t *testing.T, ts *httptest.Server) *exec.Cmd { func startServerOrFail(t *testing.T, ts *httptest.Server) *exec.Cmd {
cmd := exec.Command("go", "run", "main.go", "upstream.go", "archive.go", "git-http.go", "helpers.go", "xsendfile.go", fmt.Sprintf("-authBackend=%s", ts.URL), fmt.Sprintf("-listenAddr=%s", servAddr)) cmd := exec.Command("go", "run", "main.go", "upstream.go", "archive.go", "git-http.go", "helpers.go", "xsendfile.go",
"authorization.go", fmt.Sprintf("-authBackend=%s", ts.URL), fmt.Sprintf("-listenAddr=%s", servAddr))
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
......
...@@ -7,16 +7,16 @@ In this file we handle request routing and interaction with the authBackend. ...@@ -7,16 +7,16 @@ In this file we handle request routing and interaction with the authBackend.
package main package main
import ( import (
"encoding/json"
"io" "io"
"log" "log"
"net/http" "net/http"
"os" "os"
"path" "path"
"regexp" "regexp"
"strings"
) )
type serviceHandleFunc func(w http.ResponseWriter, r *gitRequest)
type upstream struct { type upstream struct {
httpClient *http.Client httpClient *http.Client
authBackend string authBackend string
...@@ -25,15 +25,11 @@ type upstream struct { ...@@ -25,15 +25,11 @@ type upstream struct {
type gitService struct { type gitService struct {
method string method string
regex *regexp.Regexp regex *regexp.Regexp
middlewareFunc func(u *upstream, w http.ResponseWriter, r *http.Request, handleFunc func(w http.ResponseWriter, r *gitRequest, rpc string), rpc string) handleFunc serviceHandleFunc
handleFunc func(w http.ResponseWriter, r *gitRequest, rpc string)
rpc string rpc string
} }
// A gitReqest is an *http.Request decorated with attributes returned by the type authorizationResponse struct {
// GitLab Rails application.
type gitRequest struct {
*http.Request
// GL_ID is an environment variable used by gitlab-shell hooks during 'git // GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull' // push' and 'git pull'
GL_ID string GL_ID string
...@@ -51,17 +47,26 @@ type gitRequest struct { ...@@ -51,17 +47,26 @@ type gitRequest struct {
CommitId string CommitId string
} }
// A gitReqest is an *http.Request decorated with attributes returned by the
// GitLab Rails application.
type gitRequest struct {
*http.Request
authorizationResponse
u *upstream
rpc string
}
// Routing table // Routing table
var gitServices = [...]gitService{ var gitServices = [...]gitService{
gitService{"GET", regexp.MustCompile(`/info/refs\z`), repoPreAuth, handleGetInfoRefs, ""}, gitService{"GET", regexp.MustCompile(`/info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs), ""},
gitService{"POST", regexp.MustCompile(`/git-upload-pack\z`), repoPreAuth, handlePostRPC, "git-upload-pack"}, gitService{"POST", regexp.MustCompile(`/git-upload-pack\z`), repoPreAuthorizeHandler(handlePostRPC), "git-upload-pack"},
gitService{"POST", regexp.MustCompile(`/git-receive-pack\z`), repoPreAuth, handlePostRPC, "git-receive-pack"}, gitService{"POST", regexp.MustCompile(`/git-receive-pack\z`), repoPreAuthorizeHandler(handlePostRPC), "git-receive-pack"},
gitService{"GET", regexp.MustCompile(`/repository/archive\z`), repoPreAuth, handleGetArchive, "tar.gz"}, gitService{"GET", regexp.MustCompile(`/repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive), "tar.gz"},
gitService{"GET", regexp.MustCompile(`/repository/archive.zip\z`), repoPreAuth, handleGetArchive, "zip"}, gitService{"GET", regexp.MustCompile(`/repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive), "zip"},
gitService{"GET", regexp.MustCompile(`/repository/archive.tar\z`), repoPreAuth, handleGetArchive, "tar"}, gitService{"GET", regexp.MustCompile(`/repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive), "tar"},
gitService{"GET", regexp.MustCompile(`/repository/archive.tar.gz\z`), repoPreAuth, handleGetArchive, "tar.gz"}, gitService{"GET", regexp.MustCompile(`/repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive), "tar.gz"},
gitService{"GET", regexp.MustCompile(`/repository/archive.tar.bz2\z`), repoPreAuth, handleGetArchive, "tar.bz2"}, gitService{"GET", regexp.MustCompile(`/repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive), "tar.bz2"},
gitService{"GET", regexp.MustCompile(`/uploads/`), xSendFile, nil, ""}, gitService{"GET", regexp.MustCompile(`/uploads/`), handleSendFile, ""},
} }
func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream { func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream {
...@@ -88,68 +93,13 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -88,68 +93,13 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
g.middlewareFunc(u, w, r, g.handleFunc, g.rpc) request := gitRequest{
} Request: r,
u: u,
func repoPreAuth(u *upstream, w http.ResponseWriter, r *http.Request, handleFunc func(w http.ResponseWriter, r *gitRequest, rpc string), rpc string) { rpc: g.rpc,
authReq, err := u.newUpstreamRequest(r)
if err != nil {
fail500(w, "newUpstreamRequest", err)
return
} }
authResponse, err := u.httpClient.Do(authReq) g.handleFunc(w, &request)
if err != nil {
fail500(w, "doAuthRequest", err)
return
}
defer authResponse.Body.Close()
if authResponse.StatusCode != 200 {
// The Git request is not allowed by the backend. Maybe the
// client needs to send HTTP Basic credentials. Forward the
// response from the auth backend to our client. This includes
// the 'WWW-Authenticate' header that acts as a hint that
// Basic auth credentials are needed.
for k, v := range authResponse.Header {
// Accomodate broken clients that do case-sensitive header lookup
if k == "Www-Authenticate" {
w.Header()["WWW-Authenticate"] = v
} else {
w.Header()[k] = v
}
}
w.WriteHeader(authResponse.StatusCode)
io.Copy(w, authResponse.Body)
return
}
// The auth backend validated the client request and told us additional
// request metadata. We must extract this information from the auth
// response body.
gitReq := &gitRequest{Request: r}
if err := json.NewDecoder(authResponse.Body).Decode(gitReq); err != nil {
fail500(w, "decode JSON GL_ID", err)
return
}
// Don't hog a TCP connection in CLOSE_WAIT, we can already close it now
authResponse.Body.Close()
// Negotiate authentication (Kerberos) may need to return a WWW-Authenticate
// header to the client even in case of success as per RFC4559.
for k, v := range authResponse.Header {
// Case-insensitive comparison as per RFC7230
if strings.EqualFold(k, "WWW-Authenticate") {
w.Header()[k] = v
}
}
if !looksLikeRepo(gitReq.RepoPath) {
http.Error(w, "Not Found", 404)
return
}
handleFunc(w, gitReq, rpc)
} }
func looksLikeRepo(p string) bool { func looksLikeRepo(p string) bool {
...@@ -162,9 +112,9 @@ func looksLikeRepo(p string) bool { ...@@ -162,9 +112,9 @@ func looksLikeRepo(p string) bool {
return true return true
} }
func (u *upstream) newUpstreamRequest(r *http.Request) (*http.Request, error) { func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) {
url := u.authBackend + r.URL.RequestURI() url := u.authBackend + r.URL.RequestURI() + suffix
authReq, err := http.NewRequest(r.Method, url, nil) authReq, err := http.NewRequest(r.Method, url, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -173,6 +123,16 @@ func (u *upstream) newUpstreamRequest(r *http.Request) (*http.Request, error) { ...@@ -173,6 +123,16 @@ func (u *upstream) newUpstreamRequest(r *http.Request) (*http.Request, error) {
for k, v := range r.Header { for k, v := range r.Header {
authReq.Header[k] = v authReq.Header[k] = v
} }
// Clean some headers when issuing a new request without body
if body == nil {
authReq.Header.Del("Content-Type")
authReq.Header.Del("Content-Encoding")
authReq.Header.Del("Content-Length")
authReq.Header.Del("Accept-Encoding")
authReq.Header.Del("Transfer-Encoding")
}
// Also forward the Host header, which is excluded from the Header map by the http libary. // Also forward the Host header, which is excluded from the Header map by the http libary.
// This allows the Host header received by the backend to be consistent with other // This allows the Host header received by the backend to be consistent with other
// requests not going through gitlab-workhorse. // requests not going through gitlab-workhorse.
......
...@@ -13,15 +13,16 @@ import ( ...@@ -13,15 +13,16 @@ import (
"os" "os"
) )
func xSendFile(u *upstream, w http.ResponseWriter, r *http.Request, _ func(http.ResponseWriter, *gitRequest, string), _ string) { func handleSendFile(w http.ResponseWriter, r *gitRequest) {
upRequest, err := u.newUpstreamRequest(r) upRequest, err := r.u.newUpstreamRequest(r.Request, r.Body, "")
if err != nil { if err != nil {
fail500(w, "newUpstreamRequest", err) fail500(w, "newUpstreamRequest", err)
return return
} }
upRequest.Header.Set("X-Sendfile-Type", "X-Sendfile") upRequest.Header.Set("X-Sendfile-Type", "X-Sendfile")
upResponse, err := u.httpClient.Do(upRequest) upResponse, err := r.u.httpClient.Do(upRequest)
r.Body.Close()
if err != nil { if err != nil {
fail500(w, "do upstream request", err) fail500(w, "do upstream request", err)
return return
...@@ -63,5 +64,5 @@ func xSendFile(u *upstream, w http.ResponseWriter, r *http.Request, _ func(http. ...@@ -63,5 +64,5 @@ func xSendFile(u *upstream, w http.ResponseWriter, r *http.Request, _ func(http.
fail500(w, "xSendFile get mtime", err) fail500(w, "xSendFile get mtime", err)
return return
} }
http.ServeContent(w, r, "", fi.ModTime(), content) http.ServeContent(w, r.Request, "", fi.ModTime(), content)
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment