Commit 9f9ad21a authored by Leonard Hecker's avatar Leonard Hecker

Fixed #1292: Failure to proxy WebSockets over HTTPS

This issue was caused by connHijackerTransport trying to record HTTP
response headers by "hijacking" the Read() method of the plain net.Conn.
This does not simply work over TLS though since this will record the TLS
handshake and encrypted data instead of the actual content.
This commit fixes the problem by providing an alternative transport.DialTLS
which correctly hijacks the overlying tls.Conn instead.
parent 53635ba5
...@@ -27,6 +27,11 @@ import ( ...@@ -27,6 +27,11 @@ import (
"github.com/mholt/caddy/caddyhttp/httpserver" "github.com/mholt/caddy/caddyhttp/httpserver"
) )
var defaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
var bufferPool = sync.Pool{New: createBuffer} var bufferPool = sync.Pool{New: createBuffer}
func createBuffer() interface{} { func createBuffer() interface{} {
...@@ -135,11 +140,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -135,11 +140,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
// just use default transport, to avoid creating // just use default transport, to avoid creating
// a brand new transport // a brand new transport
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: defaultDialer.Dial,
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
} }
...@@ -162,11 +164,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * ...@@ -162,11 +164,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
func (rp *ReverseProxy) UseInsecureTransport() { func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil { if rp.Transport == nil {
transport := &http.Transport{ transport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{ Dial: defaultDialer.Dial,
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
} }
...@@ -341,51 +340,148 @@ type connHijackerTransport struct { ...@@ -341,51 +340,148 @@ type connHijackerTransport struct {
} }
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
transport := &http.Transport{ t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
MaxIdleConnsPerHost: -1, MaxIdleConnsPerHost: -1,
} }
if base != nil { if b, _ := base.(*http.Transport); b != nil {
if baseTransport, ok := base.(*http.Transport); ok { t.Proxy = b.Proxy
transport.Proxy = baseTransport.Proxy t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
transport.TLSClientConfig = baseTransport.TLSClientConfig t.TLSClientConfig.NextProtos = nil
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
transport.Dial = baseTransport.Dial t.Dial = b.Dial
transport.DialTLS = baseTransport.DialTLS t.DialTLS = b.DialTLS
transport.MaxIdleConnsPerHost = -1 } else {
t.Proxy = http.ProxyFromEnvironment
t.TLSHandshakeTimeout = 10 * time.Second
}
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
dial := getTransportDial(t)
dialTLS := getTransportDialTLS(t)
t.Dial = func(network, addr string) (net.Conn, error) {
c, err := dial(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}
if dialTLS != nil {
t.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := dialTLS(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
} }
} }
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
oldDial := transport.Dial return hj
oldDialTLS := transport.DialTLS }
if oldDial == nil {
oldDial = (&net.Dialer{ // getTransportDial always returns a plain Dialer
Timeout: 30 * time.Second, // and defaults to the existing t.Dial.
KeepAlive: 30 * time.Second, func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
}).Dial if t.Dial != nil {
return t.Dial
}
return defaultDialer.Dial
}
// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
// and defaults to the existing t.DialTLS.
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
if t.DialTLS != nil {
return t.DialTLS
} }
hjTransport.Dial = func(network, addr string) (net.Conn, error) { if t.TLSClientConfig == nil {
c, err := oldDial(network, addr) return nil
hjTransport.Conn = c
return &hijackedConn{c, hjTransport}, err
} }
if oldDialTLS != nil {
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) { // newConnHijackerTransport will modify t.Dial after calling this method
c, err := oldDialTLS(network, addr) // => Create a backup reference.
hjTransport.Conn = c plainDial := getTransportDial(t)
return &hijackedConn{c, hjTransport}, err
return func(network, addr string) (net.Conn, error) {
plainConn, err := plainDial(network, addr)
if err != nil {
return nil, err
} }
tlsConn := tls.Client(plainConn, t.TLSClientConfig)
errc := make(chan error, 2)
var timer *time.Timer
if d := t.TLSHandshakeTimeout; d != 0 {
timer = time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
}
go func() {
err := tlsConn.Handshake()
if timer != nil {
timer.Stop()
}
errc <- err
}()
if err := <-errc; err != nil {
plainConn.Close()
return nil, err
}
if !t.TLSClientConfig.InsecureSkipVerify {
serverName := t.TLSClientConfig.ServerName
if serverName == "" {
serverName = addr
idx := strings.LastIndex(serverName, ":")
if idx != -1 {
serverName = serverName[:idx]
}
}
if err := tlsConn.VerifyHostname(serverName); err != nil {
plainConn.Close()
return nil, err
}
}
return tlsConn, nil
}
}
type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
// cloneTLSClientConfig is like cloneTLSConfig but omits
// the fields SessionTicketsDisabled and SessionTicketKey.
// This makes it safe to call cloneTLSClientConfig on a config
// in active use by a server.
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
} }
return hjTransport
} }
func requestIsWebsocket(req *http.Request) bool { func requestIsWebsocket(req *http.Request) bool {
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")) return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
} }
type writeFlusher interface { type writeFlusher interface {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment