Commit 2934262f authored by Victor Stinner's avatar Victor Stinner

asyncio: sync with Tulip

* Cleanup gather(): use cancelled() method instead of using private Future
  attribute
* Fix _UnixReadPipeTransport and _UnixWritePipeTransport. Only start reading
  when connection_made() has been called.
* Issue #23333: Fix BaseSelectorEventLoop._accept_connection(). Close the
  transport on error. In debug mode, log errors using call_exception_handler()
parent 54a231d5
...@@ -22,6 +22,7 @@ from . import futures ...@@ -22,6 +22,7 @@ from . import futures
from . import selectors from . import selectors
from . import transports from . import transports
from . import sslproto from . import sslproto
from .coroutines import coroutine
from .log import logger from .log import logger
...@@ -181,16 +182,47 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -181,16 +182,47 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
else: else:
raise # The event loop will catch, log and ignore it. raise # The event loop will catch, log and ignore it.
else: else:
extra = {'peername': addr}
accept = self._accept_connection2(protocol_factory, conn, extra,
sslcontext, server)
self.create_task(accept)
@coroutine
def _accept_connection2(self, protocol_factory, conn, extra,
sslcontext=None, server=None):
protocol = None
transport = None
try:
protocol = protocol_factory() protocol = protocol_factory()
waiter = futures.Future(loop=self)
if sslcontext: if sslcontext:
self._make_ssl_transport( transport = self._make_ssl_transport(
conn, protocol, sslcontext, conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra={'peername': addr}, server=server) server_side=True, extra=extra, server=server)
else: else:
self._make_socket_transport( transport = self._make_socket_transport(
conn, protocol , extra={'peername': addr}, conn, protocol, waiter=waiter, extra=extra,
server=server) server=server)
# It's now up to the protocol to handle the connection.
try:
yield from waiter
except:
transport.close()
raise
# It's now up to the protocol to handle the connection.
except Exception as exc:
if self.get_debug():
context = {
'message': ('Error on transport creation '
'for incoming connection'),
'exception': exc,
}
if protocol is not None:
context['protocol'] = protocol
if transport is not None:
context['transport'] = transport
self.call_exception_handler(context)
def add_reader(self, fd, callback, *args): def add_reader(self, fd, callback, *args):
"""Add a reader callback.""" """Add a reader callback."""
......
...@@ -592,7 +592,7 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): ...@@ -592,7 +592,7 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
fut.exception() fut.exception()
return return
if fut._state == futures._CANCELLED: if fut.cancelled():
res = futures.CancelledError() res = futures.CancelledError()
if not return_exceptions: if not return_exceptions:
outer.set_exception(res) outer.set_exception(res)
......
...@@ -298,8 +298,10 @@ class _UnixReadPipeTransport(transports.ReadTransport): ...@@ -298,8 +298,10 @@ class _UnixReadPipeTransport(transports.ReadTransport):
_set_nonblocking(self._fileno) _set_nonblocking(self._fileno)
self._protocol = protocol self._protocol = protocol
self._closing = False self._closing = False
self._loop.add_reader(self._fileno, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
self._fileno, self._read_ready)
if waiter is not None: if waiter is not None:
# only wake up the waiter when connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
...@@ -401,13 +403,16 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, ...@@ -401,13 +403,16 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
self._conn_lost = 0 self._conn_lost = 0
self._closing = False # Set when close() or write_eof() called. self._closing = False # Set when close() or write_eof() called.
# On AIX, the reader trick only works for sockets. self._loop.call_soon(self._protocol.connection_made, self)
# On other platforms it works for pipes and sockets.
# (Exception: OS X 10.4? Issue #19294.) # On AIX, the reader trick (to be notified when the read end of the
# socket is closed) only works for sockets. On other platforms it
# works for pipes and sockets. (Exception: OS X 10.4? Issue #19294.)
if is_socket or not sys.platform.startswith("aix"): if is_socket or not sys.platform.startswith("aix"):
self._loop.add_reader(self._fileno, self._read_ready) # only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
self._fileno, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# only wake up the waiter when connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
......
...@@ -886,13 +886,18 @@ class EventLoopTestsMixin: ...@@ -886,13 +886,18 @@ class EventLoopTestsMixin:
if hasattr(sslcontext_client, 'check_hostname'): if hasattr(sslcontext_client, 'check_hostname'):
sslcontext_client.check_hostname = True sslcontext_client.check_hostname = True
# no CA loaded # no CA loaded
f_c = self.loop.create_connection(MyProto, host, port, f_c = self.loop.create_connection(MyProto, host, port,
ssl=sslcontext_client) ssl=sslcontext_client)
with test_utils.disable_logger(): with mock.patch.object(self.loop, 'call_exception_handler'):
with self.assertRaisesRegex(ssl.SSLError, with test_utils.disable_logger():
'certificate verify failed '): with self.assertRaisesRegex(ssl.SSLError,
self.loop.run_until_complete(f_c) 'certificate verify failed '):
self.loop.run_until_complete(f_c)
# execute the loop to log the connection error
test_utils.run_briefly(self.loop)
# close connection # close connection
self.assertIsNone(proto.transport) self.assertIsNone(proto.transport)
...@@ -919,15 +924,20 @@ class EventLoopTestsMixin: ...@@ -919,15 +924,20 @@ class EventLoopTestsMixin:
f_c = self.loop.create_unix_connection(MyProto, path, f_c = self.loop.create_unix_connection(MyProto, path,
ssl=sslcontext_client, ssl=sslcontext_client,
server_hostname='invalid') server_hostname='invalid')
with test_utils.disable_logger(): with mock.patch.object(self.loop, 'call_exception_handler'):
with self.assertRaisesRegex(ssl.SSLError, with test_utils.disable_logger():
'certificate verify failed '): with self.assertRaisesRegex(ssl.SSLError,
self.loop.run_until_complete(f_c) 'certificate verify failed '):
self.loop.run_until_complete(f_c)
# execute the loop to log the connection error
test_utils.run_briefly(self.loop)
# close connection # close connection
self.assertIsNone(proto.transport) self.assertIsNone(proto.transport)
server.close() server.close()
def test_legacy_create_unix_server_ssl_verify_failed(self): def test_legacy_create_unix_server_ssl_verify_failed(self):
with test_utils.force_legacy_ssl_support(): with test_utils.force_legacy_ssl_support():
self.test_create_unix_server_ssl_verify_failed() self.test_create_unix_server_ssl_verify_failed()
...@@ -949,11 +959,12 @@ class EventLoopTestsMixin: ...@@ -949,11 +959,12 @@ class EventLoopTestsMixin:
# incorrect server_hostname # incorrect server_hostname
f_c = self.loop.create_connection(MyProto, host, port, f_c = self.loop.create_connection(MyProto, host, port,
ssl=sslcontext_client) ssl=sslcontext_client)
with test_utils.disable_logger(): with mock.patch.object(self.loop, 'call_exception_handler'):
with self.assertRaisesRegex( with test_utils.disable_logger():
ssl.CertificateError, with self.assertRaisesRegex(
"hostname '127.0.0.1' doesn't match 'localhost'"): ssl.CertificateError,
self.loop.run_until_complete(f_c) "hostname '127.0.0.1' doesn't match 'localhost'"):
self.loop.run_until_complete(f_c)
# close connection # close connection
proto.transport.close() proto.transport.close()
......
...@@ -350,16 +350,13 @@ class UnixReadPipeTransportTests(test_utils.TestCase): ...@@ -350,16 +350,13 @@ class UnixReadPipeTransportTests(test_utils.TestCase):
return transport return transport
def test_ctor(self): def test_ctor(self):
tr = self.read_pipe_transport() waiter = asyncio.Future(loop=self.loop)
self.loop.assert_reader(5, tr._read_ready) tr = self.read_pipe_transport(waiter=waiter)
test_utils.run_briefly(self.loop) self.loop.run_until_complete(waiter)
self.protocol.connection_made.assert_called_with(tr)
def test_ctor_with_waiter(self): self.protocol.connection_made.assert_called_with(tr)
fut = asyncio.Future(loop=self.loop) self.loop.assert_reader(5, tr._read_ready)
tr = self.read_pipe_transport(waiter=fut) self.assertIsNone(waiter.result())
test_utils.run_briefly(self.loop)
self.assertIsNone(fut.result())
@mock.patch('os.read') @mock.patch('os.read')
def test__read_ready(self, m_read): def test__read_ready(self, m_read):
...@@ -502,17 +499,13 @@ class UnixWritePipeTransportTests(test_utils.TestCase): ...@@ -502,17 +499,13 @@ class UnixWritePipeTransportTests(test_utils.TestCase):
return transport return transport
def test_ctor(self): def test_ctor(self):
tr = self.write_pipe_transport() waiter = asyncio.Future(loop=self.loop)
self.loop.assert_reader(5, tr._read_ready) tr = self.write_pipe_transport(waiter=waiter)
test_utils.run_briefly(self.loop) self.loop.run_until_complete(waiter)
self.protocol.connection_made.assert_called_with(tr)
def test_ctor_with_waiter(self): self.protocol.connection_made.assert_called_with(tr)
fut = asyncio.Future(loop=self.loop)
tr = self.write_pipe_transport(waiter=fut)
self.loop.assert_reader(5, tr._read_ready) self.loop.assert_reader(5, tr._read_ready)
test_utils.run_briefly(self.loop) self.assertEqual(None, waiter.result())
self.assertEqual(None, fut.result())
def test_can_write_eof(self): def test_can_write_eof(self):
tr = self.write_pipe_transport() tr = self.write_pipe_transport()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment