Commit 04d16195 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Move everything into internal/upstream

parent a827a033
...@@ -3,11 +3,13 @@ package main ...@@ -3,11 +3,13 @@ package main
import ( import (
"./internal/api" "./internal/api"
"./internal/helper" "./internal/helper"
"./internal/upstream"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"regexp" "regexp"
"testing" "testing"
"time"
) )
func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) { func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
...@@ -25,7 +27,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api ...@@ -25,7 +27,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
api := newUpstream(ts.URL, "").API api := upstream.New(ts.URL, "", "123", time.Second).API
response := httptest.NewRecorder() response := httptest.NewRecorder()
api.PreAuthorizeHandler(okHandler, suffix)(response, httpRequest) api.PreAuthorizeHandler(okHandler, suffix)(response, httpRequest)
......
/*
Miscellaneous helpers: logging, errors, subprocesses
*/
package main
import (
"net/http"
"path"
)
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
// Borrowed from: net/http/server.go
// Return the canonical path for p, eliminating . and .. elements.
func cleanURIPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
...@@ -2,7 +2,10 @@ package helper ...@@ -2,7 +2,10 @@ package helper
import ( import (
"net/http/httptest" "net/http/httptest"
"net/http"
"testing" "testing"
"regexp"
"log"
) )
func AssertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expectedCode int) { func AssertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expectedCode int) {
...@@ -22,3 +25,21 @@ func AssertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, hea ...@@ -22,3 +25,21 @@ func AssertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, hea
t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header)) t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header))
} }
} }
func TestServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if url != nil && !url.MatchString(r.URL.Path) {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(404)
return
}
if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(403)
return
}
handler(w, r)
}))
}
\ No newline at end of file
package main package upstream
import ( import (
"./internal/api" "../api"
"net/http" "net/http"
) )
......
package main package upstream
import ( import (
"./internal/helper" "../helper"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"path/filepath" "path/filepath"
......
package main package upstream
import ( import (
"./internal/helper" "../helper"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
......
package main package upstream
import "net/http" import "net/http"
func handleDevelopmentMode(developmentMode *bool, handler http.HandlerFunc) http.HandlerFunc { func handleDevelopmentMode(developmentMode bool, handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !*developmentMode { if !developmentMode {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
......
package main package upstream
import ( import (
"./internal/helper" "../helper"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
...@@ -14,7 +14,7 @@ func TestDevelopmentModeEnabled(t *testing.T) { ...@@ -14,7 +14,7 @@ func TestDevelopmentModeEnabled(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDevelopmentMode(&developmentMode, func(_ http.ResponseWriter, _ *http.Request) { handleDevelopmentMode(developmentMode, func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, r) })(w, r)
if !executed { if !executed {
...@@ -29,7 +29,7 @@ func TestDevelopmentModeDisabled(t *testing.T) { ...@@ -29,7 +29,7 @@ func TestDevelopmentModeDisabled(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
executed := false executed := false
handleDevelopmentMode(&developmentMode, func(_ http.ResponseWriter, _ *http.Request) { handleDevelopmentMode(developmentMode, func(_ http.ResponseWriter, _ *http.Request) {
executed = true executed = true
})(w, r) })(w, r)
if executed { if executed {
......
package main package upstream
import ( import (
"./internal/helper" "../helper"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
......
package main package upstream
import ( import (
"./internal/helper" "../helper"
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
......
package main package upstream
import ( import (
"fmt" "fmt"
......
package main package upstream
import ( import (
"./internal/errorpage" "../errorpage"
"./internal/git" "../git"
"./internal/lfs" "../lfs"
"net/http" "net/http"
"regexp" "regexp"
) )
...@@ -27,7 +27,7 @@ const ciAPIPattern = `^/ci/api/` ...@@ -27,7 +27,7 @@ const ciAPIPattern = `^/ci/api/`
// see upstream.ServeHTTP // see upstream.ServeHTTP
var routes []route var routes []route
func (u *upstream) compileRoutes() { func (u *Upstream) compileRoutes() {
u.routes = []route{ u.routes = []route{
// Git Clone // Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API)}, route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API)},
...@@ -59,7 +59,7 @@ func (u *upstream) compileRoutes() { ...@@ -59,7 +59,7 @@ func (u *upstream) compileRoutes() {
// Serve assets // Serve assets
route{"", regexp.MustCompile(`^/assets/`), route{"", regexp.MustCompile(`^/assets/`),
handleServeFile(u.DocumentRoot, u.urlPrefix, CacheExpireMax, handleServeFile(u.DocumentRoot, u.urlPrefix, CacheExpireMax,
handleDevelopmentMode(developmentMode, handleDevelopmentMode(u.DevelopmentMode,
handleDeployPage(u.DocumentRoot, handleDeployPage(u.DocumentRoot,
errorpage.Inject(u.DocumentRoot, errorpage.Inject(u.DocumentRoot,
u.Proxy, u.Proxy,
......
package main package upstream
import ( import (
"./internal/helper" "../helper"
"log" "log"
"net/http" "net/http"
"os" "os"
......
package main package upstream
import ( import (
"./internal/helper" "../helper"
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"io/ioutil" "io/ioutil"
...@@ -12,8 +12,6 @@ import ( ...@@ -12,8 +12,6 @@ import (
"testing" "testing"
) )
var dummyUpstream = newUpstream("http://localhost", "")
func TestServingNonExistingFile(t *testing.T) { func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil) httpRequest, _ := http.NewRequest("GET", "/file", nil)
......
package main package upstream
import ( import (
"./internal/helper" "../helper"
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
......
package main package upstream
import ( import (
"./internal/helper" "../helper"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
...@@ -13,17 +13,20 @@ import ( ...@@ -13,17 +13,20 @@ import (
"regexp" "regexp"
"strings" "strings"
"testing" "testing"
"time"
) )
var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
func TestUploadTempPathRequirement(t *testing.T) { func TestUploadTempPathRequirement(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
request := &http.Request{} request := &http.Request{}
handleFileUploads(dummyUpstream.Proxy).ServeHTTP(response, request) handleFileUploads(nilHandler).ServeHTTP(response, request)
helper.AssertResponseCode(t, response, 500) helper.AssertResponseCode(t, response, 500)
} }
func TestUploadHandlerForwardingRawData(t *testing.T) { func TestUploadHandlerForwardingRawData(t *testing.T) {
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" { if r.Method != "PATCH" {
t.Fatal("Expected PATCH request") t.Fatal("Expected PATCH request")
} }
...@@ -52,9 +55,8 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -52,9 +55,8 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
httpRequest.Header.Set(tempPathHeader, tempPath) httpRequest.Header.Set(tempPathHeader, tempPath)
u := newUpstream(ts.URL, "")
handleFileUploads(u.Proxy).ServeHTTP(response, httpRequest) handleFileUploads(New(ts.URL, "", "123", time.Second).Proxy).ServeHTTP(response, httpRequest)
helper.AssertResponseCode(t, response, 202) helper.AssertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { if response.Body.String() != "RESPONSE" {
t.Fatal("Expected RESPONSE in response body") t.Fatal("Expected RESPONSE in response body")
...@@ -70,7 +72,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -70,7 +72,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
} }
defer os.RemoveAll(tempPath) defer os.RemoveAll(tempPath)
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" { if r.Method != "PUT" {
t.Fatal("Expected PUT request") t.Fatal("Expected PUT request")
} }
...@@ -127,7 +129,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -127,7 +129,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
httpRequest.Header.Set(tempPathHeader, tempPath) httpRequest.Header.Set(tempPathHeader, tempPath)
response := httptest.NewRecorder() response := httptest.NewRecorder()
u := newUpstream(ts.URL, "") u := New(ts.URL, "", "123", time.Second)
handleFileUploads(u.Proxy).ServeHTTP(response, httpRequest) handleFileUploads(u.Proxy).ServeHTTP(response, httpRequest)
helper.AssertResponseCode(t, response, 202) helper.AssertResponseCode(t, response, 202)
......
...@@ -4,29 +4,33 @@ The upstream type implements http.Handler. ...@@ -4,29 +4,33 @@ The upstream type implements http.Handler.
In this file we handle request routing and interaction with the authBackend. In this file we handle request routing and interaction with the authBackend.
*/ */
package main package upstream
import ( import (
"./internal/api" "../api"
"./internal/proxy" "../proxy"
"fmt" "fmt"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"path"
"time" "time"
) )
type upstream struct { type Upstream struct {
Version string
API *api.API API *api.API
Proxy *proxy.Proxy Proxy *proxy.Proxy
DocumentRoot string DocumentRoot string
DevelopmentMode bool
ResponseHeadersTimeout time.Duration
urlPrefix urlPrefix urlPrefix urlPrefix
routes []route routes []route
} }
func newUpstream(authBackend string, authSocket string) *upstream { func New(authBackend string, authSocket string, version string, responseHeadersTimeout time.Duration) *Upstream {
parsedURL, err := url.Parse(authBackend) parsedURL, err := url.Parse(authBackend)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
...@@ -49,25 +53,25 @@ func newUpstream(authBackend string, authSocket string) *upstream { ...@@ -49,25 +53,25 @@ func newUpstream(authBackend string, authSocket string) *upstream {
Dial: func(_, _ string) (net.Conn, error) { Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", authSocket) return dialer.Dial("unix", authSocket)
}, },
ResponseHeaderTimeout: *responseHeadersTimeout, ResponseHeaderTimeout: responseHeadersTimeout,
} }
} }
proxyTransport := proxy.NewRoundTripper(authTransport) proxyTransport := proxy.NewRoundTripper(authTransport)
up := &upstream{ up := &Upstream{
API: &api.API{ API: &api.API{
Client: &http.Client{Transport: proxyTransport}, Client: &http.Client{Transport: proxyTransport},
URL: parsedURL, URL: parsedURL,
Version: Version, Version: version,
}, },
Proxy: proxy.NewProxy(parsedURL, proxyTransport, Version), Proxy: proxy.NewProxy(parsedURL, proxyTransport, version),
urlPrefix: urlPrefix(relativeURLRoot), urlPrefix: urlPrefix(relativeURLRoot),
} }
up.compileRoutes() up.compileRoutes()
return up return up
} }
func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
w := newLoggingResponseWriter(ow) w := newLoggingResponseWriter(ow)
defer w.Log(r) defer w.Log(r)
...@@ -113,3 +117,30 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { ...@@ -113,3 +117,30 @@ func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
ro.handler.ServeHTTP(&w, r) ro.handler.ServeHTTP(&w, r)
} }
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
// Borrowed from: net/http/server.go
// Return the canonical path for p, eliminating . and .. elements.
func cleanURIPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
\ No newline at end of file
package main package upstream
import ( import (
"strings" "strings"
......
...@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type. ...@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type.
package main package main
import ( import (
"./internal/upstream"
"flag" "flag"
"fmt" "fmt"
"log" "log"
...@@ -80,7 +81,9 @@ func main() { ...@@ -80,7 +81,9 @@ func main() {
}() }()
} }
upstream := newUpstream(*authBackend, *authSocket) up := upstream.New(*authBackend, *authSocket, Version, *responseHeadersTimeout)
upstream.DocumentRoot = *documentRoot up.DocumentRoot = *documentRoot
log.Fatal(http.Serve(listener, upstream)) up.DevelopmentMode = *developmentMode
log.Fatal(http.Serve(listener, up))
} }
...@@ -2,6 +2,8 @@ package main ...@@ -2,6 +2,8 @@ package main
import ( import (
"./internal/api" "./internal/api"
"./internal/helper"
"./internal/upstream"
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
...@@ -283,26 +285,8 @@ func newBranch() string { ...@@ -283,26 +285,8 @@ func newBranch() string {
return fmt.Sprintf("branch-%d", time.Now().UnixNano()) return fmt.Sprintf("branch-%d", time.Now().UnixNano())
} }
func testServerWithHandler(url *regexp.Regexp, handler http.HandlerFunc) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if url != nil && !url.MatchString(r.URL.Path) {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(404)
return
}
if version := r.Header.Get("Gitlab-Workhorse"); version == "" {
log.Println("UPSTREAM", r.Method, r.URL, "DENY")
w.WriteHeader(403)
return
}
handler(w, r)
}))
}
func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Server { func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Server {
return testServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) { return helper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) {
// Write pure string // Write pure string
if data, ok := body.(string); ok { if data, ok := body.(string); ok {
log.Println("UPSTREAM", r.Method, r.URL, code) log.Println("UPSTREAM", r.Method, r.URL, code)
...@@ -327,7 +311,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se ...@@ -327,7 +311,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
} }
func startWorkhorseServer(authBackend string) *httptest.Server { func startWorkhorseServer(authBackend string) *httptest.Server {
u := newUpstream(authBackend, "") u := upstream.New(authBackend, "", "123", time.Second)
return httptest.NewServer(u) return httptest.NewServer(u)
} }
......
...@@ -3,6 +3,7 @@ package main ...@@ -3,6 +3,7 @@ package main
import ( import (
"./internal/helper" "./internal/helper"
"./internal/proxy" "./internal/proxy"
"./internal/upstream"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
...@@ -15,8 +16,12 @@ import ( ...@@ -15,8 +16,12 @@ import (
"time" "time"
) )
func newUpstream(url string) *upstream.Upstream {
return upstream.New(url, "", "123", time.Second)
}
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
ts := testServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(regexp.MustCompile(`/url/path\z`), func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
t.Fatal("Expected POST request") t.Fatal("Expected POST request")
} }
...@@ -42,7 +47,7 @@ func TestProxyRequest(t *testing.T) { ...@@ -42,7 +47,7 @@ func TestProxyRequest(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
u := newUpstream(ts.URL, "") u := newUpstream(ts.URL)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 202) helper.AssertResponseCode(t, w, 202)
...@@ -60,7 +65,7 @@ func TestProxyError(t *testing.T) { ...@@ -60,7 +65,7 @@ func TestProxyError(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
u := newUpstream("http://localhost:655575/", "") u := newUpstream("http://localhost:655575/")
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 502) helper.AssertResponseCode(t, w, 502)
...@@ -68,7 +73,7 @@ func TestProxyError(t *testing.T) { ...@@ -68,7 +73,7 @@ func TestProxyError(t *testing.T) {
} }
func TestProxyReadTimeout(t *testing.T) { func TestProxyReadTimeout(t *testing.T) {
ts := testServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) { ts := helper.TestServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Minute) time.Sleep(time.Minute)
}) })
...@@ -89,7 +94,7 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -89,7 +94,7 @@ func TestProxyReadTimeout(t *testing.T) {
}, },
) )
u := newUpstream(ts.URL, "") u := newUpstream(ts.URL)
url, err := url.Parse(ts.URL) url, err := url.Parse(ts.URL)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -103,7 +108,7 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -103,7 +108,7 @@ func TestProxyReadTimeout(t *testing.T) {
} }
func TestProxyHandlerTimeout(t *testing.T) { func TestProxyHandlerTimeout(t *testing.T) {
ts := testServerWithHandler(nil, ts := helper.TestServerWithHandler(nil,
http.TimeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.TimeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second) time.Sleep(time.Second)
}), time.Millisecond, "Request took too long").ServeHTTP, }), time.Millisecond, "Request took too long").ServeHTTP,
...@@ -114,7 +119,7 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -114,7 +119,7 @@ func TestProxyHandlerTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
u := newUpstream(ts.URL, "") u := newUpstream(ts.URL)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment