Commit 29d1f3b8 authored by Mikio Hara's avatar Mikio Hara

net: add socket system call hooks for testing

This change adds socket system call hooks to existing test cases for
simulating a bit complicated network conditions to help making timeout
and dual IP stack test cases work more properly in followup changes.

Also test cases print debugging information in non-short mode like the
following:

Leaked goroutines:
net.TestWriteTimeout.func2(0xc20802a5a0, 0xc20801d000, 0x1000, 0x1000, 0xc2081d2ae0)
	/go/src/net/timeout_test.go:170 +0x98
created by net.TestWriteTimeout
	/go/src/net/timeout_test.go:173 +0x745
net.runDatagramPacketConnServer(0xc2080730e0, 0x2bd270, 0x3, 0x2c1770, 0xb, 0xc2081d2ba0, 0xc2081d2c00)
	/go/src/net/server_test.go:398 +0x667
created by net.TestTimeoutUDP
	/go/src/net/timeout_test.go:247 +0xc9
	(snip)

Leaked sockets:
3: {Cookie:615726511685632 Err:<nil> SocketErr:0}
5: {Cookie:7934075906097152 Err:<nil> SocketErr:0}

Socket statistical information:
{Family:1 Type:805306370 Protocol:0 Opened:17 Accepted:0 Connected:5 Closed:17}
{Family:2 Type:805306369 Protocol:0 Opened:450 Accepted:234 Connected:279 Closed:636}
{Family:1 Type:805306369 Protocol:0 Opened:11 Accepted:5 Connected:5 Closed:16}
{Family:28 Type:805306369 Protocol:0 Opened:95 Accepted:22 Connected:16 Closed:116}
{Family:2 Type:805306370 Protocol:0 Opened:84 Accepted:0 Connected:34 Closed:83}
{Family:28 Type:805306370 Protocol:0 Opened:52 Accepted:0 Connected:4 Closed:52}

Change-Id: I0e84be59a0699bc31245c78e2249423459b8cdda
Reviewed-on: https://go-review.googlesource.com/6390Reviewed-by: default avatarIan Lance Taylor <iant@golang.org>
parent fa85a720
...@@ -219,18 +219,27 @@ func TestReloadResolvConfChange(t *testing.T) { ...@@ -219,18 +219,27 @@ func TestReloadResolvConfChange(t *testing.T) {
} }
func BenchmarkGoLookupIP(b *testing.B) { func BenchmarkGoLookupIP(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
goLookupIP("www.example.com") goLookupIP("www.example.com")
} }
} }
func BenchmarkGoLookupIPNoSuchHost(b *testing.B) { func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
goLookupIP("some.nonexistent") goLookupIP("some.nonexistent")
} }
} }
func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
onceLoadConfig.Do(loadDefaultConfig) onceLoadConfig.Do(loadDefaultConfig)
if cfg.dnserr != nil || cfg.dnsConfig == nil { if cfg.dnserr != nil || cfg.dnsConfig == nil {
b.Fatalf("loadConfig failed: %v", cfg.dnserr) b.Fatalf("loadConfig failed: %v", cfg.dnserr)
......
...@@ -69,6 +69,9 @@ func TestDNSNames(t *testing.T) { ...@@ -69,6 +69,9 @@ func TestDNSNames(t *testing.T) {
} }
func BenchmarkDNSNames(b *testing.B) { func BenchmarkDNSNames(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
benchmarks := append(tests, []testCase{ benchmarks := append(tests, []testCase{
{strings.Repeat("a", 63), true}, {strings.Repeat("a", 63), true},
{strings.Repeat("a", 64), false}, {strings.Repeat("a", 64), false},
......
...@@ -72,7 +72,7 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error { ...@@ -72,7 +72,7 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
// Do not need to call fd.writeLock here, // Do not need to call fd.writeLock here,
// because fd is not yet accessible to user, // because fd is not yet accessible to user,
// so no concurrent operations are possible. // so no concurrent operations are possible.
switch err := syscall.Connect(fd.sysfd, ra); err { switch err := connectFunc(fd.sysfd, ra); err {
case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
case nil, syscall.EISCONN: case nil, syscall.EISCONN:
if !deadline.IsZero() && deadline.Before(time.Now()) { if !deadline.IsZero() && deadline.Before(time.Now()) {
...@@ -114,7 +114,7 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error { ...@@ -114,7 +114,7 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
if err := fd.pd.WaitWrite(); err != nil { if err := fd.pd.WaitWrite(); err != nil {
return err return err
} }
nerr, err := syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) nerr, err := getsockoptIntFunc(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
if err != nil { if err != nil {
return err return err
} }
...@@ -130,9 +130,9 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error { ...@@ -130,9 +130,9 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
func (fd *netFD) destroy() { func (fd *netFD) destroy() {
// Poller may want to unregister fd in readiness notification mechanism, // Poller may want to unregister fd in readiness notification mechanism,
// so this must be executed before closesocket. // so this must be executed before closeFunc.
fd.pd.Close() fd.pd.Close()
closesocket(fd.sysfd) closeFunc(fd.sysfd)
fd.sysfd = -1 fd.sysfd = -1
runtime.SetFinalizer(fd, nil) runtime.SetFinalizer(fd, nil)
} }
...@@ -417,7 +417,7 @@ func (fd *netFD) accept() (netfd *netFD, err error) { ...@@ -417,7 +417,7 @@ func (fd *netFD) accept() (netfd *netFD, err error) {
} }
if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil { if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil {
closesocket(s) closeFunc(s)
return nil, err return nil, err
} }
if err = netfd.init(); err != nil { if err = netfd.init(); err != nil {
...@@ -492,7 +492,3 @@ func (fd *netFD) dup() (f *os.File, err error) { ...@@ -492,7 +492,3 @@ func (fd *netFD) dup() (f *os.File, err error) {
return os.NewFile(uintptr(ns), fd.name()), nil return os.NewFile(uintptr(ns), fd.name()), nil
} }
func closesocket(s int) error {
return syscall.Close(s)
}
...@@ -69,10 +69,6 @@ func sysInit() { ...@@ -69,10 +69,6 @@ func sysInit() {
} }
} }
func closesocket(s syscall.Handle) error {
return syscall.Closesocket(s)
}
func canUseConnectEx(net string) bool { func canUseConnectEx(net string) bool {
switch net { switch net {
case "udp", "udp4", "udp6", "ip", "ip4", "ip6": case "udp", "udp4", "udp6", "ip", "ip4", "ip6":
...@@ -336,7 +332,7 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error { ...@@ -336,7 +332,7 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
defer fd.setWriteDeadline(noDeadline) defer fd.setWriteDeadline(noDeadline)
} }
if !canUseConnectEx(fd.net) { if !canUseConnectEx(fd.net) {
return syscall.Connect(fd.sysfd, ra) return connectFunc(fd.sysfd, ra)
} }
// ConnectEx windows API requires an unconnected, previously bound socket. // ConnectEx windows API requires an unconnected, previously bound socket.
if la == nil { if la == nil {
...@@ -356,7 +352,7 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error { ...@@ -356,7 +352,7 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
o := &fd.wop o := &fd.wop
o.sa = ra o.sa = ra
_, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error { _, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error {
return syscall.ConnectEx(o.fd.sysfd, o.sa, nil, 0, nil, &o.o) return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o)
}) })
if err != nil { if err != nil {
return err return err
...@@ -370,9 +366,9 @@ func (fd *netFD) destroy() { ...@@ -370,9 +366,9 @@ func (fd *netFD) destroy() {
return return
} }
// Poller may want to unregister fd in readiness notification mechanism, // Poller may want to unregister fd in readiness notification mechanism,
// so this must be executed before closesocket. // so this must be executed before closeFunc.
fd.pd.Close() fd.pd.Close()
closesocket(fd.sysfd) closeFunc(fd.sysfd)
fd.sysfd = syscall.InvalidHandle fd.sysfd = syscall.InvalidHandle
// no need for a finalizer anymore // no need for a finalizer anymore
runtime.SetFinalizer(fd, nil) runtime.SetFinalizer(fd, nil)
...@@ -540,7 +536,7 @@ func (fd *netFD) acceptOne(rawsa []syscall.RawSockaddrAny, o *operation) (*netFD ...@@ -540,7 +536,7 @@ func (fd *netFD) acceptOne(rawsa []syscall.RawSockaddrAny, o *operation) (*netFD
// Associate our new socket with IOCP. // Associate our new socket with IOCP.
netfd, err := newFD(s, fd.family, fd.sotype, fd.net) netfd, err := newFD(s, fd.family, fd.sotype, fd.net)
if err != nil { if err != nil {
closesocket(s) closeFunc(s)
return nil, &OpError{"accept", fd.net, fd.laddr, err} return nil, &OpError{"accept", fd.net, fd.laddr, err}
} }
if err := netfd.init(); err != nil { if err := netfd.init(); err != nil {
......
...@@ -18,13 +18,13 @@ func newFileFD(f *os.File) (*netFD, error) { ...@@ -18,13 +18,13 @@ func newFileFD(f *os.File) (*netFD, error) {
} }
if err = syscall.SetNonblock(fd, true); err != nil { if err = syscall.SetNonblock(fd, true); err != nil {
closesocket(fd) closeFunc(fd)
return nil, err return nil, err
} }
sotype, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE) sotype, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
if err != nil { if err != nil {
closesocket(fd) closeFunc(fd)
return nil, os.NewSyscallError("getsockopt", err) return nil, os.NewSyscallError("getsockopt", err)
} }
...@@ -33,7 +33,7 @@ func newFileFD(f *os.File) (*netFD, error) { ...@@ -33,7 +33,7 @@ func newFileFD(f *os.File) (*netFD, error) {
lsa, _ := syscall.Getsockname(fd) lsa, _ := syscall.Getsockname(fd)
switch lsa.(type) { switch lsa.(type) {
default: default:
closesocket(fd) closeFunc(fd)
return nil, syscall.EINVAL return nil, syscall.EINVAL
case *syscall.SockaddrInet4: case *syscall.SockaddrInet4:
family = syscall.AF_INET family = syscall.AF_INET
...@@ -64,7 +64,7 @@ func newFileFD(f *os.File) (*netFD, error) { ...@@ -64,7 +64,7 @@ func newFileFD(f *os.File) (*netFD, error) {
netfd, err := newFD(fd, family, sotype, laddr.Network()) netfd, err := newFD(fd, family, sotype, laddr.Network())
if err != nil { if err != nil {
closesocket(fd) closeFunc(fd)
return nil, err return nil, err
} }
if err := netfd.init(); err != nil { if err := netfd.init(); err != nil {
......
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build freebsd linux
package net
import "syscall"
var (
// Placeholders for socket system calls.
accept4Func func(int, int) (int, syscall.Sockaddr, error) = syscall.Accept4
)
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris
package net
import "syscall"
var (
// Placeholders for socket system calls.
socketFunc func(int, int, int) (int, error) = syscall.Socket
closeFunc func(int) error = syscall.Close
connectFunc func(int, syscall.Sockaddr) error = syscall.Connect
acceptFunc func(int) (int, syscall.Sockaddr, error) = syscall.Accept
getsockoptIntFunc func(int, int, int) (int, error) = syscall.GetsockoptInt
)
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import "syscall"
var (
// Placeholders for socket system calls.
socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
closeFunc func(syscall.Handle) error = syscall.Closesocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
connectExFunc func(syscall.Handle, syscall.Sockaddr, *byte, uint32, *uint32, *syscall.Overlapped) error = syscall.ConnectEx
)
...@@ -229,6 +229,9 @@ func testMulticastAddrs(t *testing.T, ifmat []Addr) (nmaf4, nmaf6 int) { ...@@ -229,6 +229,9 @@ func testMulticastAddrs(t *testing.T, ifmat []Addr) (nmaf4, nmaf6 int) {
} }
func BenchmarkInterfaces(b *testing.B) { func BenchmarkInterfaces(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if _, err := Interfaces(); err != nil { if _, err := Interfaces(); err != nil {
b.Fatal(err) b.Fatal(err)
...@@ -237,6 +240,9 @@ func BenchmarkInterfaces(b *testing.B) { ...@@ -237,6 +240,9 @@ func BenchmarkInterfaces(b *testing.B) {
} }
func BenchmarkInterfaceByIndex(b *testing.B) { func BenchmarkInterfaceByIndex(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
ifi := loopbackInterface() ifi := loopbackInterface()
if ifi == nil { if ifi == nil {
b.Skip("loopback interface not found") b.Skip("loopback interface not found")
...@@ -249,6 +255,9 @@ func BenchmarkInterfaceByIndex(b *testing.B) { ...@@ -249,6 +255,9 @@ func BenchmarkInterfaceByIndex(b *testing.B) {
} }
func BenchmarkInterfaceByName(b *testing.B) { func BenchmarkInterfaceByName(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
ifi := loopbackInterface() ifi := loopbackInterface()
if ifi == nil { if ifi == nil {
b.Skip("loopback interface not found") b.Skip("loopback interface not found")
...@@ -261,6 +270,9 @@ func BenchmarkInterfaceByName(b *testing.B) { ...@@ -261,6 +270,9 @@ func BenchmarkInterfaceByName(b *testing.B) {
} }
func BenchmarkInterfaceAddrs(b *testing.B) { func BenchmarkInterfaceAddrs(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if _, err := InterfaceAddrs(); err != nil { if _, err := InterfaceAddrs(); err != nil {
b.Fatal(err) b.Fatal(err)
...@@ -269,6 +281,9 @@ func BenchmarkInterfaceAddrs(b *testing.B) { ...@@ -269,6 +281,9 @@ func BenchmarkInterfaceAddrs(b *testing.B) {
} }
func BenchmarkInterfacesAndAddrs(b *testing.B) { func BenchmarkInterfacesAndAddrs(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
ifi := loopbackInterface() ifi := loopbackInterface()
if ifi == nil { if ifi == nil {
b.Skip("loopback interface not found") b.Skip("loopback interface not found")
...@@ -281,6 +296,9 @@ func BenchmarkInterfacesAndAddrs(b *testing.B) { ...@@ -281,6 +296,9 @@ func BenchmarkInterfacesAndAddrs(b *testing.B) {
} }
func BenchmarkInterfacesAndMulticastAddrs(b *testing.B) { func BenchmarkInterfacesAndMulticastAddrs(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
ifi := loopbackInterface() ifi := loopbackInterface()
if ifi == nil { if ifi == nil {
b.Skip("loopback interface not found") b.Skip("loopback interface not found")
......
...@@ -53,6 +53,9 @@ func TestParseIP(t *testing.T) { ...@@ -53,6 +53,9 @@ func TestParseIP(t *testing.T) {
} }
func BenchmarkParseIP(b *testing.B) { func BenchmarkParseIP(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, tt := range parseIPTests { for _, tt := range parseIPTests {
ParseIP(tt.in) ParseIP(tt.in)
...@@ -108,6 +111,9 @@ func TestIPString(t *testing.T) { ...@@ -108,6 +111,9 @@ func TestIPString(t *testing.T) {
} }
func BenchmarkIPString(b *testing.B) { func BenchmarkIPString(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, tt := range ipStringTests { for _, tt := range ipStringTests {
if tt.in != nil { if tt.in != nil {
...@@ -158,6 +164,9 @@ func TestIPMaskString(t *testing.T) { ...@@ -158,6 +164,9 @@ func TestIPMaskString(t *testing.T) {
} }
func BenchmarkIPMaskString(b *testing.B) { func BenchmarkIPMaskString(b *testing.B) {
uninstallTestHooks()
defer installTestHooks()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, tt := range ipMaskStringTests { for _, tt := range ipMaskStringTests {
tt.in.String() tt.in.String()
......
...@@ -14,12 +14,12 @@ import ( ...@@ -14,12 +14,12 @@ import (
) )
func probeIPv4Stack() bool { func probeIPv4Stack() bool {
s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) s, err := socketFunc(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
switch err { switch err {
case syscall.EAFNOSUPPORT, syscall.EPROTONOSUPPORT: case syscall.EAFNOSUPPORT, syscall.EPROTONOSUPPORT:
return false return false
case nil: case nil:
closesocket(s) closeFunc(s)
} }
return true return true
} }
...@@ -50,11 +50,11 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { ...@@ -50,11 +50,11 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) {
} }
for i := range probes { for i := range probes {
s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) s, err := socketFunc(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
if err != nil { if err != nil {
continue continue
} }
defer closesocket(s) defer closeFunc(s)
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, probes[i].value) syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, probes[i].value)
sa, err := probes[i].laddr.sockaddr(syscall.AF_INET6) sa, err := probes[i].laddr.sockaddr(syscall.AF_INET6)
if err != nil { if err != nil {
......
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build freebsd linux
package net
func init() {
extraTestHookInstallers = append(extraTestHookInstallers, installAccept4TestHook)
extraTestHookUninstallers = append(extraTestHookUninstallers, uninstallAccept4TestHook)
}
var (
// Placeholders for saving original socket system calls.
origAccept4 = accept4Func
)
func installAccept4TestHook() {
accept4Func = sw.Accept4
}
func uninstallAccept4TestHook() {
accept4Func = origAccept4
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
func installTestHooks() {}
func uninstallTestHooks() {}
func forceCloseSockets() {}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"fmt"
"net/internal/socktest"
"os"
"runtime"
"sort"
"strings"
"testing"
)
var sw socktest.Switch
func TestMain(m *testing.M) {
installTestHooks()
st := m.Run()
if !testing.Short() {
printLeakedGoroutines()
printLeakedSockets()
printSocketStats()
}
forceCloseSockets()
uninstallTestHooks()
os.Exit(st)
}
func printLeakedGoroutines() {
gss := leakedGoroutines()
if len(gss) == 0 {
return
}
fmt.Fprintf(os.Stderr, "Leaked goroutines:\n")
for _, gs := range gss {
fmt.Fprintf(os.Stderr, "%v\n", gs)
}
fmt.Fprintf(os.Stderr, "\n")
}
// leakedGoroutines returns a list of remaining goroutins used in test
// cases.
func leakedGoroutines() []string {
var gss []string
b := make([]byte, 2<<20)
b = b[:runtime.Stack(b, true)]
for _, s := range strings.Split(string(b), "\n\n") {
ss := strings.SplitN(s, "\n", 2)
if len(ss) != 2 {
continue
}
stack := strings.TrimSpace(ss[1])
if !strings.Contains(stack, "created by net") {
continue
}
gss = append(gss, stack)
}
sort.Strings(gss)
return gss
}
func printLeakedSockets() {
sos := sw.Sockets()
if len(sos) == 0 {
return
}
fmt.Fprintf(os.Stderr, "Leaked sockets:\n")
for s, so := range sos {
fmt.Fprintf(os.Stderr, "%v: %+v\n", s, so)
}
fmt.Fprintf(os.Stderr, "\n")
}
func printSocketStats() {
sts := sw.Stats()
if len(sts) == 0 {
return
}
fmt.Fprintf(os.Stderr, "Socket statistical information:\n")
for _, st := range sts {
fmt.Fprintf(os.Stderr, "%+v\n", st)
}
fmt.Fprintf(os.Stderr, "\n")
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris
package net
var (
// Placeholders for saving original socket system calls.
origSocket = socketFunc
origClose = closeFunc
origConnect = connectFunc
origAccept = acceptFunc
origGetsockoptInt = getsockoptIntFunc
extraTestHookInstallers []func()
extraTestHookUninstallers []func()
)
func installTestHooks() {
socketFunc = sw.Socket
closeFunc = sw.Close
connectFunc = sw.Connect
acceptFunc = sw.Accept
getsockoptIntFunc = sw.GetsockoptInt
for _, fn := range extraTestHookInstallers {
fn()
}
}
func uninstallTestHooks() {
socketFunc = origSocket
closeFunc = origClose
connectFunc = origConnect
acceptFunc = origAccept
getsockoptIntFunc = origGetsockoptInt
for _, fn := range extraTestHookUninstallers {
fn()
}
}
func forceCloseSockets() {
for s := range sw.Sockets() {
closeFunc(s)
}
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
var (
// Placeholders for saving original socket system calls.
origSocket = socketFunc
origClosesocket = closeFunc
origConnect = connectFunc
origConnectEx = connectExFunc
)
func installTestHooks() {
socketFunc = sw.Socket
closeFunc = sw.Closesocket
connectFunc = sw.Connect
connectExFunc = sw.ConnectEx
}
func uninstallTestHooks() {
socketFunc = origSocket
closeFunc = origClosesocket
connectFunc = origConnect
connectExFunc = origConnectEx
}
func forceCloseSockets() {
for s := range sw.Sockets() {
closeFunc(s)
}
}
...@@ -14,7 +14,7 @@ import "syscall" ...@@ -14,7 +14,7 @@ import "syscall"
// Wrapper around the socket system call that marks the returned file // Wrapper around the socket system call that marks the returned file
// descriptor as nonblocking and close-on-exec. // descriptor as nonblocking and close-on-exec.
func sysSocket(family, sotype, proto int) (int, error) { func sysSocket(family, sotype, proto int) (int, error) {
s, err := syscall.Socket(family, sotype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, proto) s, err := socketFunc(family, sotype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, proto)
// On Linux the SOCK_NONBLOCK and SOCK_CLOEXEC flags were // On Linux the SOCK_NONBLOCK and SOCK_CLOEXEC flags were
// introduced in 2.6.27 kernel and on FreeBSD both flags were // introduced in 2.6.27 kernel and on FreeBSD both flags were
// introduced in 10 kernel. If we get an EINVAL error on Linux // introduced in 10 kernel. If we get an EINVAL error on Linux
...@@ -26,7 +26,7 @@ func sysSocket(family, sotype, proto int) (int, error) { ...@@ -26,7 +26,7 @@ func sysSocket(family, sotype, proto int) (int, error) {
// See ../syscall/exec_unix.go for description of ForkLock. // See ../syscall/exec_unix.go for description of ForkLock.
syscall.ForkLock.RLock() syscall.ForkLock.RLock()
s, err = syscall.Socket(family, sotype, proto) s, err = socketFunc(family, sotype, proto)
if err == nil { if err == nil {
syscall.CloseOnExec(s) syscall.CloseOnExec(s)
} }
...@@ -35,7 +35,7 @@ func sysSocket(family, sotype, proto int) (int, error) { ...@@ -35,7 +35,7 @@ func sysSocket(family, sotype, proto int) (int, error) {
return -1, err return -1, err
} }
if err = syscall.SetNonblock(s, true); err != nil { if err = syscall.SetNonblock(s, true); err != nil {
syscall.Close(s) closeFunc(s)
return -1, err return -1, err
} }
return s, nil return s, nil
...@@ -44,7 +44,7 @@ func sysSocket(family, sotype, proto int) (int, error) { ...@@ -44,7 +44,7 @@ func sysSocket(family, sotype, proto int) (int, error) {
// Wrapper around the accept system call that marks the returned file // Wrapper around the accept system call that marks the returned file
// descriptor as nonblocking and close-on-exec. // descriptor as nonblocking and close-on-exec.
func accept(s int) (int, syscall.Sockaddr, error) { func accept(s int) (int, syscall.Sockaddr, error) {
ns, sa, err := syscall.Accept4(s, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) ns, sa, err := accept4Func(s, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
// On Linux the accept4 system call was introduced in 2.6.28 // On Linux the accept4 system call was introduced in 2.6.28
// kernel and on FreeBSD it was introduced in 10 kernel. If we // kernel and on FreeBSD it was introduced in 10 kernel. If we
// get an ENOSYS error on both Linux and FreeBSD, or EINVAL // get an ENOSYS error on both Linux and FreeBSD, or EINVAL
...@@ -63,7 +63,7 @@ func accept(s int) (int, syscall.Sockaddr, error) { ...@@ -63,7 +63,7 @@ func accept(s int) (int, syscall.Sockaddr, error) {
// because we have put fd.sysfd into non-blocking mode. // because we have put fd.sysfd into non-blocking mode.
// However, a call to the File method will put it back into // However, a call to the File method will put it back into
// blocking mode. We can't take that risk, so no use of ForkLock here. // blocking mode. We can't take that risk, so no use of ForkLock here.
ns, sa, err = syscall.Accept(s) ns, sa, err = acceptFunc(s)
if err == nil { if err == nil {
syscall.CloseOnExec(ns) syscall.CloseOnExec(ns)
} }
...@@ -71,7 +71,7 @@ func accept(s int) (int, syscall.Sockaddr, error) { ...@@ -71,7 +71,7 @@ func accept(s int) (int, syscall.Sockaddr, error) {
return -1, nil, err return -1, nil, err
} }
if err = syscall.SetNonblock(ns, true); err != nil { if err = syscall.SetNonblock(ns, true); err != nil {
syscall.Close(ns) closeFunc(ns)
return -1, nil, err return -1, nil, err
} }
return ns, sa, nil return ns, sa, nil
......
...@@ -42,11 +42,11 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s ...@@ -42,11 +42,11 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s
return nil, err return nil, err
} }
if err = setDefaultSockopts(s, family, sotype, ipv6only); err != nil { if err = setDefaultSockopts(s, family, sotype, ipv6only); err != nil {
closesocket(s) closeFunc(s)
return nil, err return nil, err
} }
if fd, err = newFD(s, family, sotype, net); err != nil { if fd, err = newFD(s, family, sotype, net); err != nil {
closesocket(s) closeFunc(s)
return nil, err return nil, err
} }
......
...@@ -12,10 +12,10 @@ func maxListenerBacklog() int { ...@@ -12,10 +12,10 @@ func maxListenerBacklog() int {
return syscall.SOMAXCONN return syscall.SOMAXCONN
} }
func sysSocket(f, t, p int) (syscall.Handle, error) { func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
// See ../syscall/exec_unix.go for description of ForkLock. // See ../syscall/exec_unix.go for description of ForkLock.
syscall.ForkLock.RLock() syscall.ForkLock.RLock()
s, err := syscall.Socket(f, t, p) s, err := socketFunc(family, sotype, proto)
if err == nil { if err == nil {
syscall.CloseOnExec(s) syscall.CloseOnExec(s)
} }
......
...@@ -16,7 +16,7 @@ import "syscall" ...@@ -16,7 +16,7 @@ import "syscall"
func sysSocket(family, sotype, proto int) (int, error) { func sysSocket(family, sotype, proto int) (int, error) {
// See ../syscall/exec_unix.go for description of ForkLock. // See ../syscall/exec_unix.go for description of ForkLock.
syscall.ForkLock.RLock() syscall.ForkLock.RLock()
s, err := syscall.Socket(family, sotype, proto) s, err := socketFunc(family, sotype, proto)
if err == nil { if err == nil {
syscall.CloseOnExec(s) syscall.CloseOnExec(s)
} }
...@@ -25,7 +25,7 @@ func sysSocket(family, sotype, proto int) (int, error) { ...@@ -25,7 +25,7 @@ func sysSocket(family, sotype, proto int) (int, error) {
return -1, err return -1, err
} }
if err = syscall.SetNonblock(s, true); err != nil { if err = syscall.SetNonblock(s, true); err != nil {
syscall.Close(s) closeFunc(s)
return -1, err return -1, err
} }
return s, nil return s, nil
...@@ -39,7 +39,7 @@ func accept(s int) (int, syscall.Sockaddr, error) { ...@@ -39,7 +39,7 @@ func accept(s int) (int, syscall.Sockaddr, error) {
// because we have put fd.sysfd into non-blocking mode. // because we have put fd.sysfd into non-blocking mode.
// However, a call to the File method will put it back into // However, a call to the File method will put it back into
// blocking mode. We can't take that risk, so no use of ForkLock here. // blocking mode. We can't take that risk, so no use of ForkLock here.
ns, sa, err := syscall.Accept(s) ns, sa, err := acceptFunc(s)
if err == nil { if err == nil {
syscall.CloseOnExec(ns) syscall.CloseOnExec(ns)
} }
...@@ -47,7 +47,7 @@ func accept(s int) (int, syscall.Sockaddr, error) { ...@@ -47,7 +47,7 @@ func accept(s int) (int, syscall.Sockaddr, error) {
return -1, nil, err return -1, nil, err
} }
if err = syscall.SetNonblock(ns, true); err != nil { if err = syscall.SetNonblock(ns, true); err != nil {
syscall.Close(ns) closeFunc(ns)
return -1, nil, err return -1, nil, err
} }
return ns, sa, nil return ns, sa, nil
......
...@@ -59,6 +59,9 @@ func BenchmarkTCP6PersistentTimeout(b *testing.B) { ...@@ -59,6 +59,9 @@ func BenchmarkTCP6PersistentTimeout(b *testing.B) {
} }
func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) {
uninstallTestHooks()
defer installTestHooks()
const msgLen = 512 const msgLen = 512
conns := b.N conns := b.N
numConcurrent := runtime.GOMAXPROCS(-1) * 2 numConcurrent := runtime.GOMAXPROCS(-1) * 2
...@@ -172,6 +175,8 @@ func benchmarkTCPConcurrentReadWrite(b *testing.B, laddr string) { ...@@ -172,6 +175,8 @@ func benchmarkTCPConcurrentReadWrite(b *testing.B, laddr string) {
// The benchmark stresses concurrent reading and writing to the same connection. // The benchmark stresses concurrent reading and writing to the same connection.
// Such pattern is used in net/http and net/rpc. // Such pattern is used in net/http and net/rpc.
uninstallTestHooks()
defer installTestHooks()
b.StopTimer() b.StopTimer()
P := runtime.GOMAXPROCS(0) P := runtime.GOMAXPROCS(0)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment