Commit 41afa01c authored by Jacob Vosmaer (GitLab)'s avatar Jacob Vosmaer (GitLab)

Merge branch 'urlparser' into 'master'

Fix backend URL parsing

Our fancy custom flag parser broke in Go 1.7. Also, the new 'strict
TCP connection' feature was not handling backend hosts like
`localhost:3000` (it was panicking).

Fixes:

https://gitlab.com/gitlab-org/gitlab-workhorse/issues/54
https://gitlab.com/gitlab-org/gitlab-workhorse/issues/53

See merge request !61
parents cdcabf45 7bd9cde0
package main
import (
"fmt"
"net/url"
)
func parseAuthBackend(authBackend string) (*url.URL, error) {
backendURL, err := url.Parse(authBackend)
if err != nil {
return nil, err
}
if backendURL.Host == "" {
backendURL, err = url.Parse("http://" + authBackend)
if err != nil {
return nil, err
}
}
if backendURL.Scheme != "http" {
return nil, fmt.Errorf("invalid scheme, only 'http' is allowed: %q", authBackend)
}
if backendURL.Host == "" {
return nil, fmt.Errorf("missing host in %q", authBackend)
}
return backendURL, nil
}
package main
import (
"testing"
)
func TestParseAuthBackend(t *testing.T) {
failures := []string{
"",
"ftp://localhost",
"https://example.com",
}
for _, example := range failures {
if _, err := parseAuthBackend(example); err == nil {
t.Errorf("error expected for %q", example)
}
}
successes := []struct{ input, host, scheme string }{
{"http://localhost:8080", "localhost:8080", "http"},
{"localhost:3000", "localhost:3000", "http"},
{"http://localhost", "localhost", "http"},
{"localhost", "localhost", "http"},
}
for _, example := range successes {
result, err := parseAuthBackend(example.input)
if err != nil {
t.Errorf("parse %q: %v", example.input, err)
break
}
if result.Host != example.host {
t.Errorf("example %q: expected %q, got %q", example.input, example.host, result.Host)
}
if result.Scheme != example.scheme {
t.Errorf("example %q: expected %q, got %q", example.input, example.scheme, result.Scheme)
}
}
}
...@@ -64,7 +64,7 @@ func mustParseAddress(address, scheme string) string { ...@@ -64,7 +64,7 @@ func mustParseAddress(address, scheme string) string {
} }
} }
panic("could not parse host:port from address and scheme") panic(fmt.Errorf("could not parse host:port from address %q and scheme %q", address, scheme))
} }
func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { func (t *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
......
...@@ -34,7 +34,7 @@ var printVersion = flag.Bool("version", false, "Print version and exit") ...@@ -34,7 +34,7 @@ var printVersion = flag.Bool("version", false, "Print version and exit")
var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server") var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server")
var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)") var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)")
var listenUmask = flag.Int("listenUmask", 0, "Umask for Unix socket") var listenUmask = flag.Int("listenUmask", 0, "Umask for Unix socket")
var authBackend = URLFlag("authBackend", upstream.DefaultBackend, "Authentication/authorization backend") var authBackend = flag.String("authBackend", upstream.DefaultBackend.String(), "Authentication/authorization backend")
var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at")
var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'")
var documentRoot = flag.String("documentRoot", "public", "Path to static files content") var documentRoot = flag.String("documentRoot", "public", "Path to static files content")
...@@ -55,6 +55,12 @@ func main() { ...@@ -55,6 +55,12 @@ func main() {
os.Exit(0) os.Exit(0)
} }
backendURL, err := parseAuthBackend(*authBackend)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid authBackend: %v\n", err)
os.Exit(1)
}
log.Printf("Starting %s", version) log.Printf("Starting %s", version)
// Good housekeeping for Unix sockets: unlink before binding // Good housekeeping for Unix sockets: unlink before binding
...@@ -83,7 +89,7 @@ func main() { ...@@ -83,7 +89,7 @@ func main() {
} }
up := upstream.NewUpstream( up := upstream.NewUpstream(
*authBackend, backendURL,
*authSocket, *authSocket,
Version, Version,
*documentRoot, *documentRoot,
......
package main
import (
"flag"
"net/url"
)
type urlFlag struct {
*url.URL
}
func (u *urlFlag) Set(s string) error {
myURL, err := url.Parse(s)
if err != nil {
return err
}
u.URL = myURL
return nil
}
func URLFlag(name string, value *url.URL, usage string) **url.URL {
f := &urlFlag{value}
flag.CommandLine.Var(f, name, usage)
return &f.URL
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment