Commit bd748961 authored by Jacob Vosmaer's avatar Jacob Vosmaer

Move Proxy into an internal package

parent 3dd96bd4
...@@ -5,6 +5,7 @@ In this file we handle 'git archive' downloads ...@@ -5,6 +5,7 @@ In this file we handle 'git archive' downloads
package main package main
import ( import (
"./internal/helper"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -103,22 +104,22 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -103,22 +104,22 @@ func handleGetArchive(w http.ResponseWriter, r *http.Request, a *apiResponse) {
setArchiveHeaders(w, format, archiveFilename) setArchiveHeaders(w, format, archiveFilename)
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if _, err := io.Copy(w, archiveReader); err != nil { if _, err := io.Copy(w, archiveReader); err != nil {
logError(fmt.Errorf("handleGetArchive: read: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: read: %v", err))
return return
} }
if err := archiveCmd.Wait(); err != nil { if err := archiveCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: archiveCmd: %v", err))
return return
} }
if compressCmd != nil { if compressCmd != nil {
if err := compressCmd.Wait(); err != nil { if err := compressCmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetArchive: compressCmd: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: compressCmd: %v", err))
return return
} }
} }
if err := finalizeCachedArchive(tempFile, a.ArchivePath); err != nil { if err := finalizeCachedArchive(tempFile, a.ArchivePath); err != nil {
logError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err)) helper.LogError(fmt.Errorf("handleGetArchive: finalize cached archive: %v", err))
return return
} }
} }
......
package main package main
import ( import (
"./internal/proxy"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
...@@ -15,7 +16,7 @@ func (api *API) newUpstreamRequest(r *http.Request, body io.Reader, suffix strin ...@@ -15,7 +16,7 @@ func (api *API) newUpstreamRequest(r *http.Request, body io.Reader, suffix strin
authReq := &http.Request{ authReq := &http.Request{
Method: r.Method, Method: r.Method,
URL: &url, URL: &url,
Header: headerClone(r.Header), Header: proxy.HeaderClone(r.Header),
} }
if body != nil { if body != nil {
authReq.Body = ioutil.NopCloser(body) authReq.Body = ioutil.NopCloser(body)
......
...@@ -23,7 +23,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api ...@@ -23,7 +23,7 @@ func runPreAuthorizeHandler(t *testing.T, suffix string, url *regexp.Regexp, api
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
api := newUpstream(ts.URL, nil).API api := newUpstream(ts.URL, "").API
response := httptest.NewRecorder() response := httptest.NewRecorder()
api.preAuthorizeHandler(okHandler, suffix)(response, httpRequest) api.preAuthorizeHandler(okHandler, suffix)(response, httpRequest)
......
...@@ -5,6 +5,7 @@ In this file we handle the Git 'smart HTTP' protocol ...@@ -5,6 +5,7 @@ In this file we handle the Git 'smart HTTP' protocol
package main package main
import ( import (
"./internal/helper"
"errors" "errors"
"fmt" "fmt"
"io" "io"
...@@ -69,19 +70,19 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -69,19 +70,19 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *apiResponse) {
w.Header().Add("Cache-Control", "no-cache") w.Header().Add("Cache-Control", "no-cache")
w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return
if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil { if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: pktLine: %v", err))
return return
} }
if err := pktFlush(w); err != nil { if err := pktFlush(w); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: pktFlush: %v", err))
return return
} }
if _, err := io.Copy(w, stdout); err != nil { if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: read from %v: %v", cmd.Args, err))
return return
} }
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handleGetInfoRefs: wait for %v: %v", cmd.Args, err))
return return
} }
} }
...@@ -136,11 +137,11 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *apiResponse) { ...@@ -136,11 +137,11 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *apiResponse) {
// This io.Copy may take a long time, both for Git push and pull. // This io.Copy may take a long time, both for Git push and pull.
if _, err := io.Copy(w, stdout); err != nil { if _, err := io.Copy(w, stdout); err != nil {
logError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handlePostRPC read from %v: %v", cmd.Args, err))
return return
} }
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
logError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err)) helper.LogError(fmt.Errorf("handlePostRPC wait for %v: %v", cmd.Args, err))
return return
} }
} }
......
...@@ -5,9 +5,8 @@ Miscellaneous helpers: logging, errors, subprocesses ...@@ -5,9 +5,8 @@ Miscellaneous helpers: logging, errors, subprocesses
package main package main
import ( import (
"errors" "./internal/helper"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
...@@ -17,11 +16,7 @@ import ( ...@@ -17,11 +16,7 @@ import (
func fail500(w http.ResponseWriter, err error) { func fail500(w http.ResponseWriter, err error) {
http.Error(w, "Internal server error", 500) http.Error(w, "Internal server error", 500)
logError(err) helper.LogError(err)
}
func logError(err error) {
log.Printf("error: %v", err)
} }
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) { func httpError(w http.ResponseWriter, r *http.Request, error string, code int) {
...@@ -71,36 +66,6 @@ func setNoCacheHeaders(header http.Header) { ...@@ -71,36 +66,6 @@ func setNoCacheHeaders(header http.Header) {
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
} }
func openFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path)
if err != nil {
return
}
defer func() {
if err != nil {
file.Close()
}
}()
fi, err = file.Stat()
if err != nil {
return
}
// The os.Open can also open directories
if fi.IsDir() {
err = &os.PathError{
Op: "open",
Path: path,
Err: errors.New("path is directory"),
}
return
}
return
}
// Borrowed from: net/http/server.go // Borrowed from: net/http/server.go
// Return the canonical path for p, eliminating . and .. elements. // Return the canonical path for p, eliminating . and .. elements.
func cleanURIPath(p string) string { func cleanURIPath(p string) string {
......
package helper
import (
"errors"
"log"
"os"
)
func LogError(err error) {
log.Printf("error: %v", err)
}
func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path)
if err != nil {
return
}
defer func() {
if err != nil {
file.Close()
}
}()
fi, err = file.Stat()
if err != nil {
return
}
// The os.Open can also open directories
if fi.IsDir() {
err = &os.PathError{
Op: "open",
Path: path,
Err: errors.New("path is directory"),
}
return
}
return
}
package main package proxy
import ( import (
"../helper"
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httputil"
"net/url"
) )
type proxyRoundTripper struct { type Proxy struct {
reverseProxy *httputil.ReverseProxy
version string
}
func NewProxy(url *url.URL, transport http.RoundTripper, version string) *Proxy {
// Modify a copy of url
proxyURL := *url
proxyURL.Path = ""
p := Proxy{reverseProxy: httputil.NewSingleHostReverseProxy(&proxyURL), version: version}
p.reverseProxy.Transport = transport
return &p
}
type RoundTripper struct {
transport http.RoundTripper transport http.RoundTripper
} }
func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { func NewRoundTripper(transport http.RoundTripper) *RoundTripper {
res, err = p.transport.RoundTrip(r) return &RoundTripper{transport: transport}
}
func (rt *RoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) {
res, err = rt.transport.RoundTrip(r)
// httputil.ReverseProxy translates all errors from this // httputil.ReverseProxy translates all errors from this
// RoundTrip function into 500 errors. But the most likely error // RoundTrip function into 500 errors. But the most likely error
...@@ -21,7 +42,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err ...@@ -21,7 +42,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
// instead of 500s we catch the RoundTrip error here and inject a // instead of 500s we catch the RoundTrip error here and inject a
// 502 response. // 502 response.
if err != nil { if err != nil {
logError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err)) helper.LogError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err))
res = &http.Response{ res = &http.Response{
StatusCode: http.StatusBadGateway, StatusCode: http.StatusBadGateway,
...@@ -41,7 +62,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err ...@@ -41,7 +62,7 @@ func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err
return return
} }
func headerClone(h http.Header) http.Header { func HeaderClone(h http.Header) http.Header {
h2 := make(http.Header, len(h)) h2 := make(http.Header, len(h))
for k, vv := range h { for k, vv := range h {
vv2 := make([]string, len(vv)) vv2 := make([]string, len(vv))
...@@ -54,12 +75,12 @@ func headerClone(h http.Header) http.Header { ...@@ -54,12 +75,12 @@ func headerClone(h http.Header) http.Header {
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Clone request // Clone request
req := *r req := *r
req.Header = headerClone(r.Header) req.Header = HeaderClone(r.Header)
// Set Workhorse version // Set Workhorse version
req.Header.Set("Gitlab-Workhorse", Version) req.Header.Set("Gitlab-Workhorse", p.version)
rw := newSendFileResponseWriter(w, &req) rw := newSendFileResponseWriter(w, &req)
defer rw.Flush() defer rw.Flush()
p.ReverseProxy.ServeHTTP(&rw, &req) p.reverseProxy.ServeHTTP(&rw, &req)
} }
...@@ -4,9 +4,10 @@ via the X-Sendfile mechanism. All that is needed in the Rails code is the ...@@ -4,9 +4,10 @@ via the X-Sendfile mechanism. All that is needed in the Rails code is the
'send_file' method. 'send_file' method.
*/ */
package main package proxy
import ( import (
"../helper"
"log" "log"
"net/http" "net/http"
) )
...@@ -63,7 +64,7 @@ func (s *sendFileResponseWriter) WriteHeader(status int) { ...@@ -63,7 +64,7 @@ func (s *sendFileResponseWriter) WriteHeader(status int) {
// Serve the file // Serve the file
log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI) log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI)
content, fi, err := openFile(file) content, fi, err := helper.OpenFile(file)
if err != nil { if err != nil {
http.NotFound(s.rw, s.req) http.NotFound(s.rw, s.req)
return return
......
...@@ -153,23 +153,6 @@ func main() { ...@@ -153,23 +153,6 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
// Create Proxy Transport
authTransport := http.DefaultTransport
if *authSocket != "" {
dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
authTransport = &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", *authSocket)
},
ResponseHeaderTimeout: *responseHeadersTimeout,
}
}
proxyTransport := &proxyRoundTripper{transport: authTransport}
// The profiler will only be activated by HTTP requests. HTTP // The profiler will only be activated by HTTP requests. HTTP
// requests can only reach the profiler if we start a listener. So by // requests can only reach the profiler if we start a listener. So by
// having no profiler HTTP listener by default, the profiler is // having no profiler HTTP listener by default, the profiler is
...@@ -180,7 +163,7 @@ func main() { ...@@ -180,7 +163,7 @@ func main() {
}() }()
} }
upstream := newUpstream(*authBackend, proxyTransport) upstream := newUpstream(*authBackend, *authSocket)
compileRoutes(upstream) compileRoutes(upstream)
log.Fatal(http.Serve(listener, upstream)) log.Fatal(http.Serve(listener, upstream))
} }
...@@ -326,7 +326,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se ...@@ -326,7 +326,7 @@ func testAuthServer(url *regexp.Regexp, code int, body interface{}) *httptest.Se
} }
func startWorkhorseServer(authBackend string) *httptest.Server { func startWorkhorseServer(authBackend string) *httptest.Server {
u := newUpstream(authBackend, nil) u := newUpstream(authBackend, "")
compileRoutes(u) compileRoutes(u)
return httptest.NewServer(u) return httptest.NewServer(u)
} }
......
package main package main
import ( import (
"./internal/proxy"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"regexp" "regexp"
"testing" "testing"
"time" "time"
...@@ -39,7 +41,7 @@ func TestProxyRequest(t *testing.T) { ...@@ -39,7 +41,7 @@ func TestProxyRequest(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, "")
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 202) assertResponseCode(t, w, 202)
...@@ -57,11 +59,7 @@ func TestProxyError(t *testing.T) { ...@@ -57,11 +59,7 @@ func TestProxyError(t *testing.T) {
} }
httpRequest.Header.Set("Custom-Header", "test") httpRequest.Header.Set("Custom-Header", "test")
transport := proxyRoundTripper{ u := newUpstream("http://localhost:655575/", "")
transport: http.DefaultTransport,
}
u := newUpstream("http://localhost:655575/", &transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
assertResponseCode(t, w, 502) assertResponseCode(t, w, 502)
...@@ -78,8 +76,8 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -78,8 +76,8 @@ func TestProxyReadTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
transport := &proxyRoundTripper{ transport := proxy.NewRoundTripper(
transport: &http.Transport{ &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
...@@ -88,9 +86,14 @@ func TestProxyReadTimeout(t *testing.T) { ...@@ -88,9 +86,14 @@ func TestProxyReadTimeout(t *testing.T) {
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: time.Millisecond, ResponseHeaderTimeout: time.Millisecond,
}, },
} )
u := newUpstream(ts.URL, transport) u := newUpstream(ts.URL, "")
url, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
u.Proxy = proxy.NewProxy(url, transport, "123")
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
...@@ -110,11 +113,7 @@ func TestProxyHandlerTimeout(t *testing.T) { ...@@ -110,11 +113,7 @@ func TestProxyHandlerTimeout(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
transport := &proxyRoundTripper{ u := newUpstream(ts.URL, "")
transport: http.DefaultTransport,
}
u := newUpstream(ts.URL, transport)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u.Proxy.ServeHTTP(w, httpRequest) u.Proxy.ServeHTTP(w, httpRequest)
......
package main package main
import ( import (
"./internal/helper"
"log" "log"
"net/http" "net/http"
"os" "os"
...@@ -36,7 +37,7 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou ...@@ -36,7 +37,7 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou
// Serve pre-gzipped assets // Serve pre-gzipped assets
if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") { if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") {
content, fi, err = openFile(file + ".gz") content, fi, err = helper.OpenFile(file + ".gz")
if err == nil { if err == nil {
w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Content-Encoding", "gzip")
} }
...@@ -44,7 +45,7 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou ...@@ -44,7 +45,7 @@ func (u *upstream) handleServeFile(documentRoot *string, cache CacheMode, notFou
// If not found, open the original file // If not found, open the original file
if content == nil || err != nil { if content == nil || err != nil {
content, fi, err = openFile(file) content, fi, err = helper.OpenFile(file)
} }
if err != nil { if err != nil {
if notFoundHandler != nil { if notFoundHandler != nil {
......
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
"testing" "testing"
) )
var dummyUpstream = newUpstream("http://localhost", nil) var dummyUpstream = newUpstream("http://localhost", "")
func TestServingNonExistingFile(t *testing.T) { func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory" dir := "/path/to/non/existing/directory"
......
...@@ -51,7 +51,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) { ...@@ -51,7 +51,7 @@ func TestUploadHandlerForwardingRawData(t *testing.T) {
response := httptest.NewRecorder() response := httptest.NewRecorder()
httpRequest.Header.Set(tempPathHeader, tempPath) httpRequest.Header.Set(tempPathHeader, tempPath)
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, "")
handleFileUploads(u.Proxy).ServeHTTP(response, httpRequest) handleFileUploads(u.Proxy).ServeHTTP(response, httpRequest)
assertResponseCode(t, response, 202) assertResponseCode(t, response, 202)
...@@ -126,7 +126,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -126,7 +126,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
httpRequest.Header.Set(tempPathHeader, tempPath) httpRequest.Header.Set(tempPathHeader, tempPath)
response := httptest.NewRecorder() response := httptest.NewRecorder()
u := newUpstream(ts.URL, nil) u := newUpstream(ts.URL, "")
handleFileUploads(u.Proxy).ServeHTTP(response, httpRequest) handleFileUploads(u.Proxy).ServeHTTP(response, httpRequest)
assertResponseCode(t, response, 202) assertResponseCode(t, response, 202)
......
...@@ -7,12 +7,14 @@ In this file we handle request routing and interaction with the authBackend. ...@@ -7,12 +7,14 @@ In this file we handle request routing and interaction with the authBackend.
package main package main
import ( import (
"./internal/proxy"
"fmt" "fmt"
"log" "log"
"net"
"net/http" "net/http"
"net/http/httputil"
"net/url" "net/url"
"strings" "strings"
"time"
) )
type serviceHandleFunc func(http.ResponseWriter, *http.Request, *apiResponse) type serviceHandleFunc func(http.ResponseWriter, *http.Request, *apiResponse)
...@@ -24,15 +26,11 @@ type API struct { ...@@ -24,15 +26,11 @@ type API struct {
type upstream struct { type upstream struct {
API *API API *API
Proxy *Proxy Proxy *proxy.Proxy
authBackend string authBackend string
relativeURLRoot string relativeURLRoot string
} }
type Proxy struct {
ReverseProxy *httputil.ReverseProxy
}
type apiResponse struct { type apiResponse struct {
// GL_ID is an environment variable used by gitlab-shell hooks during 'git // GL_ID is an environment variable used by gitlab-shell hooks during 'git
// push' and 'git pull' // push' and 'git pull'
...@@ -61,16 +59,7 @@ type apiResponse struct { ...@@ -61,16 +59,7 @@ type apiResponse struct {
TempPath string TempPath string
} }
func newProxy(url *url.URL, transport http.RoundTripper) *Proxy { func newUpstream(authBackend string, authSocket string) *upstream {
// Modify a copy of url
proxyURL := *url
proxyURL.Path = ""
proxy := Proxy{ReverseProxy: httputil.NewSingleHostReverseProxy(&proxyURL)}
proxy.ReverseProxy.Transport = transport
return &proxy
}
func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream {
parsedURL, err := url.Parse(authBackend) parsedURL, err := url.Parse(authBackend)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
...@@ -81,10 +70,27 @@ func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream ...@@ -81,10 +70,27 @@ func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream
relativeURLRoot += "/" relativeURLRoot += "/"
} }
// Create Proxy Transport
authTransport := http.DefaultTransport
if authSocket != "" {
dialer := &net.Dialer{
// The values below are taken from http.DefaultTransport
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
authTransport = &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return dialer.Dial("unix", authSocket)
},
ResponseHeaderTimeout: *responseHeadersTimeout,
}
}
proxyTransport := proxy.NewRoundTripper(authTransport)
up := &upstream{ up := &upstream{
authBackend: authBackend, authBackend: authBackend,
API: &API{Client: &http.Client{Transport: authTransport}, URL: parsedURL}, API: &API{Client: &http.Client{Transport: proxyTransport}, URL: parsedURL},
Proxy: newProxy(parsedURL, authTransport), Proxy: proxy.NewProxy(parsedURL, proxyTransport, Version),
relativeURLRoot: relativeURLRoot, relativeURLRoot: relativeURLRoot,
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment