Commit f1ba7fa3 authored by Matt Holt's avatar Matt Holt

Merge pull request #467 from eiszfuchs/feature/proxy-socket

proxy: Support unix sockets
parents 57ffe5a6 7091a209
...@@ -63,7 +63,6 @@ var tryDuration = 60 * time.Second ...@@ -63,7 +63,6 @@ var tryDuration = 60 * time.Second
// ServeHTTP satisfies the middleware.Handler interface. // ServeHTTP satisfies the middleware.Handler interface.
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, upstream := range p.Upstreams { for _, upstream := range p.Upstreams {
if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.IsAllowedPath(r.URL.Path) { if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.IsAllowedPath(r.URL.Path) {
var replacer middleware.Replacer var replacer middleware.Replacer
......
...@@ -3,6 +3,7 @@ package proxy ...@@ -3,6 +3,7 @@ package proxy
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
...@@ -13,7 +14,9 @@ import ( ...@@ -13,7 +14,9 @@ import (
"os" "os"
"strings" "strings"
"testing" "testing"
"runtime"
"time" "time"
"path/filepath"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
...@@ -160,6 +163,69 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { ...@@ -160,6 +163,69 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
} }
} }
func TestUnixSocketProxy(t *testing.T) {
if runtime.GOOS == "windows" {
return
}
trialMsg := "Is it working?"
var proxySuccess bool
// This is our fake "application" we want to proxy to
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Request was proxied when this is called
proxySuccess = true
fmt.Fprint(w, trialMsg)
}))
// Get absolute path for unix: socket
socketPath, err := filepath.Abs("./test_socket")
if err != nil {
t.Fatalf("Unable to get absolute path: %v", err)
}
// Change httptest.Server listener to listen to unix: socket
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Unable to listen: %v", err)
}
ts.Listener = ln
ts.Start()
defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
res, err := http.Get(echoProxy.URL)
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
actualMsg := fmt.Sprintf("%s", greeting)
if !proxySuccess {
t.Errorf("Expected request to be proxied, but it wasn't")
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream { func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name) uri, _ := url.Parse(name)
u := &fakeUpstream{ u := &fakeUpstream{
......
...@@ -59,6 +59,18 @@ func singleJoiningSlash(a, b string) string { ...@@ -59,6 +59,18 @@ func singleJoiningSlash(a, b string) string {
return a + b return a + b
} }
// Though the relevant directive prefix is just "unix:", url.Parse
// will - assuming the regular URL scheme - add additional slashes
// as if "unix" was a request protocol.
// What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):])
}
}
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites // NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
// URLs to the scheme, host, and base path provided in target. If the // URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir", // target's path is "/base" and the incoming request was for "/dir",
...@@ -68,8 +80,15 @@ func singleJoiningSlash(a, b string) string { ...@@ -68,8 +80,15 @@ func singleJoiningSlash(a, b string) string {
func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy { func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
targetQuery := target.RawQuery targetQuery := target.RawQuery
director := func(req *http.Request) { director := func(req *http.Request) {
req.URL.Scheme = target.Scheme if target.Scheme == "unix" {
req.URL.Host = target.Host // to make Dial work with unix URL,
// scheme and host have to be faked
req.URL.Scheme = "http"
req.URL.Host = "socket"
} else {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
}
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" { if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery req.URL.RawQuery = targetQuery + req.URL.RawQuery
...@@ -80,7 +99,13 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy { ...@@ -80,7 +99,13 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
req.URL.Path = strings.TrimPrefix(req.URL.Path, without) req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
} }
} }
return &ReverseProxy{Director: director} rp := &ReverseProxy{Director: director}
if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
}
}
return rp
} }
func copyHeader(dst, src http.Header) { func copyHeader(dst, src http.Header) {
......
...@@ -65,7 +65,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { ...@@ -65,7 +65,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
upstream.Hosts = make([]*UpstreamHost, len(to)) upstream.Hosts = make([]*UpstreamHost, len(to))
for i, host := range to { for i, host := range to {
if !strings.HasPrefix(host, "http") { if !strings.HasPrefix(host, "http") &&
!strings.HasPrefix(host, "unix:") {
host = "http://" + host host = "http://" + host
} }
uh := &UpstreamHost{ uh := &UpstreamHost{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment