Commit 4c6082df authored by Nimi Wariboko Jr's avatar Nimi Wariboko Jr Committed by GitHub

Merge pull request #987 from nemothekid/proxy/single-webconn

Proxy: Single WebSocket connection
parents fffc1bed 88980664
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
...@@ -102,7 +103,8 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) { ...@@ -102,7 +103,8 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
// No-op websocket backend simply allows the WS connection to be // No-op websocket backend simply allows the WS connection to be
// accepted then it will be immediately closed. Perfect for testing. // accepted then it will be immediately closed. Perfect for testing.
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {})) var connCount int32
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { atomic.AddInt32(&connCount, 1) }))
defer wsNop.Close() defer wsNop.Close()
// Get proxy to use for the test // Get proxy to use for the test
...@@ -135,6 +137,9 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { ...@@ -135,6 +137,9 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
if !bytes.Equal(actual, expected) { if !bytes.Equal(actual, expected) {
t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual) t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual)
} }
if atomic.LoadInt32(&connCount) != 1 {
t.Errorf("Expected 1 websocket connection, got %d", connCount)
}
} }
func TestWebSocketReverseProxyFromWSClient(t *testing.T) { func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
......
...@@ -186,9 +186,80 @@ var hopHeaders = []string{ ...@@ -186,9 +186,80 @@ var hopHeaders = []string{
type respUpdateFn func(resp *http.Response) type respUpdateFn func(resp *http.Response)
type hijackedConn struct {
net.Conn
hj *connHijackerTransport
}
func (c *hijackedConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
c.hj.Replay = append(c.hj.Replay, b[:n]...)
return
}
func (c *hijackedConn) Close() error {
return nil
}
type connHijackerTransport struct {
*http.Transport
Conn net.Conn
Replay []byte
}
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
DisableKeepAlives: true,
}
if base != nil {
if baseTransport, ok := base.(*http.Transport); ok {
transport.Proxy = baseTransport.Proxy
transport.TLSClientConfig = baseTransport.TLSClientConfig
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
transport.Dial = baseTransport.Dial
transport.DialTLS = baseTransport.DialTLS
transport.DisableKeepAlives = true
}
}
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
oldDial := transport.Dial
oldDialTLS := transport.DialTLS
if oldDial == nil {
oldDial = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial
}
hjTransport.Dial = func(network, addr string) (net.Conn, error) {
c, err := oldDial(network, addr)
hjTransport.Conn = c
return &hijackedConn{c, hjTransport}, err
}
if oldDialTLS != nil {
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := oldDialTLS(network, addr)
hjTransport.Conn = c
return &hijackedConn{c, hjTransport}, err
}
}
return hjTransport
}
func requestIsWebsocket(req *http.Request) bool {
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error { func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
transport := p.Transport transport := p.Transport
if transport == nil { if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport)
} else if transport == nil {
transport = http.DefaultTransport transport = http.DefaultTransport
} }
...@@ -219,14 +290,22 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r ...@@ -219,14 +290,22 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r
} }
defer conn.Close() defer conn.Close()
backendConn, err := net.Dial("tcp", outreq.URL.Host) var backendConn net.Conn
if err != nil { if hj, ok := transport.(*connHijackerTransport); ok {
return err backendConn = hj.Conn
if _, err := conn.Write(hj.Replay); err != nil {
return err
}
bufferPool.Put(hj.Replay)
} else {
backendConn, err = net.Dial("tcp", outreq.URL.Host)
if err != nil {
return err
}
outreq.Write(backendConn)
} }
defer backendConn.Close() defer backendConn.Close()
outreq.Write(backendConn)
go func() { go func() {
io.Copy(backendConn, conn) // write tcp stream to backend. io.Copy(backendConn, conn) // write tcp stream to backend.
}() }()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment