Commit 1f8b2a69 authored by Adam Langley's avatar Adam Langley

crypto/tls: add DialWithDialer.

While reviewing uses of the lower-level Client API in code, I found
that in many cases, code was using Client only because it needed a
timeout on the connection. DialWithDialer allows a timeout (and
 other values) to be specified without resorting to the low-level API.

LGTM=r
R=golang-codereviews, r, bradfitz
CC=golang-codereviews
https://golang.org/cl/68920045
parent b3e0a8df
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"strings" "strings"
"time"
) )
// Server returns a new TLS server side connection // Server returns a new TLS server side connection
...@@ -76,24 +77,51 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) { ...@@ -76,24 +77,51 @@ func Listen(network, laddr string, config *Config) (net.Listener, error) {
return NewListener(l, config), nil return NewListener(l, config), nil
} }
// Dial connects to the given network address using net.Dial type timeoutError struct{}
// and then initiates a TLS handshake, returning the resulting
// TLS connection. func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
// Dial interprets a nil configuration as equivalent to func (timeoutError) Timeout() bool { return true }
// the zero configuration; see the documentation of Config func (timeoutError) Temporary() bool { return true }
// for the defaults.
func Dial(network, addr string, config *Config) (*Conn, error) { // DialWithDialer connects to the given network address using dialer.Dial and
raddr := addr // then initiates a TLS handshake, returning the resulting TLS connection. Any
c, err := net.Dial(network, raddr) // timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := dialer.Timeout
if !dialer.Deadline.IsZero() {
deadlineTimeout := dialer.Deadline.Sub(time.Now())
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- timeoutError{}
})
}
rawConn, err := dialer.Dial(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
colonPos := strings.LastIndex(raddr, ":") colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 { if colonPos == -1 {
colonPos = len(raddr) colonPos = len(addr)
} }
hostname := raddr[:colonPos] hostname := addr[:colonPos]
if config == nil { if config == nil {
config = defaultConfig() config = defaultConfig()
...@@ -106,14 +134,37 @@ func Dial(network, addr string, config *Config) (*Conn, error) { ...@@ -106,14 +134,37 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
c.ServerName = hostname c.ServerName = hostname
config = &c config = &c
} }
conn := Client(c, config)
if err = conn.Handshake(); err != nil { conn := Client(rawConn, config)
c.Close()
if timeout == 0 {
err = conn.Handshake()
} else {
go func() {
errChannel <- conn.Handshake()
}()
err = <-errChannel
}
if err != nil {
rawConn.Close()
return nil, err return nil, err
} }
return conn, nil return conn, nil
} }
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}
// LoadX509KeyPair reads and parses a public/private key pair from a pair of // LoadX509KeyPair reads and parses a public/private key pair from a pair of
// files. The files must contain PEM encoded data. // files. The files must contain PEM encoded data.
func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) { func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
......
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
package tls package tls
import ( import (
"net"
"strings"
"testing" "testing"
"time"
) )
var rsaCertPEM = `-----BEGIN CERTIFICATE----- var rsaCertPEM = `-----BEGIN CERTIFICATE-----
...@@ -105,3 +108,45 @@ func TestX509MixedKeyPair(t *testing.T) { ...@@ -105,3 +108,45 @@ func TestX509MixedKeyPair(t *testing.T) {
t.Error("Load of ECDSA certificate succeeded with RSA private key") t.Error("Load of ECDSA certificate succeeded with RSA private key")
} }
} }
func TestDialTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
listener, err = net.Listen("tcp6", "[::1]:0")
}
if err != nil {
t.Fatal(err)
}
addr := listener.Addr().String()
defer listener.Close()
complete := make(chan bool)
defer close(complete)
go func() {
conn, err := listener.Accept()
if err != nil {
t.Error(err)
return
}
<-complete
conn.Close()
}()
dialer := &net.Dialer{
Timeout: 10 * time.Millisecond,
}
if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil {
t.Fatal("DialWithTimeout completed successfully")
}
if !strings.Contains(err.Error(), "timed out") {
t.Errorf("resulting error not a timeout: %s", err)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment