Commit b857265f authored by Leonard Hecker's avatar Leonard Hecker

proxy: Fixed support for TLS verification of WebSocket connections

parent 153d4a5a
...@@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { ...@@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
MaxIdleConnsPerHost: -1, MaxIdleConnsPerHost: -1,
} }
if b, _ := base.(*http.Transport); b != nil { if b, _ := base.(*http.Transport); b != nil {
tlsClientConfig := b.TLSClientConfig
if tlsClientConfig.NextProtos != nil {
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
tlsClientConfig.NextProtos = nil
}
t.Proxy = b.Proxy t.Proxy = b.Proxy
t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig) t.TLSClientConfig = tlsClientConfig
t.TLSClientConfig.NextProtos = nil
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
t.Dial = b.Dial t.Dial = b.Dial
t.DialTLS = b.DialTLS t.DialTLS = b.DialTLS
...@@ -363,19 +368,15 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { ...@@ -363,19 +368,15 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
dial := getTransportDial(t) dial := getTransportDial(t)
dialTLS := getTransportDialTLS(t) dialTLS := getTransportDialTLS(t)
t.Dial = func(network, addr string) (net.Conn, error) { t.Dial = func(network, addr string) (net.Conn, error) {
c, err := dial(network, addr) c, err := dial(network, addr)
hj.Conn = c hj.Conn = c
return &hijackedConn{c, hj}, err return &hijackedConn{c, hj}, err
} }
t.DialTLS = func(network, addr string) (net.Conn, error) {
if dialTLS != nil { c, err := dialTLS(network, addr)
t.DialTLS = func(network, addr string) (net.Conn, error) { hj.Conn = c
c, err := dialTLS(network, addr) return &hijackedConn{c, hj}, err
hj.Conn = c
return &hijackedConn{c, hj}, err
}
} }
return hj return hj
...@@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e ...@@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e
return defaultDialer.Dial return defaultDialer.Dial
} }
// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil // getTransportDial always returns a TLS Dialer
// and defaults to the existing t.DialTLS. // and defaults to the existing t.DialTLS.
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) { func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
if t.DialTLS != nil { if t.DialTLS != nil {
return t.DialTLS return t.DialTLS
} }
if t.TLSClientConfig == nil {
return nil
}
// newConnHijackerTransport will modify t.Dial after calling this method // newConnHijackerTransport will modify t.Dial after calling this method
// => Create a backup reference. // => Create a backup reference.
plainDial := getTransportDial(t) plainDial := getTransportDial(t)
// The following DialTLS implementation stems from the Go stdlib and
// is identical to what happens if DialTLS is not provided.
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
return func(network, addr string) (net.Conn, error) { return func(network, addr string) (net.Conn, error) {
plainConn, err := plainDial(network, addr) plainConn, err := plainDial(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tlsConn := tls.Client(plainConn, t.TLSClientConfig) tlsClientConfig := t.TLSClientConfig
if tlsClientConfig == nil {
tlsClientConfig = &tls.Config{}
}
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
tlsClientConfig.ServerName = stripPort(addr)
}
tlsConn := tls.Client(plainConn, tlsClientConfig)
errc := make(chan error, 2) errc := make(chan error, 2)
var timer *time.Timer var timer *time.Timer
if d := t.TLSHandshakeTimeout; d != 0 { if d := t.TLSHandshakeTimeout; d != 0 {
...@@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn ...@@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
plainConn.Close() plainConn.Close()
return nil, err return nil, err
} }
if !t.TLSClientConfig.InsecureSkipVerify { if !tlsClientConfig.InsecureSkipVerify {
serverName := t.TLSClientConfig.ServerName hostname := tlsClientConfig.ServerName
if serverName == "" { if hostname == "" {
serverName = addr hostname = stripPort(addr)
idx := strings.LastIndex(serverName, ":")
if idx != -1 {
serverName = serverName[:idx]
}
} }
if err := tlsConn.VerifyHostname(serverName); err != nil { if err := tlsConn.VerifyHostname(hostname); err != nil {
plainConn.Close() plainConn.Close()
return nil, err return nil, err
} }
...@@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn ...@@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
} }
} }
// stripPort returns address without its port if it has one and
// works with IP addresses as well as hostnames formatted as host:port.
//
// IPv6 addresses (excluding the port) must be enclosed in
// square brackets similar to the requirements of Go's stdlib.
func stripPort(address string) string {
// Keep in mind that the address might be a IPv6 address
// and thus contain a colon, but not have a port.
portIdx := strings.LastIndex(address, ":")
ipv6Idx := strings.LastIndex(address, "]")
if portIdx > ipv6Idx {
address = address[:portIdx]
}
return address
}
type tlsHandshakeTimeoutError struct{} type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true } func (tlsHandshakeTimeoutError) Timeout() bool { return true }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment