Commit fa73779b authored by Victor Stinner's avatar Victor Stinner

asyncio: Fix _SelectorSocketTransport constructor

Only start reading when connection_made() has been called:
protocol.data_received() must not be called before protocol.connection_made().
parent f07801bb
...@@ -578,8 +578,10 @@ class _SelectorSocketTransport(_SelectorTransport): ...@@ -578,8 +578,10 @@ class _SelectorSocketTransport(_SelectorTransport):
self._eof = False self._eof = False
self._paused = False self._paused = False
self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
self._sock_fd, self._read_ready)
if waiter is not None: if waiter is not None:
# only wake up the waiter when connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
......
...@@ -59,6 +59,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): ...@@ -59,6 +59,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
def test_make_socket_transport(self): def test_make_socket_transport(self):
m = mock.Mock() m = mock.Mock()
self.loop.add_reader = mock.Mock() self.loop.add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False
transport = self.loop._make_socket_transport(m, asyncio.Protocol()) transport = self.loop._make_socket_transport(m, asyncio.Protocol())
self.assertIsInstance(transport, _SelectorSocketTransport) self.assertIsInstance(transport, _SelectorSocketTransport)
close_transport(transport) close_transport(transport)
...@@ -67,6 +68,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): ...@@ -67,6 +68,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
def test_make_ssl_transport(self): def test_make_ssl_transport(self):
m = mock.Mock() m = mock.Mock()
self.loop.add_reader = mock.Mock() self.loop.add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False
self.loop.add_writer = mock.Mock() self.loop.add_writer = mock.Mock()
self.loop.remove_reader = mock.Mock() self.loop.remove_reader = mock.Mock()
self.loop.remove_writer = mock.Mock() self.loop.remove_writer = mock.Mock()
...@@ -770,20 +772,24 @@ class SelectorSocketTransportTests(test_utils.TestCase): ...@@ -770,20 +772,24 @@ class SelectorSocketTransportTests(test_utils.TestCase):
return transport return transport
def test_ctor(self): def test_ctor(self):
tr = self.socket_transport() waiter = asyncio.Future(loop=self.loop)
tr = self.socket_transport(waiter=waiter)
self.loop.run_until_complete(waiter)
self.loop.assert_reader(7, tr._read_ready) self.loop.assert_reader(7, tr._read_ready)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.protocol.connection_made.assert_called_with(tr) self.protocol.connection_made.assert_called_with(tr)
def test_ctor_with_waiter(self): def test_ctor_with_waiter(self):
fut = asyncio.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
self.socket_transport(waiter=waiter)
self.loop.run_until_complete(waiter)
self.socket_transport(waiter=fut) self.assertIsNone(waiter.result())
test_utils.run_briefly(self.loop)
self.assertIsNone(fut.result())
def test_pause_resume_reading(self): def test_pause_resume_reading(self):
tr = self.socket_transport() tr = self.socket_transport()
test_utils.run_briefly(self.loop)
self.assertFalse(tr._paused) self.assertFalse(tr._paused)
self.loop.assert_reader(7, tr._read_ready) self.loop.assert_reader(7, tr._read_ready)
tr.pause_reading() tr.pause_reading()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment