Commit 24a83d35 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net: add Dialer.Cancel to cancel pending dials

Dialer.Cancel is a new optional <-chan struct{} channel whose closure
indicates that the dial should be canceled. It is compatible with the
x/net/context and http.Request.Cancel types.

Tested by hand with:

package main

    import (
            "log"
            "net"
            "time"
    )

    func main() {
            log.Printf("start.")
            var d net.Dialer
            cancel := make(chan struct{})
            time.AfterFunc(2*time.Second, func() {
                    log.Printf("timeout firing")
                    close(cancel)
            })
            d.Cancel = cancel
            c, err := d.Dial("tcp", "192.168.0.1:22")
            if err != nil {
                    log.Print(err)
                    return
            }
            log.Fatalf("unexpected connect: %v", c)
    }

Which says:

    2015/12/14 22:24:58 start.
    2015/12/14 22:25:00 timeout firing
    2015/12/14 22:25:00 dial tcp 192.168.0.1:22: operation was canceled

Fixes #11225

Change-Id: I2ef39e3a540e29fe6bfec03ab7a629a6b187fcb3
Reviewed-on: https://go-review.googlesource.com/17821Reviewed-by: default avatarRuss Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 479c47e4
......@@ -57,6 +57,11 @@ type Dialer struct {
// If zero, keep-alives are not enabled. Network protocols
// that do not support keep-alives ignore this field.
KeepAlive time.Duration
// Cancel is an optional channel whose closure indicates that
// the dial should be canceled. Not all types of dials support
// cancelation.
Cancel <-chan struct{}
}
// Return either now+Timeout or Deadline, whichever comes first.
......@@ -361,7 +366,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err erro
switch ra := ra.(type) {
case *TCPAddr:
la, _ := la.(*TCPAddr)
c, err = testHookDialTCP(ctx.network, la, ra, deadline)
c, err = testHookDialTCP(ctx.network, la, ra, deadline, ctx.Cancel)
case *UDPAddr:
la, _ := la.(*UDPAddr)
c, err = dialUDP(ctx.network, la, ra, deadline)
......
......@@ -5,6 +5,7 @@
package net
import (
"internal/testenv"
"io"
"net/internal/socktest"
"runtime"
......@@ -236,8 +237,8 @@ const (
// In some environments, the slow IPs may be explicitly unreachable, and fail
// more quickly than expected. This test hook prevents dialTCP from returning
// before the deadline.
func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
c, err := dialTCP(net, laddr, raddr, deadline)
func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
c, err := dialTCP(net, laddr, raddr, deadline, cancel)
if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
time.Sleep(deadline.Sub(time.Now()))
}
......@@ -716,3 +717,64 @@ func TestDialerKeepAlive(t *testing.T) {
}
}
}
func TestDialCancel(t *testing.T) {
if runtime.GOOS == "plan9" || runtime.GOOS == "nacl" {
// plan9 is not implemented and nacl doesn't have
// external network access.
t.Skip("skipping on %s", runtime.GOOS)
}
onGoBuildFarm := testenv.Builder() != ""
if testing.Short() && !onGoBuildFarm {
t.Skip("skipping in short mode")
}
blackholeIPPort := JoinHostPort(slowDst4, "1234")
if !supportsIPv4 {
blackholeIPPort = JoinHostPort(slowDst6, "1234")
}
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
const cancelTick = 5 // the timer tick we cancel the dial at
const timeoutTick = 100
var d Dialer
cancel := make(chan struct{})
d.Cancel = cancel
errc := make(chan error, 1)
connc := make(chan Conn, 1)
go func() {
if c, err := d.Dial("tcp", blackholeIPPort); err != nil {
errc <- err
} else {
connc <- c
}
}()
ticks := 0
for {
select {
case <-ticker.C:
ticks++
if ticks == cancelTick {
close(cancel)
}
if ticks == timeoutTick {
t.Fatal("timeout waiting for dial to fail")
}
case c := <-connc:
c.Close()
t.Fatal("unexpected successful connection")
case err := <-errc:
if ticks < cancelTick {
t.Fatalf("dial error after %d ticks (%d before cancel sent): %v",
ticks, cancelTick-ticks, err)
}
if oe, ok := err.(*OpError); !ok || oe.Err != errCanceled {
t.Fatalf("dial error = %v (%T); want OpError with Err == errCanceled", err, err)
}
return // success.
}
}
}
......@@ -68,7 +68,7 @@ func (fd *netFD) name() string {
return fd.net + ":" + ls + "->" + rs
}
func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
// Do not need to call fd.writeLock here,
// because fd is not yet accessible to user,
// so no concurrent operations are possible.
......@@ -102,6 +102,19 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
fd.setWriteDeadline(deadline)
defer fd.setWriteDeadline(noDeadline)
}
if cancel != nil {
done := make(chan bool)
defer close(done)
go func() {
select {
case <-cancel:
// Force the runtime's poller to immediately give
// up waiting for writability.
fd.setWriteDeadline(aLongTimeAgo)
case <-done:
}
}()
}
for {
// Performing multiple connect system calls on a
// non-blocking socket under Unix variants does not
......@@ -112,6 +125,11 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
// succeeded or failed. See issue 7474 for further
// details.
if err := fd.pd.WaitWrite(); err != nil {
select {
case <-cancel:
return errCanceled
default:
}
return err
}
nerr, err := getsockoptIntFunc(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
......
......@@ -320,7 +320,7 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
runtime.SetFinalizer(fd, (*netFD).Close)
}
func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
// Do not need to call fd.writeLock here,
// because fd is not yet accessible to user,
// so no concurrent operations are possible.
......@@ -351,14 +351,38 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
// Call ConnectEx API.
o := &fd.wop
o.sa = ra
if cancel != nil {
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-cancel:
// TODO(bradfitz,brainman): cancel the dial operation
// somehow. Brad doesn't know Windows but is going to
// try this:
if canCancelIO {
syscall.CancelIoEx(o.fd.sysfd, &o.o)
} else {
wsrv.req <- ioSrvReq{o, nil}
<-o.errc
}
case <-done:
}
}()
}
_, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error {
return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o)
})
if err != nil {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("connectex", err)
select {
case <-cancel:
return errCanceled
default:
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("connectex", err)
}
return err
}
return err
}
// Refresh socket properties.
return os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))))
......
......@@ -220,7 +220,7 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn,
if raddr == nil {
return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial")
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial", noCancel)
if err != nil {
return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
......@@ -241,7 +241,7 @@ func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
default:
return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(netProto)}
}
fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen")
fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen", noCancel)
if err != nil {
return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: err}
}
......
......@@ -156,9 +156,9 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family
// Internet sockets (TCP, UDP, IP)
func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string) (fd *netFD, err error) {
func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, cancel <-chan struct{}) (fd *netFD, err error) {
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline)
return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline, cancel)
}
func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
......
......@@ -426,7 +426,16 @@ func (e *OpError) Error() string {
return s
}
var noDeadline = time.Time{}
var (
// aLongTimeAgo is a non-zero time, far in the past, used for
// immediate cancelation of dials.
aLongTimeAgo = time.Unix(233431200, 0)
// nonDeadline and noCancel are just zero values for
// readability with functions taking too many parameters.
noDeadline = time.Time{}
noCancel = (chan struct{})(nil)
)
type timeout interface {
Timeout() bool
......
......@@ -34,7 +34,7 @@ type sockaddr interface {
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time) (fd *netFD, err error) {
func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) (fd *netFD, err error) {
s, err := sysSocket(family, sotype, proto)
if err != nil {
return nil, err
......@@ -86,7 +86,7 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s
return fd, nil
}
}
if err := fd.dial(laddr, raddr, deadline); err != nil {
if err := fd.dial(laddr, raddr, deadline, cancel); err != nil {
fd.Close()
return nil, err
}
......@@ -117,7 +117,7 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
return func(syscall.Sockaddr) Addr { return nil }
}
func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time) error {
func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) error {
var err error
var lsa syscall.Sockaddr
if laddr != nil {
......@@ -134,7 +134,7 @@ func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time) error {
if rsa, err = raddr.sockaddr(fd.family); err != nil {
return err
}
if err := fd.connect(lsa, rsa, deadline); err != nil {
if err := fd.connect(lsa, rsa, deadline, cancel); err != nil {
return err
}
fd.isConnected = true
......
......@@ -107,13 +107,14 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error {
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is
// used as the local address for the connection.
func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
return dialTCP(net, laddr, raddr, noDeadline)
return dialTCP(net, laddr, raddr, noDeadline, noCancel)
}
func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
if !deadline.IsZero() {
panic("net.dialTCP: deadline not implemented on Plan 9")
}
// TODO(bradfitz,0intro): also use the cancel channel.
switch net {
case "tcp", "tcp4", "tcp6":
default:
......
......@@ -164,11 +164,11 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
return dialTCP(net, laddr, raddr, noDeadline)
return dialTCP(net, laddr, raddr, noDeadline, noCancel)
}
func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial")
func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
// TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
......@@ -198,7 +198,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, e
if err == nil {
fd.Close()
}
fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial")
fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
}
if err != nil {
......@@ -326,7 +326,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) {
if laddr == nil {
laddr = &TCPAddr{}
}
fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen")
fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", noCancel)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr, Err: err}
}
......
......@@ -189,7 +189,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
}
func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) {
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial")
fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial", noCancel)
if err != nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
......@@ -212,7 +212,7 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) {
if laddr == nil {
laddr = &UDPAddr{}
}
fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen")
fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr, Err: err}
}
......@@ -239,7 +239,7 @@ func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPCon
if gaddr == nil || gaddr.IP == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress}
}
fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen")
fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr, Err: err}
}
......
......@@ -42,7 +42,7 @@ func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Ti
return nil, errors.New("unknown mode: " + mode)
}
fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline)
fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline, noCancel)
if err != nil {
return nil, err
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment