Commit 8e9a6da2 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Add badgateway.RoundTripper type

parent d89b378e
package badgateway
import (
"../helper"
"bytes"
"fmt"
"io/ioutil"
"net"
"net/http"
"sync"
"time"
)
// Values from http.DefaultTransport
var DefaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
var DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport
Dial: DefaultDialer.Dial, // from http.DefaultTransport
ResponseHeaderTimeout: time.Minute, // custom
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
type RoundTripper struct {
Socket string
ResponseHeaderTimeout time.Duration
Transport *http.Transport
configureRoundTripperOnce sync.Once
}
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
t.configureRoundTripperOnce.Do(t.configureRoundTripper)
res, err = t.Transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error
// is that the Rails app is not responding, in which case users
// and administrators expect to see a 502 error. To show 502s
// instead of 500s we catch the RoundTrip error here and inject a
// 502 response.
if err != nil {
helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
res = &http.Response{
StatusCode: http.StatusBadGateway,
Status: http.StatusText(http.StatusBadGateway),
Request: r,
ProtoMajor: r.ProtoMajor,
ProtoMinor: r.ProtoMinor,
Proto: r.Proto,
Header: make(http.Header),
Trailer: make(http.Header),
Body: ioutil.NopCloser(bytes.NewBufferString(err.Error())),
}
res.Header.Set("Content-Type", "text/plain")
err = nil
}
return
}
func (t *RoundTripper) configureRoundTripper() {
if t.Transport != nil {
return
}
tr := *DefaultTransport
if t.ResponseHeaderTimeout != 0 {
tr.ResponseHeaderTimeout = t.ResponseHeaderTimeout
}
if t.Socket != "" {
tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", t.Socket)
}
}
t.Transport = &tr
}
package proxy package proxy
import ( import (
"../helper" "../badgateway"
"bytes"
"fmt"
"io/ioutil"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
...@@ -14,55 +11,23 @@ import ( ...@@ -14,55 +11,23 @@ import (
type Proxy struct { type Proxy struct {
URL *url.URL URL *url.URL
Version string Version string
Transport http.RoundTripper RoundTripper *badgateway.RoundTripper
_reverseProxy *httputil.ReverseProxy _reverseProxy *httputil.ReverseProxy
configureReverseProxyOnce sync.Once configureReverseProxyOnce sync.Once
} }
func (p *Proxy) reverseProxy() *httputil.ReverseProxy { func (p *Proxy) reverseProxy() *httputil.ReverseProxy {
p.configureReverseProxyOnce.Do(p.configureReverseProxy) p.configureReverseProxyOnce.Do(func() {
return p._reverseProxy
}
func (p *Proxy) configureReverseProxy() {
u := *p.URL // Make a copy of p.URL u := *p.URL // Make a copy of p.URL
u.Path = "" u.Path = ""
p._reverseProxy = httputil.NewSingleHostReverseProxy(&u) p._reverseProxy = httputil.NewSingleHostReverseProxy(&u)
p._reverseProxy.Transport = p.Transport if p.RoundTripper != nil {
} p._reverseProxy.Transport = p.RoundTripper
} else {
type RoundTripper struct { p._reverseProxy.Transport = &badgateway.RoundTripper{}
Transport http.RoundTripper
}
func (rt *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = rt.Transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error
// is that the Rails app is not responding, in which case users
// and administrators expect to see a 502 error. To show 502s
// instead of 500s we catch the RoundTrip error here and inject a
// 502 response.
if err != nil {
helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
res = &http.Response{
StatusCode: http.StatusBadGateway,
Status: http.StatusText(http.StatusBadGateway),
Request: r,
ProtoMajor: r.ProtoMajor,
ProtoMinor: r.ProtoMinor,
Proto: r.Proto,
Header: make(http.Header),
Trailer: make(http.Header),
Body: ioutil.NopCloser(bytes.NewBufferString(err.Error())),
}
res.Header.Set("Content-Type", "text/plain")
err = nil
} }
return })
return p._reverseProxy
} }
func HeaderClone(h http.Header) http.Header { func HeaderClone(h http.Header) http.Header {
......
...@@ -3,6 +3,7 @@ package upstream ...@@ -3,6 +3,7 @@ package upstream
import ( import (
"../git" "../git"
"../lfs" "../lfs"
pr "../proxy"
"../staticpages" "../staticpages"
"../upload" "../upload"
"net/http" "net/http"
...@@ -34,12 +35,13 @@ func (u *Upstream) Routes() []route { ...@@ -34,12 +35,13 @@ func (u *Upstream) Routes() []route {
func (u *Upstream) configureRoutes() { func (u *Upstream) configureRoutes() {
static := &staticpages.Static{u.DocumentRoot} static := &staticpages.Static{u.DocumentRoot}
proxy := &pr.Proxy{URL: u.Backend, Version: u.Version, RoundTripper: u.RoundTripper()}
u.routes = []route{ u.routes = []route{
// Git Clone // Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API())}, route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API())},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))}, route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))}, route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API(), u.Proxy())}, route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API(), proxy)},
// Repository Archive // Repository Archive
route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(u.API())}, route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(u.API())},
...@@ -56,17 +58,17 @@ func (u *Upstream) configureRoutes() { ...@@ -56,17 +58,17 @@ func (u *Upstream) configureRoutes() {
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API())}, route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API())},
// CI Artifacts API // CI Artifacts API
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API(), u.Proxy()))}, route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API(), proxy))},
// Explicitly proxy API requests // Explicitly proxy API requests
route{"", regexp.MustCompile(apiPattern), u.Proxy()}, route{"", regexp.MustCompile(apiPattern), proxy},
route{"", regexp.MustCompile(ciAPIPattern), u.Proxy()}, route{"", regexp.MustCompile(ciAPIPattern), proxy},
// Serve assets // Serve assets
route{"", regexp.MustCompile(`^/assets/`), route{"", regexp.MustCompile(`^/assets/`),
static.ServeExisting(u.URLPrefix(), staticpages.CacheExpireMax, static.ServeExisting(u.URLPrefix(), staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode, NotFoundUnless(u.DevelopmentMode,
u.Proxy(), proxy,
), ),
), ),
}, },
...@@ -76,7 +78,7 @@ func (u *Upstream) configureRoutes() { ...@@ -76,7 +78,7 @@ func (u *Upstream) configureRoutes() {
static.ServeExisting(u.URLPrefix(), staticpages.CacheDisabled, static.ServeExisting(u.URLPrefix(), staticpages.CacheDisabled,
static.DeployPage( static.DeployPage(
static.ErrorPages( static.ErrorPages(
u.Proxy(), proxy,
), ),
), ),
), ),
......
package upstream
import (
"../proxy"
"net"
"net/http"
"time"
)
// Values from http.DefaultTransport
var DefaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
var DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport
Dial: DefaultDialer.Dial, // from http.DefaultTransport
ResponseHeaderTimeout: time.Minute, // custom
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
func (u *Upstream) Transport() http.RoundTripper {
u.configureTransportOnce.Do(u.configureTransport)
return u.transport
}
func (u *Upstream) configureTransport() {
t := *DefaultTransport
if u.ResponseHeaderTimeout != 0 {
t.ResponseHeaderTimeout = u.ResponseHeaderTimeout
}
if u.Socket != "" {
t.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", u.Socket)
}
}
u.transport = &proxy.RoundTripper{&t}
}
...@@ -8,6 +8,7 @@ package upstream ...@@ -8,6 +8,7 @@ package upstream
import ( import (
"../api" "../api"
"../badgateway"
"../helper" "../helper"
"../proxy" "../proxy"
"../staticpages" "../staticpages"
...@@ -42,22 +43,13 @@ type Upstream struct { ...@@ -42,22 +43,13 @@ type Upstream struct {
routes []route routes []route
configureRoutesOnce sync.Once configureRoutesOnce sync.Once
transport http.RoundTripper roundtripper *badgateway.RoundTripper
configureTransportOnce sync.Once configureRoundTripperOnce sync.Once
_static *staticpages.Static _static *staticpages.Static
configureStaticOnce sync.Once configureStaticOnce sync.Once
} }
func (u *Upstream) Proxy() *proxy.Proxy {
u.configureProxyOnce.Do(u.configureProxy)
return u._proxy
}
func (u *Upstream) configureProxy() {
u._proxy = &proxy.Proxy{URL: u.Backend, Transport: u.Transport(), Version: u.Version}
}
func (u *Upstream) API() *api.API { func (u *Upstream) API() *api.API {
u.configureAPIOnce.Do(u.configureAPI) u.configureAPIOnce.Do(u.configureAPI)
return u._api return u._api
...@@ -65,7 +57,7 @@ func (u *Upstream) API() *api.API { ...@@ -65,7 +57,7 @@ func (u *Upstream) API() *api.API {
func (u *Upstream) configureAPI() { func (u *Upstream) configureAPI() {
u._api = &api.API{ u._api = &api.API{
Client: &http.Client{Transport: u.Transport()}, Client: &http.Client{Transport: u.RoundTripper()},
URL: u.Backend, URL: u.Backend,
Version: u.Version, Version: u.Version,
} }
...@@ -87,12 +79,15 @@ func (u *Upstream) configureURLPrefix() { ...@@ -87,12 +79,15 @@ func (u *Upstream) configureURLPrefix() {
u.urlPrefix = urlprefix.Prefix(relativeURLRoot) u.urlPrefix = urlprefix.Prefix(relativeURLRoot)
} }
// func (u *Upstream) Static() *static.Static { func (u *Upstream) RoundTripper() *badgateway.RoundTripper {
// u.configureStaticOnce.Do(func() { u.configureRoundTripperOnce.Do(func() {
// u._static = &static.Static{u.DocumentRoot} u.roundtripper = &badgateway.RoundTripper{
// }) Socket: u.Socket,
// return u._static ResponseHeaderTimeout: u.ResponseHeaderTimeout,
// } }
})
return u.roundtripper
}
func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
w := newLoggingResponseWriter(ow) w := newLoggingResponseWriter(ow)
......
...@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type. ...@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type.
package main package main
import ( import (
"./internal/badgateway"
"./internal/upstream" "./internal/upstream"
"flag" "flag"
"fmt" "fmt"
...@@ -36,7 +37,7 @@ var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authenticatio ...@@ -36,7 +37,7 @@ var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authenticatio
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content") var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", upstream.DefaultTransport.ResponseHeaderTimeout, "How long to wait for response headers when proxying the request") var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", badgateway.DefaultTransport.ResponseHeaderTimeout, "How long to wait for response headers when proxying the request")
var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app") var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app")
func main() { func main() {
......
package main package main
import ( import (
"./internal/badgateway"
"./internal/helper" "./internal/helper"
"./internal/proxy" "./internal/proxy"
"./internal/upstream"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
...@@ -15,8 +15,8 @@ import ( ...@@ -15,8 +15,8 @@ import (
"time" "time"
) )
func newUpstream(url string) *upstream.Upstream { func newProxy(url string) *proxy.Proxy {
return &upstream.Upstream{Backend: helper.URLMustParse(url), Version: "123"} return &proxy.Proxy{URL: helper.URLMustParse(url), Version: "123"}
} }
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
...@@ -46,9 +46,8 @@ func TestProxyRequest(t *testing.T) { ...@@ -46,9 +46,8 @@ func TestProxyRequest(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
u := newUpstream(ts.URL)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy().ServeHTTP(w, httpRequest) newProxy(ts.URL).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 202) helper.AssertResponseCode(t, w, 202)
helper.AssertResponseBody(t, w, "RESPONSE") helper.AssertResponseBody(t, w, "RESPONSE")
...@@ -64,9 +63,8 @@ func TestProxyError(t *testing.T) { ...@@ -64,9 +63,8 @@ func TestProxyError(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
u := newUpstream("http://localhost:655575/")
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy().ServeHTTP(w, httpRequest) newProxy("http://localhost:655575/").ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 502) helper.AssertResponseCode(t, w, 502)
helper.AssertResponseBody(t, w, "dial tcp: invalid port 655575") helper.AssertResponseBody(t, w, "dial tcp: invalid port 655575")
} }
...@@ -81,8 +79,8 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -81,8 +79,8 @@ func TestProxyReadTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
transport := &proxy.RoundTripper{ rt := &badgateway.RoundTripper{
&http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
...@@ -93,8 +91,8 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -93,8 +91,8 @@ func TestProxyReadTimeout(t *testing.T) {
}, },
} }
p := &proxy.Proxy{URL: helper.URLMustParse(ts.URL), Transport: transport, Version: "123"} p := newProxy(ts.URL)
p.RoundTripper = rt
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.ServeHTTP(w, httpRequest) p.ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 502) helper.AssertResponseCode(t, w, 502)
...@@ -113,10 +111,8 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -113,10 +111,8 @@ func TestProxyHandlerTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
u := newUpstream(ts.URL)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy().ServeHTTP(w, httpRequest) newProxy(ts.URL).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 503) helper.AssertResponseCode(t, w, 503)
helper.AssertResponseBody(t, w, "Request took too long") helper.AssertResponseBody(t, w, "Request took too long")
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment