Commit d5c2a621 authored by Yury Selivanov's avatar Yury Selivanov

asyncio: Skip getaddrinfo if host is already resolved.

getaddrinfo takes an exclusive lock on some platforms, causing clients to queue
up waiting for the lock if many names are being resolved concurrently. Users
may want to handle name resolution in their own code, for the sake of caching,
using an alternate resolver, or to measure DNS duration separately from
connection duration. Skip getaddrinfo if the "host" passed into
create_connection is already resolved.

See https://github.com/python/asyncio/pull/302 for details.

Patch by A. Jesse Jiryu Davis.
parent 8c084eb7
...@@ -16,8 +16,10 @@ to modify the meaning of the API call itself. ...@@ -16,8 +16,10 @@ to modify the meaning of the API call itself.
import collections import collections
import concurrent.futures import concurrent.futures
import functools
import heapq import heapq
import inspect import inspect
import ipaddress
import itertools import itertools
import logging import logging
import os import os
...@@ -70,49 +72,83 @@ def _format_pipe(fd): ...@@ -70,49 +72,83 @@ def _format_pipe(fd):
return repr(fd) return repr(fd)
def _check_resolved_address(sock, address): # Linux's sock.type is a bitmask that can include extra info about socket.
# Ensure that the address is already resolved to avoid the trap of hanging _SOCKET_TYPE_MASK = 0
# the entire event loop when the address requires doing a DNS lookup. if hasattr(socket, 'SOCK_NONBLOCK'):
# _SOCKET_TYPE_MASK |= socket.SOCK_NONBLOCK
# getaddrinfo() is slow (around 10 us per call): this function should only if hasattr(socket, 'SOCK_CLOEXEC'):
# be called in debug mode _SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC
family = sock.family
if family == socket.AF_INET: @functools.lru_cache(maxsize=1024)
host, port = address def _ipaddr_info(host, port, family, type, proto):
elif family == socket.AF_INET6: # Try to skip getaddrinfo if "host" is already an IP. Since getaddrinfo
host, port = address[:2] # blocks on an exclusive lock on some platforms, users might handle name
# resolution in their own code and pass in resolved IPs.
if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or host is None:
return None
type &= ~_SOCKET_TYPE_MASK
if type == socket.SOCK_STREAM:
proto = socket.IPPROTO_TCP
elif type == socket.SOCK_DGRAM:
proto = socket.IPPROTO_UDP
else: else:
return return None
# On Windows, socket.inet_pton() is only available since Python 3.4
if hasattr(socket, 'inet_pton'): if hasattr(socket, 'inet_pton'):
# getaddrinfo() is slow and has known issue: prefer inet_pton() if family == socket.AF_UNSPEC:
# if available afs = [socket.AF_INET, socket.AF_INET6]
else:
afs = [family]
for af in afs:
# Linux's inet_pton doesn't accept an IPv6 zone index after host,
# like '::1%lo0', so strip it. If we happen to make an invalid
# address look valid, we fail later in sock.connect or sock.bind.
try: try:
socket.inet_pton(family, host) if af == socket.AF_INET6:
except OSError as exc: socket.inet_pton(af, host.partition('%')[0])
raise ValueError("address must be resolved (IP address), "
"got host %r: %s"
% (host, exc))
else: else:
# Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is socket.inet_pton(af, host)
# already resolved. return af, type, proto, '', (host, port)
type_mask = 0 except OSError:
if hasattr(socket, 'SOCK_NONBLOCK'): pass
type_mask |= socket.SOCK_NONBLOCK
if hasattr(socket, 'SOCK_CLOEXEC'): # "host" is not an IP address.
type_mask |= socket.SOCK_CLOEXEC return None
# No inet_pton. (On Windows it's only available since Python 3.4.)
# Even though getaddrinfo with AI_NUMERICHOST would be non-blocking, it
# still requires a lock on some platforms, and waiting for that lock could
# block the event loop. Use ipaddress instead, it's just text parsing.
try:
addr = ipaddress.IPv4Address(host)
except ValueError:
try: try:
socket.getaddrinfo(host, port, addr = ipaddress.IPv6Address(host.partition('%')[0])
family=family, except ValueError:
type=(sock.type & ~type_mask), return None
proto=sock.proto,
flags=socket.AI_NUMERICHOST) af = socket.AF_INET if addr.version == 4 else socket.AF_INET6
except socket.gaierror as err: if family not in (socket.AF_UNSPEC, af):
raise ValueError("address must be resolved (IP address), " # "host" is wrong IP version for "family".
"got host %r: %s" return None
% (host, err))
return af, type, proto, '', (host, port)
def _check_resolved_address(sock, address):
# Ensure that the address is already resolved to avoid the trap of hanging
# the entire event loop when the address requires doing a DNS lookup.
if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
return
host, port = address[:2]
if _ipaddr_info(host, port, sock.family, sock.type, sock.proto) is None:
raise ValueError("address must be resolved (IP address),"
" got host %r" % host)
def _run_until_complete_cb(fut): def _run_until_complete_cb(fut):
...@@ -535,7 +571,12 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -535,7 +571,12 @@ class BaseEventLoop(events.AbstractEventLoop):
def getaddrinfo(self, host, port, *, def getaddrinfo(self, host, port, *,
family=0, type=0, proto=0, flags=0): family=0, type=0, proto=0, flags=0):
if self._debug: info = _ipaddr_info(host, port, family, type, proto)
if info is not None:
fut = futures.Future(loop=self)
fut.set_result([info])
return fut
elif self._debug:
return self.run_in_executor(None, self._getaddrinfo_debug, return self.run_in_executor(None, self._getaddrinfo_debug,
host, port, family, type, proto, flags) host, port, family, type, proto, flags)
else: else:
......
...@@ -441,7 +441,6 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): ...@@ -441,7 +441,6 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
def sock_connect(self, sock, address): def sock_connect(self, sock, address):
try: try:
if self._debug:
base_events._check_resolved_address(sock, address) base_events._check_resolved_address(sock, address)
except ValueError as err: except ValueError as err:
fut = futures.Future(loop=self) fut = futures.Future(loop=self)
......
...@@ -397,7 +397,6 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -397,7 +397,6 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
fut = futures.Future(loop=self) fut = futures.Future(loop=self)
try: try:
if self._debug:
base_events._check_resolved_address(sock, address) base_events._check_resolved_address(sock, address)
except ValueError as err: except ValueError as err:
fut.set_exception(err) fut.set_exception(err)
......
...@@ -446,9 +446,14 @@ def disable_logger(): ...@@ -446,9 +446,14 @@ def disable_logger():
finally: finally:
logger.setLevel(old_level) logger.setLevel(old_level)
def mock_nonblocking_socket():
def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
family=socket.AF_INET):
"""Create a mock of a non-blocking socket.""" """Create a mock of a non-blocking socket."""
sock = mock.Mock(socket.socket) sock = mock.MagicMock(socket.socket)
sock.proto = proto
sock.type = type
sock.family = family
sock.gettimeout.return_value = 0.0 sock.gettimeout.return_value = 0.0
return sock return sock
......
...@@ -32,6 +32,120 @@ MOCK_ANY = mock.ANY ...@@ -32,6 +32,120 @@ MOCK_ANY = mock.ANY
PY34 = sys.version_info >= (3, 4) PY34 = sys.version_info >= (3, 4)
def mock_socket_module():
m_socket = mock.MagicMock(spec=socket)
for name in (
'AF_INET', 'AF_INET6', 'AF_UNSPEC', 'IPPROTO_TCP', 'IPPROTO_UDP',
'SOCK_STREAM', 'SOCK_DGRAM', 'SOL_SOCKET', 'SO_REUSEADDR', 'inet_pton'
):
if hasattr(socket, name):
setattr(m_socket, name, getattr(socket, name))
else:
delattr(m_socket, name)
m_socket.socket = mock.MagicMock()
m_socket.socket.return_value = test_utils.mock_nonblocking_socket()
return m_socket
def patch_socket(f):
return mock.patch('asyncio.base_events.socket',
new_callable=mock_socket_module)(f)
class BaseEventTests(test_utils.TestCase):
def setUp(self):
super().setUp()
base_events._ipaddr_info.cache_clear()
def tearDown(self):
base_events._ipaddr_info.cache_clear()
super().tearDown()
def test_ipaddr_info(self):
UNSPEC = socket.AF_UNSPEC
INET = socket.AF_INET
INET6 = socket.AF_INET6
STREAM = socket.SOCK_STREAM
DGRAM = socket.SOCK_DGRAM
TCP = socket.IPPROTO_TCP
UDP = socket.IPPROTO_UDP
self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, INET, STREAM, TCP))
self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, TCP))
self.assertEqual(
(INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, UDP))
# Socket type STREAM implies TCP protocol.
self.assertEqual(
(INET, STREAM, TCP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, 0))
# Socket type DGRAM implies UDP protocol.
self.assertEqual(
(INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, 0))
# No socket type.
self.assertIsNone(
base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, 0, 0))
# IPv4 address with family IPv6.
self.assertIsNone(
base_events._ipaddr_info('1.2.3.4', 1, INET6, STREAM, TCP))
self.assertEqual(
(INET6, STREAM, TCP, '', ('::3', 1)),
base_events._ipaddr_info('::3', 1, INET6, STREAM, TCP))
self.assertEqual(
(INET6, STREAM, TCP, '', ('::3', 1)),
base_events._ipaddr_info('::3', 1, UNSPEC, STREAM, TCP))
# IPv6 address with family IPv4.
self.assertIsNone(
base_events._ipaddr_info('::3', 1, INET, STREAM, TCP))
# IPv6 address with zone index.
self.assertEqual(
(INET6, STREAM, TCP, '', ('::3%lo0', 1)),
base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
@patch_socket
def test_ipaddr_info_no_inet_pton(self, m_socket):
del m_socket.inet_pton
self.test_ipaddr_info()
def test_check_resolved_address(self):
sock = socket.socket(socket.AF_INET)
base_events._check_resolved_address(sock, ('1.2.3.4', 1))
sock = socket.socket(socket.AF_INET6)
base_events._check_resolved_address(sock, ('::3', 1))
base_events._check_resolved_address(sock, ('::3%lo0', 1))
self.assertRaises(ValueError,
base_events._check_resolved_address, sock, ('foo', 1))
def test_check_resolved_sock_type(self):
# Ensure we ignore extra flags in sock.type.
if hasattr(socket, 'SOCK_NONBLOCK'):
sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
base_events._check_resolved_address(sock, ('1.2.3.4', 1))
if hasattr(socket, 'SOCK_CLOEXEC'):
sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
base_events._check_resolved_address(sock, ('1.2.3.4', 1))
class BaseEventLoopTests(test_utils.TestCase): class BaseEventLoopTests(test_utils.TestCase):
def setUp(self): def setUp(self):
...@@ -875,7 +989,12 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -875,7 +989,12 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop) self.set_event_loop(self.loop)
@mock.patch('asyncio.base_events.socket') def tearDown(self):
# Clear mocked constants like AF_INET from the cache.
base_events._ipaddr_info.cache_clear()
super().tearDown()
@patch_socket
def test_create_connection_multiple_errors(self, m_socket): def test_create_connection_multiple_errors(self, m_socket):
class MyProto(asyncio.Protocol): class MyProto(asyncio.Protocol):
...@@ -908,7 +1027,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -908,7 +1027,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2')
@mock.patch('asyncio.base_events.socket') @patch_socket
def test_create_connection_timeout(self, m_socket): def test_create_connection_timeout(self, m_socket):
# Ensure that the socket is closed on timeout # Ensure that the socket is closed on timeout
sock = mock.Mock() sock = mock.Mock()
...@@ -986,7 +1105,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -986,7 +1105,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
with self.assertRaises(OSError): with self.assertRaises(OSError):
self.loop.run_until_complete(coro) self.loop.run_until_complete(coro)
@mock.patch('asyncio.base_events.socket') @patch_socket
def test_create_connection_multiple_errors_local_addr(self, m_socket): def test_create_connection_multiple_errors_local_addr(self, m_socket):
def bind(addr): def bind(addr):
...@@ -1018,6 +1137,46 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1018,6 +1137,46 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) self.assertTrue(str(cm.exception).startswith('Multiple exceptions: '))
self.assertTrue(m_socket.socket.return_value.close.called) self.assertTrue(m_socket.socket.return_value.close.called)
def _test_create_connection_ip_addr(self, m_socket, allow_inet_pton):
# Test the fallback code, even if this system has inet_pton.
if not allow_inet_pton:
del m_socket.inet_pton
def getaddrinfo(*args, **kw):
self.fail('should not have called getaddrinfo')
m_socket.getaddrinfo = getaddrinfo
sock = m_socket.socket.return_value
self.loop.add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False
self.loop.add_writer = mock.Mock()
self.loop.add_writer._is_coroutine = False
coro = self.loop.create_connection(MyProto, '1.2.3.4', 80)
self.loop.run_until_complete(coro)
sock.connect.assert_called_with(('1.2.3.4', 80))
m_socket.socket.assert_called_with(family=m_socket.AF_INET,
proto=m_socket.IPPROTO_TCP,
type=m_socket.SOCK_STREAM)
sock.family = socket.AF_INET6
coro = self.loop.create_connection(MyProto, '::2', 80)
self.loop.run_until_complete(coro)
sock.connect.assert_called_with(('::2', 80))
m_socket.socket.assert_called_with(family=m_socket.AF_INET6,
proto=m_socket.IPPROTO_TCP,
type=m_socket.SOCK_STREAM)
@patch_socket
def test_create_connection_ip_addr(self, m_socket):
self._test_create_connection_ip_addr(m_socket, True)
@patch_socket
def test_create_connection_no_inet_pton(self, m_socket):
self._test_create_connection_ip_addr(m_socket, False)
def test_create_connection_no_local_addr(self): def test_create_connection_no_local_addr(self):
@asyncio.coroutine @asyncio.coroutine
def getaddrinfo(host, *args, **kw): def getaddrinfo(host, *args, **kw):
...@@ -1153,11 +1312,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1153,11 +1312,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
f = self.loop.create_server(MyProto, '0.0.0.0', 0) f = self.loop.create_server(MyProto, '0.0.0.0', 0)
self.assertRaises(OSError, self.loop.run_until_complete, f) self.assertRaises(OSError, self.loop.run_until_complete, f)
@mock.patch('asyncio.base_events.socket') @patch_socket
def test_create_server_nosoreuseport(self, m_socket): def test_create_server_nosoreuseport(self, m_socket):
m_socket.getaddrinfo = socket.getaddrinfo m_socket.getaddrinfo = socket.getaddrinfo
m_socket.SOCK_STREAM = socket.SOCK_STREAM
m_socket.SOL_SOCKET = socket.SOL_SOCKET
del m_socket.SO_REUSEPORT del m_socket.SO_REUSEPORT
m_socket.socket.return_value = mock.Mock() m_socket.socket.return_value = mock.Mock()
...@@ -1166,7 +1323,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1166,7 +1323,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(ValueError, self.loop.run_until_complete, f) self.assertRaises(ValueError, self.loop.run_until_complete, f)
@mock.patch('asyncio.base_events.socket') @patch_socket
def test_create_server_cant_bind(self, m_socket): def test_create_server_cant_bind(self, m_socket):
class Err(OSError): class Err(OSError):
...@@ -1182,7 +1339,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1182,7 +1339,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(OSError, self.loop.run_until_complete, fut) self.assertRaises(OSError, self.loop.run_until_complete, fut)
self.assertTrue(m_sock.close.called) self.assertTrue(m_sock.close.called)
@mock.patch('asyncio.base_events.socket') @patch_socket
def test_create_datagram_endpoint_no_addrinfo(self, m_socket): def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
m_socket.getaddrinfo.return_value = [] m_socket.getaddrinfo.return_value = []
m_socket.getaddrinfo._is_coroutine = False m_socket.getaddrinfo._is_coroutine = False
...@@ -1211,7 +1368,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1211,7 +1368,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises( self.assertRaises(
OSError, self.loop.run_until_complete, coro) OSError, self.loop.run_until_complete, coro)
@mock.patch('asyncio.base_events.socket') @patch_socket
def test_create_datagram_endpoint_socket_err(self, m_socket): def test_create_datagram_endpoint_socket_err(self, m_socket):
m_socket.getaddrinfo = socket.getaddrinfo m_socket.getaddrinfo = socket.getaddrinfo
m_socket.socket.side_effect = OSError m_socket.socket.side_effect = OSError
...@@ -1234,7 +1391,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1234,7 +1391,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises( self.assertRaises(
ValueError, self.loop.run_until_complete, coro) ValueError, self.loop.run_until_complete, coro)
@mock.patch('asyncio.base_events.socket') @patch_socket
def test_create_datagram_endpoint_setblk_err(self, m_socket): def test_create_datagram_endpoint_setblk_err(self, m_socket):
m_socket.socket.return_value.setblocking.side_effect = OSError m_socket.socket.return_value.setblocking.side_effect = OSError
...@@ -1250,12 +1407,11 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1250,12 +1407,11 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
asyncio.DatagramProtocol) asyncio.DatagramProtocol)
self.assertRaises(ValueError, self.loop.run_until_complete, coro) self.assertRaises(ValueError, self.loop.run_until_complete, coro)
@mock.patch('asyncio.base_events.socket') @patch_socket
def test_create_datagram_endpoint_cant_bind(self, m_socket): def test_create_datagram_endpoint_cant_bind(self, m_socket):
class Err(OSError): class Err(OSError):
pass pass
m_socket.AF_INET6 = socket.AF_INET6
m_socket.getaddrinfo = socket.getaddrinfo m_socket.getaddrinfo = socket.getaddrinfo
m_sock = m_socket.socket.return_value = mock.Mock() m_sock = m_socket.socket.return_value = mock.Mock()
m_sock.bind.side_effect = Err m_sock.bind.side_effect = Err
...@@ -1369,11 +1525,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1369,11 +1525,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop.run_until_complete(protocol.done) self.loop.run_until_complete(protocol.done)
self.assertEqual('CLOSED', protocol.state) self.assertEqual('CLOSED', protocol.state)
@mock.patch('asyncio.base_events.socket') @patch_socket
def test_create_datagram_endpoint_nosoreuseport(self, m_socket): def test_create_datagram_endpoint_nosoreuseport(self, m_socket):
m_socket.getaddrinfo = socket.getaddrinfo
m_socket.SOCK_DGRAM = socket.SOCK_DGRAM
m_socket.SOL_SOCKET = socket.SOL_SOCKET
del m_socket.SO_REUSEPORT del m_socket.SO_REUSEPORT
m_socket.socket.return_value = mock.Mock() m_socket.socket.return_value = mock.Mock()
...@@ -1385,6 +1538,29 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1385,6 +1538,29 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(ValueError, self.loop.run_until_complete, coro) self.assertRaises(ValueError, self.loop.run_until_complete, coro)
@patch_socket
def test_create_datagram_endpoint_ip_addr(self, m_socket):
def getaddrinfo(*args, **kw):
self.fail('should not have called getaddrinfo')
m_socket.getaddrinfo = getaddrinfo
m_socket.socket.return_value.bind = bind = mock.Mock()
self.loop.add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False
reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
coro = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop),
local_addr=('1.2.3.4', 0),
reuse_address=False,
reuse_port=reuseport_supported)
self.loop.run_until_complete(coro)
bind.assert_called_with(('1.2.3.4', 0))
m_socket.socket.assert_called_with(family=m_socket.AF_INET,
proto=m_socket.IPPROTO_UDP,
type=m_socket.SOCK_DGRAM)
def test_accept_connection_retry(self): def test_accept_connection_retry(self):
sock = mock.Mock() sock = mock.Mock()
sock.accept.side_effect = BlockingIOError() sock.accept.side_effect = BlockingIOError()
......
...@@ -1573,10 +1573,6 @@ class EventLoopTestsMixin: ...@@ -1573,10 +1573,6 @@ class EventLoopTestsMixin:
'selector': self.loop._selector.__class__.__name__}) 'selector': self.loop._selector.__class__.__name__})
def test_sock_connect_address(self): def test_sock_connect_address(self):
# In debug mode, sock_connect() must ensure that the address is already
# resolved (call _check_resolved_address())
self.loop.set_debug(True)
addresses = [(socket.AF_INET, ('www.python.org', 80))] addresses = [(socket.AF_INET, ('www.python.org', 80))]
if support.IPV6_ENABLED: if support.IPV6_ENABLED:
addresses.extend(( addresses.extend((
......
...@@ -436,7 +436,7 @@ class ProactorSocketTransportTests(test_utils.TestCase): ...@@ -436,7 +436,7 @@ class ProactorSocketTransportTests(test_utils.TestCase):
class BaseProactorEventLoopTests(test_utils.TestCase): class BaseProactorEventLoopTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.sock = mock.Mock(socket.socket) self.sock = test_utils.mock_nonblocking_socket()
self.proactor = mock.Mock() self.proactor = mock.Mock()
self.ssock, self.csock = mock.Mock(), mock.Mock() self.ssock, self.csock = mock.Mock(), mock.Mock()
...@@ -491,8 +491,8 @@ class BaseProactorEventLoopTests(test_utils.TestCase): ...@@ -491,8 +491,8 @@ class BaseProactorEventLoopTests(test_utils.TestCase):
self.proactor.send.assert_called_with(self.sock, b'data') self.proactor.send.assert_called_with(self.sock, b'data')
def test_sock_connect(self): def test_sock_connect(self):
self.loop.sock_connect(self.sock, 123) self.loop.sock_connect(self.sock, ('1.2.3.4', 123))
self.proactor.connect.assert_called_with(self.sock, 123) self.proactor.connect.assert_called_with(self.sock, ('1.2.3.4', 123))
def test_sock_accept(self): def test_sock_accept(self):
self.loop.sock_accept(self.sock) self.loop.sock_accept(self.sock)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment