Commit d5aeccf9 authored by Victor Stinner's avatar Victor Stinner

asyncio, Tulip issue 205: Fix a race condition in BaseSelectorEventLoop.sock_connect()

There is a race condition in create_connection() used with wait_for() to have a
timeout. sock_connect() registers the file descriptor of the socket to be
notified of write event (if connect() raises BlockingIOError). When
create_connection() is cancelled with a TimeoutError, sock_connect() coroutine
gets the exception, but it doesn't unregister the file descriptor for write
event. create_connection() gets the TimeoutError and closes the socket.

If you call again create_connection(), the new socket will likely gets the same
file descriptor, which is still registered in the selector. When sock_connect()
calls add_writer(), it tries to modify the entry instead of creating a new one.

This issue was originally reported in the Trollius project, but the bug comes
from Tulip in fact (Trollius is based on Tulip):
https://bitbucket.org/enovance/trollius/issue/15/after-timeouterror-on-wait_for

This change fixes the race condition. It also makes sock_connect() more
reliable (and portable) is sock.connect() raises an InterruptedError.
parent 41f3c3f2
...@@ -8,6 +8,7 @@ __all__ = ['BaseSelectorEventLoop'] ...@@ -8,6 +8,7 @@ __all__ = ['BaseSelectorEventLoop']
import collections import collections
import errno import errno
import functools
import socket import socket
try: try:
import ssl import ssl
...@@ -345,26 +346,43 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -345,26 +346,43 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
except ValueError as err: except ValueError as err:
fut.set_exception(err) fut.set_exception(err)
else: else:
self._sock_connect(fut, False, sock, address) self._sock_connect(fut, sock, address)
return fut return fut
def _sock_connect(self, fut, registered, sock, address): def _sock_connect(self, fut, sock, address):
fd = sock.fileno() fd = sock.fileno()
if registered: try:
self.remove_writer(fd) while True:
try:
sock.connect(address)
except InterruptedError:
continue
else:
break
except BlockingIOError:
fut.add_done_callback(functools.partial(self._sock_connect_done,
sock))
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(None)
def _sock_connect_done(self, sock, fut):
self.remove_writer(sock.fileno())
def _sock_connect_cb(self, fut, sock, address):
if fut.cancelled(): if fut.cancelled():
return return
try: try:
if not registered: err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
# First time around. if err != 0:
sock.connect(address) # Jump to any except clause below.
else: raise OSError(err, 'Connect call failed %s' % (address,))
err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
# Jump to the except clause below.
raise OSError(err, 'Connect call failed %s' % (address,))
except (BlockingIOError, InterruptedError): except (BlockingIOError, InterruptedError):
self.add_writer(fd, self._sock_connect, fut, True, sock, address) # socket is still registered, the callback will be retried later
pass
except Exception as exc: except Exception as exc:
fut.set_exception(exc) fut.set_exception(exc)
else: else:
......
...@@ -40,8 +40,9 @@ def list_to_buffer(l=()): ...@@ -40,8 +40,9 @@ def list_to_buffer(l=()):
class BaseSelectorEventLoopTests(test_utils.TestCase): class BaseSelectorEventLoopTests(test_utils.TestCase):
def setUp(self): def setUp(self):
selector = mock.Mock() self.selector = mock.Mock()
self.loop = TestBaseSelectorEventLoop(selector) self.selector.select.return_value = []
self.loop = TestBaseSelectorEventLoop(self.selector)
self.set_event_loop(self.loop, cleanup=False) self.set_event_loop(self.loop, cleanup=False)
def test_make_socket_transport(self): def test_make_socket_transport(self):
...@@ -303,63 +304,92 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): ...@@ -303,63 +304,92 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
self.assertIsInstance(f, asyncio.Future) self.assertIsInstance(f, asyncio.Future)
self.assertEqual( self.assertEqual(
(f, False, sock, ('127.0.0.1', 8080)), (f, sock, ('127.0.0.1', 8080)),
self.loop._sock_connect.call_args[0]) self.loop._sock_connect.call_args[0])
def test_sock_connect_timeout(self):
# Tulip issue #205: sock_connect() must unregister the socket on
# timeout error
# prepare mocks
self.loop.add_writer = mock.Mock()
self.loop.remove_writer = mock.Mock()
sock = test_utils.mock_nonblocking_socket()
sock.connect.side_effect = BlockingIOError
# first call to sock_connect() registers the socket
fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
self.assertTrue(sock.connect.called)
self.assertTrue(self.loop.add_writer.called)
self.assertEqual(len(fut._callbacks), 1)
# on timeout, the socket must be unregistered
sock.connect.reset_mock()
fut.set_exception(asyncio.TimeoutError)
with self.assertRaises(asyncio.TimeoutError):
self.loop.run_until_complete(fut)
self.assertTrue(self.loop.remove_writer.called)
def test__sock_connect(self): def test__sock_connect(self):
f = asyncio.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = mock.Mock() sock = mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) self.loop._sock_connect(f, sock, ('127.0.0.1', 8080))
self.assertTrue(f.done()) self.assertTrue(f.done())
self.assertIsNone(f.result()) self.assertIsNone(f.result())
self.assertTrue(sock.connect.called) self.assertTrue(sock.connect.called)
def test__sock_connect_canceled_fut(self): def test__sock_connect_cb_cancelled_fut(self):
sock = mock.Mock() sock = mock.Mock()
self.loop.remove_writer = mock.Mock()
f = asyncio.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() f.cancel()
self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
self.assertFalse(sock.connect.called) self.assertFalse(sock.getsockopt.called)
def test__sock_connect_writer(self):
# check that the fd is registered and then unregistered
self.loop._process_events = mock.Mock()
self.loop.add_writer = mock.Mock()
self.loop.remove_writer = mock.Mock()
def test__sock_connect_unregister(self):
sock = mock.Mock() sock = mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.connect.side_effect = BlockingIOError
sock.getsockopt.return_value = 0
address = ('127.0.0.1', 8080)
f = asyncio.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() self.loop._sock_connect(f, sock, address)
self.assertTrue(self.loop.add_writer.called)
self.assertEqual(10, self.loop.add_writer.call_args[0][0])
self.loop.remove_writer = mock.Mock() self.loop._sock_connect_cb(f, sock, address)
self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) # need to run the event loop to execute _sock_connect_done() callback
self.loop.run_until_complete(f)
self.assertEqual((10,), self.loop.remove_writer.call_args[0]) self.assertEqual((10,), self.loop.remove_writer.call_args[0])
def test__sock_connect_tryagain(self): def test__sock_connect_cb_tryagain(self):
f = asyncio.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = mock.Mock() sock = mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.getsockopt.return_value = errno.EAGAIN sock.getsockopt.return_value = errno.EAGAIN
self.loop.add_writer = mock.Mock() # check that the exception is handled
self.loop.remove_writer = mock.Mock() self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
self.assertEqual(
(10, self.loop._sock_connect, f,
True, sock, ('127.0.0.1', 8080)),
self.loop.add_writer.call_args[0])
def test__sock_connect_exception(self): def test__sock_connect_cb_exception(self):
f = asyncio.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = mock.Mock() sock = mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.getsockopt.return_value = errno.ENOTCONN sock.getsockopt.return_value = errno.ENOTCONN
self.loop.remove_writer = mock.Mock() self.loop.remove_writer = mock.Mock()
self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
self.assertIsInstance(f.exception(), OSError) self.assertIsInstance(f.exception(), OSError)
def test_sock_accept(self): def test_sock_accept(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment