Commit 47bbea71 authored by Victor Stinner's avatar Victor Stinner

asyncio: sync with Tulip

* _SelectorTransport constructor: extra parameter is now optional
* Fix _SelectorDatagramTransport constructor. Only start reading after
  connection_made() has been called.
* Fix _SelectorSslTransport.close(). Don't call protocol.connection_lost() if
  protocol.connection_made() was not called yet: if the SSL handshake failed or
  is still in progress. The close() method can be called if the creation of the
  connection is cancelled, by a timeout for example.
parent 7b5a900e
...@@ -467,7 +467,7 @@ class _SelectorTransport(transports._FlowControlMixin, ...@@ -467,7 +467,7 @@ class _SelectorTransport(transports._FlowControlMixin,
_buffer_factory = bytearray # Constructs initial value for self._buffer. _buffer_factory = bytearray # Constructs initial value for self._buffer.
def __init__(self, loop, sock, protocol, extra, server=None): def __init__(self, loop, sock, protocol, extra=None, server=None):
super().__init__(extra, loop) super().__init__(extra, loop)
self._extra['socket'] = sock self._extra['socket'] = sock
self._extra['sockname'] = sock.getsockname() self._extra['sockname'] = sock.getsockname()
...@@ -479,6 +479,7 @@ class _SelectorTransport(transports._FlowControlMixin, ...@@ -479,6 +479,7 @@ class _SelectorTransport(transports._FlowControlMixin,
self._sock = sock self._sock = sock
self._sock_fd = sock.fileno() self._sock_fd = sock.fileno()
self._protocol = protocol self._protocol = protocol
self._protocol_connected = True
self._server = server self._server = server
self._buffer = self._buffer_factory() self._buffer = self._buffer_factory()
self._conn_lost = 0 # Set when call to connection_lost scheduled. self._conn_lost = 0 # Set when call to connection_lost scheduled.
...@@ -555,6 +556,7 @@ class _SelectorTransport(transports._FlowControlMixin, ...@@ -555,6 +556,7 @@ class _SelectorTransport(transports._FlowControlMixin,
def _call_connection_lost(self, exc): def _call_connection_lost(self, exc):
try: try:
if self._protocol_connected:
self._protocol.connection_lost(exc) self._protocol.connection_lost(exc)
finally: finally:
self._sock.close() self._sock.close()
...@@ -718,6 +720,8 @@ class _SelectorSslTransport(_SelectorTransport): ...@@ -718,6 +720,8 @@ class _SelectorSslTransport(_SelectorTransport):
sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs)
super().__init__(loop, sslsock, protocol, extra, server) super().__init__(loop, sslsock, protocol, extra, server)
# the protocol connection is only made after the SSL handshake
self._protocol_connected = False
self._server_hostname = server_hostname self._server_hostname = server_hostname
self._waiter = waiter self._waiter = waiter
...@@ -797,6 +801,7 @@ class _SelectorSslTransport(_SelectorTransport): ...@@ -797,6 +801,7 @@ class _SelectorSslTransport(_SelectorTransport):
self._read_wants_write = False self._read_wants_write = False
self._write_wants_read = False self._write_wants_read = False
self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.add_reader(self._sock_fd, self._read_ready)
self._protocol_connected = True
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
# only wake up the waiter when connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(self._wakeup_waiter) self._loop.call_soon(self._wakeup_waiter)
...@@ -928,8 +933,10 @@ class _SelectorDatagramTransport(_SelectorTransport): ...@@ -928,8 +933,10 @@ class _SelectorDatagramTransport(_SelectorTransport):
waiter=None, extra=None): waiter=None, extra=None):
super().__init__(loop, sock, protocol, extra) super().__init__(loop, sock, protocol, extra)
self._address = address self._address = address
self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
self._sock_fd, self._read_ready)
if waiter is not None: if waiter is not None:
# only wake up the waiter when connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
......
...@@ -1427,7 +1427,7 @@ class SelectorSslTransportTests(test_utils.TestCase): ...@@ -1427,7 +1427,7 @@ class SelectorSslTransportTests(test_utils.TestCase):
self.assertFalse(tr.can_write_eof()) self.assertFalse(tr.can_write_eof())
self.assertRaises(NotImplementedError, tr.write_eof) self.assertRaises(NotImplementedError, tr.write_eof)
def test_close(self): def check_close(self):
tr = self._make_one() tr = self._make_one()
tr.close() tr.close()
...@@ -1439,6 +1439,19 @@ class SelectorSslTransportTests(test_utils.TestCase): ...@@ -1439,6 +1439,19 @@ class SelectorSslTransportTests(test_utils.TestCase):
self.assertEqual(tr._conn_lost, 1) self.assertEqual(tr._conn_lost, 1)
self.assertEqual(1, self.loop.remove_reader_count[1]) self.assertEqual(1, self.loop.remove_reader_count[1])
test_utils.run_briefly(self.loop)
def test_close(self):
self.check_close()
self.assertTrue(self.protocol.connection_made.called)
self.assertTrue(self.protocol.connection_lost.called)
def test_close_not_connected(self):
self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError
self.check_close()
self.assertFalse(self.protocol.connection_made.called)
self.assertFalse(self.protocol.connection_lost.called)
@unittest.skipIf(ssl is None, 'No SSL support') @unittest.skipIf(ssl is None, 'No SSL support')
def test_server_hostname(self): def test_server_hostname(self):
self.ssl_transport(server_hostname='localhost') self.ssl_transport(server_hostname='localhost')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment