Commit 8e9a6da2 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Add badgateway.RoundTripper type

parent d89b378e
package badgateway
import (
"../helper"
"bytes"
"fmt"
"io/ioutil"
"net"
"net/http"
"sync"
"time"
)
// Values from http.DefaultTransport
var DefaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
var DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport
Dial: DefaultDialer.Dial, // from http.DefaultTransport
ResponseHeaderTimeout: time.Minute, // custom
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
type RoundTripper struct {
Socket string
ResponseHeaderTimeout time.Duration
Transport *http.Transport
configureRoundTripperOnce sync.Once
}
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
t.configureRoundTripperOnce.Do(t.configureRoundTripper)
res, err = t.Transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error
// is that the Rails app is not responding, in which case users
// and administrators expect to see a 502 error. To show 502s
// instead of 500s we catch the RoundTrip error here and inject a
// 502 response.
if err != nil {
helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
res = &http.Response{
StatusCode: http.StatusBadGateway,
Status: http.StatusText(http.StatusBadGateway),
Request: r,
ProtoMajor: r.ProtoMajor,
ProtoMinor: r.ProtoMinor,
Proto: r.Proto,
Header: make(http.Header),
Trailer: make(http.Header),
Body: ioutil.NopCloser(bytes.NewBufferString(err.Error())),
}
res.Header.Set("Content-Type", "text/plain")
err = nil
}
return
}
func (t *RoundTripper) configureRoundTripper() {
if t.Transport != nil {
return
}
tr := *DefaultTransport
if t.ResponseHeaderTimeout != 0 {
tr.ResponseHeaderTimeout = t.ResponseHeaderTimeout
}
if t.Socket != "" {
tr.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", t.Socket)
}
}
t.Transport = &tr
}
package proxy
import (
"../helper"
"bytes"
"fmt"
"io/ioutil"
"../badgateway"
"net/http"
"net/http/httputil"
"net/url"
......@@ -14,55 +11,23 @@ import (
type Proxy struct {
URL *url.URL
Version string
Transport http.RoundTripper
RoundTripper *badgateway.RoundTripper
_reverseProxy *httputil.ReverseProxy
configureReverseProxyOnce sync.Once
}
func (p *Proxy) reverseProxy() *httputil.ReverseProxy {
p.configureReverseProxyOnce.Do(p.configureReverseProxy)
return p._reverseProxy
}
func (p *Proxy) configureReverseProxy() {
u := *p.URL // Make a copy of p.URL
u.Path = ""
p._reverseProxy = httputil.NewSingleHostReverseProxy(&u)
p._reverseProxy.Transport = p.Transport
}
type RoundTripper struct {
Transport http.RoundTripper
}
func (rt *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = rt.Transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error
// is that the Rails app is not responding, in which case users
// and administrators expect to see a 502 error. To show 502s
// instead of 500s we catch the RoundTrip error here and inject a
// 502 response.
if err != nil {
helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
res = &http.Response{
StatusCode: http.StatusBadGateway,
Status: http.StatusText(http.StatusBadGateway),
Request: r,
ProtoMajor: r.ProtoMajor,
ProtoMinor: r.ProtoMinor,
Proto: r.Proto,
Header: make(http.Header),
Trailer: make(http.Header),
Body: ioutil.NopCloser(bytes.NewBufferString(err.Error())),
p.configureReverseProxyOnce.Do(func() {
u := *p.URL // Make a copy of p.URL
u.Path = ""
p._reverseProxy = httputil.NewSingleHostReverseProxy(&u)
if p.RoundTripper != nil {
p._reverseProxy.Transport = p.RoundTripper
} else {
p._reverseProxy.Transport = &badgateway.RoundTripper{}
}
res.Header.Set("Content-Type", "text/plain")
err = nil
}
return
})
return p._reverseProxy
}
func HeaderClone(h http.Header) http.Header {
......
......@@ -3,6 +3,7 @@ package upstream
import (
"../git"
"../lfs"
pr "../proxy"
"../staticpages"
"../upload"
"net/http"
......@@ -34,12 +35,13 @@ func (u *Upstream) Routes() []route {
func (u *Upstream) configureRoutes() {
static := &staticpages.Static{u.DocumentRoot}
proxy := &pr.Proxy{URL: u.Backend, Version: u.Version, RoundTripper: u.RoundTripper()}
u.routes = []route{
// Git Clone
route{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), git.GetInfoRefs(u.API())},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))},
route{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), contentEncodingHandler(git.PostRPC(u.API()))},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API(), u.Proxy())},
route{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfs.PutStore(u.API(), proxy)},
// Repository Archive
route{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), git.GetArchive(u.API())},
......@@ -56,17 +58,17 @@ func (u *Upstream) configureRoutes() {
route{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), git.GetArchive(u.API())},
// CI Artifacts API
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API(), u.Proxy()))},
route{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), contentEncodingHandler(upload.Artifacts(u.API(), proxy))},
// Explicitly proxy API requests
route{"", regexp.MustCompile(apiPattern), u.Proxy()},
route{"", regexp.MustCompile(ciAPIPattern), u.Proxy()},
route{"", regexp.MustCompile(apiPattern), proxy},
route{"", regexp.MustCompile(ciAPIPattern), proxy},
// Serve assets
route{"", regexp.MustCompile(`^/assets/`),
static.ServeExisting(u.URLPrefix(), staticpages.CacheExpireMax,
NotFoundUnless(u.DevelopmentMode,
u.Proxy(),
proxy,
),
),
},
......@@ -76,7 +78,7 @@ func (u *Upstream) configureRoutes() {
static.ServeExisting(u.URLPrefix(), staticpages.CacheDisabled,
static.DeployPage(
static.ErrorPages(
u.Proxy(),
proxy,
),
),
),
......
package upstream
import (
"../proxy"
"net"
"net/http"
"time"
)
// Values from http.DefaultTransport
var DefaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
var DefaultTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment, // from http.DefaultTransport
Dial: DefaultDialer.Dial, // from http.DefaultTransport
ResponseHeaderTimeout: time.Minute, // custom
TLSHandshakeTimeout: 10 * time.Second, // from http.DefaultTransport
}
func (u *Upstream) Transport() http.RoundTripper {
u.configureTransportOnce.Do(u.configureTransport)
return u.transport
}
func (u *Upstream) configureTransport() {
t := *DefaultTransport
if u.ResponseHeaderTimeout != 0 {
t.ResponseHeaderTimeout = u.ResponseHeaderTimeout
}
if u.Socket != "" {
t.Dial = func(_, _ string) (net.Conn, error) {
return DefaultDialer.Dial("unix", u.Socket)
}
}
u.transport = &proxy.RoundTripper{&t}
}
......@@ -8,6 +8,7 @@ package upstream
import (
"../api"
"../badgateway"
"../helper"
"../proxy"
"../staticpages"
......@@ -42,22 +43,13 @@ type Upstream struct {
routes []route
configureRoutesOnce sync.Once
transport http.RoundTripper
configureTransportOnce sync.Once
roundtripper *badgateway.RoundTripper
configureRoundTripperOnce sync.Once
_static *staticpages.Static
configureStaticOnce sync.Once
}
func (u *Upstream) Proxy() *proxy.Proxy {
u.configureProxyOnce.Do(u.configureProxy)
return u._proxy
}
func (u *Upstream) configureProxy() {
u._proxy = &proxy.Proxy{URL: u.Backend, Transport: u.Transport(), Version: u.Version}
}
func (u *Upstream) API() *api.API {
u.configureAPIOnce.Do(u.configureAPI)
return u._api
......@@ -65,7 +57,7 @@ func (u *Upstream) API() *api.API {
func (u *Upstream) configureAPI() {
u._api = &api.API{
Client: &http.Client{Transport: u.Transport()},
Client: &http.Client{Transport: u.RoundTripper()},
URL: u.Backend,
Version: u.Version,
}
......@@ -87,12 +79,15 @@ func (u *Upstream) configureURLPrefix() {
u.urlPrefix = urlprefix.Prefix(relativeURLRoot)
}
// func (u *Upstream) Static() *static.Static {
// u.configureStaticOnce.Do(func() {
// u._static = &static.Static{u.DocumentRoot}
// })
// return u._static
// }
func (u *Upstream) RoundTripper() *badgateway.RoundTripper {
u.configureRoundTripperOnce.Do(func() {
u.roundtripper = &badgateway.RoundTripper{
Socket: u.Socket,
ResponseHeaderTimeout: u.ResponseHeaderTimeout,
}
})
return u.roundtripper
}
func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) {
w := newLoggingResponseWriter(ow)
......
......@@ -14,6 +14,7 @@ In this file we start the web server and hand off to the upstream type.
package main
import (
"./internal/badgateway"
"./internal/upstream"
"flag"
"fmt"
......@@ -36,7 +37,7 @@ var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authenticatio
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", upstream.DefaultTransport.ResponseHeaderTimeout, "How long to wait for response headers when proxying the request")
var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", badgateway.DefaultTransport.ResponseHeaderTimeout, "How long to wait for response headers when proxying the request")
var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app")
func main() {
......
package main
import (
"./internal/badgateway"
"./internal/helper"
"./internal/proxy"
"./internal/upstream"
"bytes"
"fmt"
"io"
......@@ -15,8 +15,8 @@ import (
"time"
)
func newUpstream(url string) *upstream.Upstream {
return &upstream.Upstream{Backend: helper.URLMustParse(url), Version: "123"}
func newProxy(url string) *proxy.Proxy {
return &proxy.Proxy{URL: helper.URLMustParse(url), Version: "123"}
}
func TestProxyRequest(t *testing.T) {
......@@ -46,9 +46,8 @@ func TestProxyRequest(t *testing.T) {
}
httpRequest.Header.Set("Custom-Header", "test")
u := newUpstream(ts.URL)
w := httptest.NewRecorder()
u.Proxy().ServeHTTP(w, httpRequest)
newProxy(ts.URL).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 202)
helper.AssertResponseBody(t, w, "RESPONSE")
......@@ -64,9 +63,8 @@ func TestProxyError(t *testing.T) {
}
httpRequest.Header.Set("Custom-Header", "test")
u := newUpstream("http://localhost:655575/")
w := httptest.NewRecorder()
u.Proxy().ServeHTTP(w, httpRequest)
newProxy("http://localhost:655575/").ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 502)
helper.AssertResponseBody(t, w, "dial tcp: invalid port 655575")
}
......@@ -81,8 +79,8 @@ func TestProxyReadTimeout(t *testing.T) {
t.Fatal(err)
}
transport := &proxy.RoundTripper{
&http.Transport{
rt := &badgateway.RoundTripper{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
......@@ -93,8 +91,8 @@ func TestProxyReadTimeout(t *testing.T) {
},
}
p := &proxy.Proxy{URL: helper.URLMustParse(ts.URL), Transport: transport, Version: "123"}
p := newProxy(ts.URL)
p.RoundTripper = rt
w := httptest.NewRecorder()
p.ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 502)
......@@ -113,10 +111,8 @@ func TestProxyHandlerTimeout(t *testing.T) {
t.Fatal(err)
}
u := newUpstream(ts.URL)
w := httptest.NewRecorder()
u.Proxy().ServeHTTP(w, httpRequest)
newProxy(ts.URL).ServeHTTP(w, httpRequest)
helper.AssertResponseCode(t, w, 503)
helper.AssertResponseBody(t, w, "Request took too long")
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment