Commit ae66b95f authored by Jacob Vosmaer's avatar Jacob Vosmaer

Make authBackend a *url.URL

parent ede3979f
...@@ -27,7 +27,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api ...@@ -27,7 +27,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
api := upstream.New(ts.URL, "", "123", time.Second).API api := upstream.New(helper.URLMustParse(ts.URL), "", "123", time.Second).API
response := httptest.NewRecorder() response := httptest.NewRecorder()
api.PreAuthorizeHandler(okHandler, suffix)(response, httpRequest) api.PreAuthorizeHandler(okHandler, suffix)(response, httpRequest)
......
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
type API struct { type API struct {
*http.Client *http.Client
*url.URL URL *url.URL
Version string Version string
} }
...@@ -49,7 +49,7 @@ type Response struct { ...@@ -49,7 +49,7 @@ type Response struct {
} }
func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) { func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) {
url := *api.URL url := *api.URL // Make a copy of api.URL
url.Path = r.URL.RequestURI() + suffix url.Path = r.URL.RequestURI() + suffix
authReq := &http.Request{ authReq := &http.Request{
Method: r.Method, Method: r.Method,
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"errors" "errors"
"log" "log"
"net/http" "net/http"
"net/url"
"os" "os"
) )
...@@ -51,3 +52,11 @@ func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) { ...@@ -51,3 +52,11 @@ func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) {
return return
} }
func URLMustParse(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
log.Fatalf("urlMustParse: %q %v", s, err)
}
return u
}
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
...@@ -13,7 +12,7 @@ import ( ...@@ -13,7 +12,7 @@ import (
) )
type Proxy struct { type Proxy struct {
URL string URL *url.URL
Version string Version string
Transport http.RoundTripper Transport http.RoundTripper
_reverseProxy *httputil.ReverseProxy _reverseProxy *httputil.ReverseProxy
...@@ -26,13 +25,9 @@ func (p *Proxy) reverseProxy() *httputil.ReverseProxy { ...@@ -26,13 +25,9 @@ func (p *Proxy) reverseProxy() *httputil.ReverseProxy {
} }
func (p *Proxy) configureReverseProxy() { func (p *Proxy) configureReverseProxy() {
// Modify a copy of url u := *p.URL // Make a copy of p.URL
url, err := url.Parse(p.URL) u.Path = ""
if err != nil { p._reverseProxy = httputil.NewSingleHostReverseProxy(&u)
log.Fatalf("configureReverseProxy: %v", err)
}
url.Path = ""
p._reverseProxy = httputil.NewSingleHostReverseProxy(url)
p._reverseProxy.Transport = p.Transport p._reverseProxy.Transport = p.Transport
} }
......
...@@ -57,7 +57,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -57,7 +57,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
httpRequest.Header.Set(tempPathHeader, tempPath) httpRequest.Header.Set(tempPathHeader, tempPath)
handleFileUploads(&proxy.Proxy{URL: ts.URL, Version: "123"}).ServeHTTP(response, httpRequest) handleFileUploads(&proxy.Proxy{URL: helper.URLMustParse(ts.URL), Version: "123"}).ServeHTTP(response, httpRequest)
helper.AssertResponseCode(t, response, 202) helper.AssertResponseCode(t, response, 202)
if response.Body.String() != "RESPONSE" { if response.Body.String() != "RESPONSE" {
t.Fatal("Expected RESPONSE in response body") t.Fatal("Expected RESPONSE in response body")
...@@ -131,7 +131,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -131,7 +131,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Header.Set(tempPathHeader, tempPath) httpRequest.Header.Set(tempPathHeader, tempPath)
response := httptest.NewRecorder() response := httptest.NewRecorder()
handleFileUploads(&proxy.Proxy{URL: ts.URL, Version: "123"}).ServeHTTP(response, httpRequest) handleFileUploads(&proxy.Proxy{URL: helper.URLMustParse(ts.URL), Version: "123"}).ServeHTTP(response, httpRequest)
helper.AssertResponseCode(t, response, 202) helper.AssertResponseCode(t, response, 202)
if _, err := os.Stat(filePath); !os.IsNotExist(err) { if _, err := os.Stat(filePath); !os.IsNotExist(err) {
......
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"../api" "../api"
"../proxy" "../proxy"
"fmt" "fmt"
"log"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
...@@ -30,13 +29,8 @@ type Upstream struct { ...@@ -30,13 +29,8 @@ type Upstream struct {
routes []route routes []route
} }
func New(authBackend string, authSocket string, version string, responseHeadersTimeout time.Duration) *Upstream { func New(authBackend *url.URL, authSocket string, version string, responseHeadersTimeout time.Duration) *Upstream {
parsedURL, err := url.Parse(authBackend) relativeURLRoot := authBackend.Path
if err != nil {
log.Fatalln(err)
}
relativeURLRoot := parsedURL.Path
if !strings.HasSuffix(relativeURLRoot, "/") { if !strings.HasSuffix(relativeURLRoot, "/") {
relativeURLRoot += "/" relativeURLRoot += "/"
} }
...@@ -61,7 +55,7 @@ func New(authBackend string, authSocket string, version string, responseHeadersT ...@@ -61,7 +55,7 @@ func New(authBackend string, authSocket string, version string, responseHeadersT
up := &Upstream{ up := &Upstream{
API: &api.API{ API: &api.API{
Client: &http.Client{Transport: proxyTransport}, Client: &http.Client{Transport: proxyTransport},
URL: parsedURL, URL: authBackend,
Version: version, Version: version,
}, },
Proxy: &proxy.Proxy{URL: authBackend, Transport: proxyTransport, Version: version}, Proxy: &proxy.Proxy{URL: authBackend, Transport: proxyTransport, Version: version},
......
...@@ -33,7 +33,7 @@ var printVersion = flag.Bool("version", false, "Print version and exit") ...@@ -33,7 +33,7 @@ var printVersion = flag.Bool("version", false, "Print version and exit")
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server") var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)") var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022") var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022")
var authBackend = flag.String("authBackend", "http://localhost:8080", "Authentication/authorization backend") var authBackend = URLFlag("authBackend", "http://localhost:8080", "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content") var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
...@@ -81,7 +81,7 @@ func main() { ...@@ -81,7 +81,7 @@ func main() {
}() }()
} }
up := upstream.New(*authBackend, *authSocket, Version, *responseHeadersTimeout) up := upstream.New(authBackend, *authSocket, Version, *responseHeadersTimeout)
up.DocumentRoot = *documentRoot up.DocumentRoot = *documentRoot
up.DevelopmentMode = *developmentMode up.DevelopmentMode = *developmentMode
......
...@@ -311,7 +311,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se ...@@ -311,7 +311,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
} }
func startWorkhorseServer(authBackend string) *httptest.Server { func startWorkhorseServer(authBackend string) *httptest.Server {
u := upstream.New(authBackend, "", "123", time.Second) u := upstream.New(helper.URLMustParse(authBackend), "", "123", time.Second)
return httptest.NewServer(u) return httptest.NewServer(u)
} }
......
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
) )
func newUpstream(url string) *upstream.Upstream { func newUpstream(url string) *upstream.Upstream {
return upstream.New(url, "", "123", time.Second) return upstream.New(helper.URLMustParse(url), "", "123", time.Second)
} }
func TestProxyRequest(t *testing.T) { func TestProxyRequest(t *testing.T) {
...@@ -93,7 +93,7 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -93,7 +93,7 @@ func TestProxyReadTimeout(t *testing.T) {
}, },
) )
p := &proxy.Proxy{URL: ts.URL, Transport: transport, Version: "123"} p := &proxy.Proxy{URL: helper.URLMustParse(ts.URL), Transport: transport, Version: "123"}
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.ServeHTTP(w, httpRequest) p.ServeHTTP(w, httpRequest)
......
package main
import (
"flag"
"log"
"net/url"
)
type urlFlag struct{ *url.URL }
func (u *urlFlag) Set(s string) error {
myURL, err := url.Parse(s)
if err != nil {
return err
}
u.URL = myURL
return nil
}
func URLFlag(name string, value string, usage string) *url.URL {
u, err := url.Parse(value)
if err != nil {
log.Fatalf("URLFlag: invalid default: %q %v", value, err)
}
f := urlFlag{u}
flag.CommandLine.Var(&f, name, usage)
return f.URL
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment