Commit 28ee1796 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net: prevent cancelation goroutine from adjusting fd timeout after connect

This was previously fixed in https://golang.org/cl/21497 but not enough.

Fixes #16523

Change-Id: I678543a656304c82d654e25e12fb094cd6cc87e8
Reviewed-on: https://go-review.googlesource.com/25330
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: default avatarJoe Tsai <thebrokentoaster@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 2629446d
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
package net
import (
"context"
"syscall"
"testing"
"time"
)
// Issue 16523
func TestDialContextCancelRace(t *testing.T) {
oldConnectFunc := connectFunc
oldGetsockoptIntFunc := getsockoptIntFunc
oldTestHookCanceledDial := testHookCanceledDial
defer func() {
connectFunc = oldConnectFunc
getsockoptIntFunc = oldGetsockoptIntFunc
testHookCanceledDial = oldTestHookCanceledDial
}()
ln, err := newLocalListener("tcp")
if err != nil {
t.Fatal(err)
}
listenerDone := make(chan struct{})
go func() {
defer close(listenerDone)
c, err := ln.Accept()
if err == nil {
c.Close()
}
}()
defer func() { <-listenerDone }()
defer ln.Close()
sawCancel := make(chan bool, 1)
testHookCanceledDial = func() {
sawCancel <- true
}
ctx, cancelCtx := context.WithCancel(context.Background())
connectFunc = func(fd int, addr syscall.Sockaddr) error {
err := oldConnectFunc(fd, addr)
t.Logf("connect(%d, addr) = %v", fd, err)
if err == nil {
// On some operating systems, localhost
// connects _sometimes_ succeed immediately.
// Prevent that, so we exercise the code path
// we're interested in testing. This seems
// harmless. It makes FreeBSD 10.10 work when
// run with many iterations. It failed about
// half the time previously.
return syscall.EINPROGRESS
}
return err
}
getsockoptIntFunc = func(fd, level, opt int) (val int, err error) {
val, err = oldGetsockoptIntFunc(fd, level, opt)
t.Logf("getsockoptIntFunc(%d, %d, %d) = (%v, %v)", fd, level, opt, val, err)
if level == syscall.SOL_SOCKET && opt == syscall.SO_ERROR && err == nil && val == 0 {
t.Logf("canceling context")
// Cancel the context at just the moment which
// caused the race in issue 16523.
cancelCtx()
// And wait for the "interrupter" goroutine to
// cancel the dial by messing with its write
// timeout before returning.
select {
case <-sawCancel:
t.Logf("saw cancel")
case <-time.After(5 * time.Second):
t.Errorf("didn't see cancel after 5 seconds")
}
}
return
}
var d Dialer
c, err := d.DialContext(ctx, "tcp", ln.Addr().String())
if err == nil {
c.Close()
t.Fatal("unexpected successful dial; want context canceled error")
}
select {
case <-ctx.Done():
case <-time.After(5 * time.Second):
t.Fatal("expected context to be canceled")
}
oe, ok := err.(*OpError)
if !ok || oe.Op != "dial" {
t.Fatalf("Dial error = %#v; want dial *OpError", err)
}
if oe.Err != ctx.Err() {
t.Errorf("DialContext = (%v, %v); want OpError with error %v", c, err, ctx.Err())
}
}
...@@ -64,7 +64,7 @@ func (fd *netFD) name() string { ...@@ -64,7 +64,7 @@ func (fd *netFD) name() string {
return fd.net + ":" + ls + "->" + rs return fd.net + ":" + ls + "->" + rs
} }
func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret error) {
// Do not need to call fd.writeLock here, // Do not need to call fd.writeLock here,
// because fd is not yet accessible to user, // because fd is not yet accessible to user,
// so no concurrent operations are possible. // so no concurrent operations are possible.
...@@ -101,21 +101,44 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { ...@@ -101,21 +101,44 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error {
defer fd.setWriteDeadline(noDeadline) defer fd.setWriteDeadline(noDeadline)
} }
// Wait for the goroutine converting context.Done into a write timeout // Start the "interrupter" goroutine, if this context might be canceled.
// to exist, otherwise our caller might cancel the context and // (The background context cannot)
// cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial. //
done := make(chan bool) // must be unbuffered // The interrupter goroutine waits for the context to be done and
defer func() { done <- true }() // interrupts the dial (by altering the fd's write deadline, which
go func() { // wakes up waitWrite).
select { if ctx != context.Background() {
case <-ctx.Done(): // Wait for the interrupter goroutine to exit before returning
// Force the runtime's poller to immediately give // from connect.
// up waiting for writability. done := make(chan struct{})
fd.setWriteDeadline(aLongTimeAgo) interruptRes := make(chan error)
<-done defer func() {
case <-done: close(done)
} if ctxErr := <-interruptRes; ctxErr != nil && ret == nil {
}() // The interrupter goroutine called setWriteDeadline,
// but the connect code below had returned from
// waitWrite already and did a successful connect (ret
// == nil). Because we've now poisoned the connection
// by making it unwritable, don't return a successful
// dial. This was issue 16523.
ret = ctxErr
fd.Close() // prevent a leak
}
}()
go func() {
select {
case <-ctx.Done():
// Force the runtime's poller to immediately give up
// waiting for writability, unblocking waitWrite
// below.
fd.setWriteDeadline(aLongTimeAgo)
testHookCanceledDial()
interruptRes <- ctx.Err()
case <-done:
interruptRes <- nil
}
}()
}
for { for {
// Performing multiple connect system calls on a // Performing multiple connect system calls on a
......
...@@ -9,7 +9,8 @@ package net ...@@ -9,7 +9,8 @@ package net
import "syscall" import "syscall"
var ( var (
testHookDialChannel = func() {} // see golang.org/issue/5349 testHookDialChannel = func() {} // for golang.org/issue/5349
testHookCanceledDial = func() {} // for golang.org/issue/16523
// Placeholders for socket system calls. // Placeholders for socket system calls.
socketFunc func(int, int, int) (int, error) = syscall.Socket socketFunc func(int, int, int) (int, error) = syscall.Socket
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment