Commit c5664cdf authored by Jacob Vosmaer's avatar Jacob Vosmaer

Get rid of upstream.New

parent ef04d680
......@@ -9,7 +9,6 @@ import (
"net/http/httptest"
"regexp"
"testing"
"time"
)
func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) {
......@@ -27,7 +26,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
if err != nil {
t.Fatal(err)
}
api := upstream.New(helper.URLMustParse(ts.URL), "", "123", time.Second).API
api := (&upstream.Upstream{Backend: helper.URLMustParse(ts.URL), Version: "123"}).API()
response := httptest.NewRecorder()
api.PreAuthorizeHandler(okHandler, suffix)(response, httpRequest)
......
......@@ -32,15 +32,11 @@ func (p *Proxy) configureReverseProxy() {
}
type RoundTripper struct {
transport http.RoundTripper
}
func NewRoundTripper(transport http.RoundTripper) *RoundTripper {
return &RoundTripper{transport: transport}
Transport http.RoundTripper
}
func (rt *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = rt.transport.RoundTrip(r)
res, err = rt.Transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error
......
......@@ -27,42 +27,47 @@ const ciAPIPattern = `^/ci/api/`
// We match against URI not containing the relativeUrlRoot:
// see upstream.ServeHTTP
func (u *Upstream) compileRoutes() {
func (u *Upstream) Routes() []route {
u.configureRoutesOnce.Do(u.configureRoutes)
return u.routes
}
func (u *Upstream) configureRoutes() {
u.routes = []route{
// Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API)},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(u.API))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(u.API))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API, u.Proxy)},
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API())},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API(), u.Proxy())},
// Repository Archive
route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API())},
// Repository Archive API
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API)},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), git.GetArchive(u.API())},
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API())},
// CI Artifacts API
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API, u.Proxy))},
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API(), u.Proxy()))},
// Explicitly u.Proxy API requests
route{"", regexp.MustCompile(apiPattern), u.Proxy},
route{"", regexp.MustCompile(ciAPIPattern), u.Proxy},
// Explicitly proxy API requests
route{"", regexp.MustCompile(apiPattern), u.Proxy()},
route{"", regexp.MustCompile(ciAPIPattern), u.Proxy()},
// Serve assets
route{"", regexp.MustCompile(`^/assets/`),
handleServeFile(u.DocumentRoot, u.urlPrefix, CacheExpireMax,
handleServeFile(u.DocumentRoot, u.URLPrefix(), CacheExpireMax,
handleDevelopmentMode(u.DevelopmentMode,
handleDeployPage(u.DocumentRoot,
errorpage.Inject(u.DocumentRoot,
u.Proxy,
u.Proxy(),
),
),
),
......@@ -71,10 +76,10 @@ func (u *Upstream) compileRoutes() {
// Serve static files or forward the requests
route{"", nil,
handleServeFile(u.DocumentRoot, u.urlPrefix, CacheDisabled,
handleServeFile(u.DocumentRoot, u.URLPrefix(), CacheDisabled,
handleDeployPage(u.DocumentRoot,
errorpage.Inject(u.DocumentRoot,
u.Proxy,
u.Proxy(),
),
),
),
......
package upstream
import (
"../proxy"
"net"
"net/http"
"time"
)
// Values from http.DefaultTransport
var DefaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
var DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport
Dial: DefaultDialer.Dial, // from http.DefaultTransport
ResponseHeaderTimeout: time.Minute, // custom
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
func (u *Upstream) Transport() http.RoundTripper {
u.configureTransportOnce.Do(u.configureTransport)
return u.transport
}
func (u *Upstream) configureTransport() {
t := *DefaultTransport
if u.ResponseHeaderTimeout != 0 {
t.ResponseHeaderTimeout = u.ResponseHeaderTimeout
}
if u.Socket != "" {
t.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", u.Socket)
}
}
u.transport = &proxy.RoundTripper{&t}
}
......@@ -8,61 +8,62 @@ package upstream
import (
"../api"
"../helper"
"../proxy"
"fmt"
"net"
"net/http"
"net/url"
"path"
"strings"
"sync"
"time"
)
var DefaultBackend = helper.URLMustParse("http://localhost:8080")
type Upstream struct {
Version string
API *api.API
Proxy *proxy.Proxy
DocumentRoot string
DevelopmentMode bool
ResponseHeadersTimeout time.Duration
Backend *url.URL
Version string
Socket string
DocumentRoot string
DevelopmentMode bool
ResponseHeaderTimeout time.Duration
_api *api.API
configureAPIOnce sync.Once
_proxy *proxy.Proxy
configureProxyOnce sync.Once
urlPrefix urlPrefix
routes []route
configureURLPrefixOnce sync.Once
routes []route
configureRoutesOnce sync.Once
transport http.RoundTripper
configureTransportOnce sync.Once
}
func New(authBackend *url.URL, authSocket string, version string, responseHeadersTimeout time.Duration) *Upstream {
relativeURLRoot := authBackend.Path
if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/"
}
func (u *Upstream) Proxy() *proxy.Proxy {
u.configureProxyOnce.Do(u.configureProxy)
return u._proxy
}
// Create Proxy Transport
authTransport := http.DefaultTransport
if authSocket != "" {
dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
authTransport = &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", authSocket)
},
ResponseHeaderTimeout: responseHeadersTimeout,
}
}
proxyTransport := proxy.NewRoundTripper(authTransport)
up := &Upstream{
API: &api.API{
Client: &http.Client{Transport: proxyTransport},
URL: authBackend,
Version: version,
},
Proxy: &proxy.Proxy{URL: authBackend, Transport: proxyTransport, Version: version},
urlPrefix: urlPrefix(relativeURLRoot),
func (u *Upstream) configureProxy() {
u._proxy = &proxy.Proxy{URL: u.Backend, Transport: u.Transport(), Version: u.Version}
}
func (u *Upstream) API() *api.API {
u.configureAPIOnce.Do(u.configureAPI)
return u._api
}
func (u *Upstream) configureAPI() {
u._api = &api.API{
Client: &http.Client{Transport: u.Transport()},
URL: u.Backend,
Version: u.Version,
}
up.compileRoutes()
return up
}
func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
......@@ -83,7 +84,7 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
// Check URL Root
URIPath := cleanURIPath(r.URL.Path)
prefix := u.urlPrefix
prefix := u.URLPrefix()
if !prefix.match(URIPath) {
httpError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
......@@ -92,7 +93,7 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
// Look for a matching Git service
var ro route
foundService := false
for _, ro = range u.routes {
for _, ro = range u.Routes() {
if ro.method != "" && r.Method != ro.method {
continue
}
......
......@@ -14,3 +14,19 @@ func (p urlPrefix) match(path string) bool {
pre := string(p)
return strings.HasPrefix(path, pre) || path+"/" == pre
}
func (u *Upstream) URLPrefix() urlPrefix {
u.configureURLPrefixOnce.Do(u.configureURLPrefix)
return u.urlPrefix
}
func (u *Upstream) configureURLPrefix() {
if u.Backend == nil {
u.Backend = DefaultBackend
}
relativeURLRoot := u.Backend.Path
if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/"
}
u.urlPrefix = urlPrefix(relativeURLRoot)
}
......@@ -23,7 +23,6 @@ import (
_ "net/http/pprof"
"os"
"syscall"
"time"
)
// Current version of GitLab Workhorse
......@@ -33,11 +32,11 @@ var printVersion = flag.Bool("version", false, "Print version and exit")
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022")
var authBackend = URLFlag("authBackend", "http://localhost:8080", "Authentication/authorization backend")
var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", upstream.DefaultTransport.ResponseHeaderTimeout, "How long to wait for response headers when proxying the request")
var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app")
func main() {
......@@ -81,9 +80,14 @@ func main() {
}()
}
up := upstream.New(authBackend, *authSocket, Version, *responseHeadersTimeout)
up.DocumentRoot = *documentRoot
up.DevelopmentMode = *developmentMode
up := &upstream.Upstream{
Backend: authBackend,
Socket: *authSocket,
Version: Version,
ResponseHeaderTimeout: *responseHeadersTimeout,
DocumentRoot: *documentRoot,
DevelopmentMode: *developmentMode,
}
log.Fatal(http.Serve(listener, up))
}
......@@ -311,7 +311,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
}
func startWorkhorseServer(authBackend string) *httptest.Server {
u := upstream.New(helper.URLMustParse(authBackend), "", "123", time.Second)
u := &upstream.Upstream{Backend: helper.URLMustParse(authBackend), Version: "123"}
return httptest.NewServer(u)
}
......
......@@ -16,7 +16,7 @@ import (
)
func newUpstream(url string) *upstream.Upstream {
return upstream.New(helper.URLMustParse(url), "", "123", time.Second)
return &upstream.Upstream{Backend: helper.URLMustParse(url), Version: "123"}
}
func TestProxyRequest(t *testing.T) {
......@@ -48,7 +48,7 @@ func TestProxyRequest(t *testing.T) {
u := newUpstream(ts.URL)
w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest)
u.Proxy().ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 202)
helper.AssertResponseBody(t, w, "RESPONSE")
......@@ -66,7 +66,7 @@ func TestProxyError(t *testing.T) {
u := newUpstream("http://localhost:655575/")
w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest)
u.Proxy().ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 502)
helper.AssertResponseBody(t, w, "dial tcp: invalid port 655575")
}
......@@ -81,7 +81,7 @@ func TestProxyReadTimeout(t *testing.T) {
t.Fatal(err)
}
transport := proxy.NewRoundTripper(
transport := &proxy.RoundTripper{
&http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
......@@ -91,7 +91,7 @@ func TestProxyReadTimeout(t *testing.T) {
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: time.Millisecond,
},
)
}
p := &proxy.Proxy{URL: helper.URLMustParse(ts.URL), Transport: transport, Version: "123"}
......@@ -116,7 +116,7 @@ func TestProxyHandlerTimeout(t *testing.T) {
u := newUpstream(ts.URL)
w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest)
u.Proxy().ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 503)
helper.AssertResponseBody(t, w, "Request took too long")
}
......@@ -2,7 +2,6 @@ package main
import (
"flag"
"log"
"net/url"
)
......@@ -17,12 +16,8 @@ func (u *urlFlag) Set(s string) error {
return nil
}
func URLFlag(name string, value string, usage string) *url.URL {
u, err := url.Parse(value)
if err != nil {
log.Fatalf("URLFlag: invalid default: %q %v", value, err)
}
f := urlFlag{u}
func URLFlag(name string, value *url.URL, usage string) *url.URL {
f := urlFlag{value}
flag.CommandLine.Var(&f, name, usage)
return f.URL
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment