Commit a5d1eb8d authored by Andrew Svetlov's avatar Andrew Svetlov Committed by GitHub

bpo-34638: Store a weak reference to stream reader to break strong references loop (GH-9201)

Store a weak reference to stream readerfor breaking strong references

It breaks the strong reference loop between reader and protocol and allows to detect and close the socket if the stream is deleted (garbage collected)
parent aca819fb
...@@ -3,6 +3,8 @@ __all__ = ( ...@@ -3,6 +3,8 @@ __all__ = (
'open_connection', 'start_server') 'open_connection', 'start_server')
import socket import socket
import sys
import weakref
if hasattr(socket, 'AF_UNIX'): if hasattr(socket, 'AF_UNIX'):
__all__ += ('open_unix_connection', 'start_unix_server') __all__ += ('open_unix_connection', 'start_unix_server')
...@@ -10,6 +12,7 @@ if hasattr(socket, 'AF_UNIX'): ...@@ -10,6 +12,7 @@ if hasattr(socket, 'AF_UNIX'):
from . import coroutines from . import coroutines
from . import events from . import events
from . import exceptions from . import exceptions
from . import format_helpers
from . import protocols from . import protocols
from .log import logger from .log import logger
from .tasks import sleep from .tasks import sleep
...@@ -186,46 +189,106 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): ...@@ -186,46 +189,106 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
call inappropriate methods of the protocol.) call inappropriate methods of the protocol.)
""" """
_source_traceback = None
def __init__(self, stream_reader, client_connected_cb=None, loop=None): def __init__(self, stream_reader, client_connected_cb=None, loop=None):
super().__init__(loop=loop) super().__init__(loop=loop)
self._stream_reader = stream_reader if stream_reader is not None:
self._stream_reader_wr = weakref.ref(stream_reader,
self._on_reader_gc)
self._source_traceback = stream_reader._source_traceback
else:
self._stream_reader_wr = None
if client_connected_cb is not None:
# This is a stream created by the `create_server()` function.
# Keep a strong reference to the reader until a connection
# is established.
self._strong_reader = stream_reader
self._reject_connection = False
self._stream_writer = None self._stream_writer = None
self._transport = None
self._client_connected_cb = client_connected_cb self._client_connected_cb = client_connected_cb
self._over_ssl = False self._over_ssl = False
self._closed = self._loop.create_future() self._closed = self._loop.create_future()
def _on_reader_gc(self, wr):
transport = self._transport
if transport is not None:
# connection_made was called
context = {
'message': ('An open stream object is being garbage '
'collected; call "stream.close()" explicitly.')
}
if self._source_traceback:
context['source_traceback'] = self._source_traceback
self._loop.call_exception_handler(context)
transport.abort()
else:
self._reject_connection = True
self._stream_reader_wr = None
def _untrack_reader(self):
self._stream_reader_wr = None
@property
def _stream_reader(self):
if self._stream_reader_wr is None:
return None
return self._stream_reader_wr()
def connection_made(self, transport): def connection_made(self, transport):
self._stream_reader.set_transport(transport) if self._reject_connection:
context = {
'message': ('An open stream was garbage collected prior to '
'establishing network connection; '
'call "stream.close()" explicitly.')
}
if self._source_traceback:
context['source_traceback'] = self._source_traceback
self._loop.call_exception_handler(context)
transport.abort()
return
self._transport = transport
reader = self._stream_reader
if reader is not None:
reader.set_transport(transport)
self._over_ssl = transport.get_extra_info('sslcontext') is not None self._over_ssl = transport.get_extra_info('sslcontext') is not None
if self._client_connected_cb is not None: if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self, self._stream_writer = StreamWriter(transport, self,
self._stream_reader, reader,
self._loop) self._loop)
res = self._client_connected_cb(self._stream_reader, res = self._client_connected_cb(reader,
self._stream_writer) self._stream_writer)
if coroutines.iscoroutine(res): if coroutines.iscoroutine(res):
self._loop.create_task(res) self._loop.create_task(res)
self._strong_reader = None
def connection_lost(self, exc): def connection_lost(self, exc):
if self._stream_reader is not None: reader = self._stream_reader
if reader is not None:
if exc is None: if exc is None:
self._stream_reader.feed_eof() reader.feed_eof()
else: else:
self._stream_reader.set_exception(exc) reader.set_exception(exc)
if not self._closed.done(): if not self._closed.done():
if exc is None: if exc is None:
self._closed.set_result(None) self._closed.set_result(None)
else: else:
self._closed.set_exception(exc) self._closed.set_exception(exc)
super().connection_lost(exc) super().connection_lost(exc)
self._stream_reader = None self._stream_reader_wr = None
self._stream_writer = None self._stream_writer = None
self._transport = None
def data_received(self, data): def data_received(self, data):
self._stream_reader.feed_data(data) reader = self._stream_reader
if reader is not None:
reader.feed_data(data)
def eof_received(self): def eof_received(self):
self._stream_reader.feed_eof() reader = self._stream_reader
if reader is not None:
reader.feed_eof()
if self._over_ssl: if self._over_ssl:
# Prevent a warning in SSLProtocol.eof_received: # Prevent a warning in SSLProtocol.eof_received:
# "returning true from eof_received() # "returning true from eof_received()
...@@ -282,6 +345,9 @@ class StreamWriter: ...@@ -282,6 +345,9 @@ class StreamWriter:
return self._transport.can_write_eof() return self._transport.can_write_eof()
def close(self): def close(self):
# a reader can be garbage collected
# after connection closing
self._protocol._untrack_reader()
return self._transport.close() return self._transport.close()
def is_closing(self): def is_closing(self):
...@@ -318,6 +384,8 @@ class StreamWriter: ...@@ -318,6 +384,8 @@ class StreamWriter:
class StreamReader: class StreamReader:
_source_traceback = None
def __init__(self, limit=_DEFAULT_LIMIT, loop=None): def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
# The line length limit is a security feature; # The line length limit is a security feature;
# it also doubles as half the buffer limit. # it also doubles as half the buffer limit.
...@@ -336,6 +404,9 @@ class StreamReader: ...@@ -336,6 +404,9 @@ class StreamReader:
self._exception = None self._exception = None
self._transport = None self._transport = None
self._paused = False self._paused = False
if self._loop.get_debug():
self._source_traceback = format_helpers.extract_stack(
sys._getframe(1))
def __repr__(self): def __repr__(self):
info = ['StreamReader'] info = ['StreamReader']
......
...@@ -36,6 +36,11 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, ...@@ -36,6 +36,11 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
info.append(f'stderr={self.stderr!r}') info.append(f'stderr={self.stderr!r}')
return '<{}>'.format(' '.join(info)) return '<{}>'.format(' '.join(info))
def _untrack_reader(self):
# StreamWriter.close() expects the protocol
# to have this method defined.
pass
def connection_made(self, transport): def connection_made(self, transport):
self._transport = transport self._transport = transport
......
...@@ -46,6 +46,8 @@ class StreamTests(test_utils.TestCase): ...@@ -46,6 +46,8 @@ class StreamTests(test_utils.TestCase):
self.assertIs(stream._loop, m_events.get_event_loop.return_value) self.assertIs(stream._loop, m_events.get_event_loop.return_value)
def _basetest_open_connection(self, open_connection_fut): def _basetest_open_connection(self, open_connection_fut):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
reader, writer = self.loop.run_until_complete(open_connection_fut) reader, writer = self.loop.run_until_complete(open_connection_fut)
writer.write(b'GET / HTTP/1.0\r\n\r\n') writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.readline() f = reader.readline()
...@@ -55,6 +57,7 @@ class StreamTests(test_utils.TestCase): ...@@ -55,6 +57,7 @@ class StreamTests(test_utils.TestCase):
data = self.loop.run_until_complete(f) data = self.loop.run_until_complete(f)
self.assertTrue(data.endswith(b'\r\n\r\nTest message')) self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close() writer.close()
self.assertEqual(messages, [])
def test_open_connection(self): def test_open_connection(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
...@@ -70,6 +73,8 @@ class StreamTests(test_utils.TestCase): ...@@ -70,6 +73,8 @@ class StreamTests(test_utils.TestCase):
self._basetest_open_connection(conn_fut) self._basetest_open_connection(conn_fut)
def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
try: try:
reader, writer = self.loop.run_until_complete(open_connection_fut) reader, writer = self.loop.run_until_complete(open_connection_fut)
finally: finally:
...@@ -80,6 +85,7 @@ class StreamTests(test_utils.TestCase): ...@@ -80,6 +85,7 @@ class StreamTests(test_utils.TestCase):
self.assertTrue(data.endswith(b'\r\n\r\nTest message')) self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
writer.close() writer.close()
self.assertEqual(messages, [])
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_open_connection_no_loop_ssl(self): def test_open_connection_no_loop_ssl(self):
...@@ -104,6 +110,8 @@ class StreamTests(test_utils.TestCase): ...@@ -104,6 +110,8 @@ class StreamTests(test_utils.TestCase):
self._basetest_open_connection_no_loop_ssl(conn_fut) self._basetest_open_connection_no_loop_ssl(conn_fut)
def _basetest_open_connection_error(self, open_connection_fut): def _basetest_open_connection_error(self, open_connection_fut):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
reader, writer = self.loop.run_until_complete(open_connection_fut) reader, writer = self.loop.run_until_complete(open_connection_fut)
writer._protocol.connection_lost(ZeroDivisionError()) writer._protocol.connection_lost(ZeroDivisionError())
f = reader.read() f = reader.read()
...@@ -111,6 +119,7 @@ class StreamTests(test_utils.TestCase): ...@@ -111,6 +119,7 @@ class StreamTests(test_utils.TestCase):
self.loop.run_until_complete(f) self.loop.run_until_complete(f)
writer.close() writer.close()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual(messages, [])
def test_open_connection_error(self): def test_open_connection_error(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
...@@ -621,6 +630,9 @@ class StreamTests(test_utils.TestCase): ...@@ -621,6 +630,9 @@ class StreamTests(test_utils.TestCase):
writer.close() writer.close()
return msgback return msgback
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
# test the server variant with a coroutine as client handler # test the server variant with a coroutine as client handler
server = MyServer(self.loop) server = MyServer(self.loop)
addr = server.start() addr = server.start()
...@@ -637,6 +649,8 @@ class StreamTests(test_utils.TestCase): ...@@ -637,6 +649,8 @@ class StreamTests(test_utils.TestCase):
server.stop() server.stop()
self.assertEqual(msg, b"hello world!\n") self.assertEqual(msg, b"hello world!\n")
self.assertEqual(messages, [])
@support.skip_unless_bind_unix_socket @support.skip_unless_bind_unix_socket
def test_start_unix_server(self): def test_start_unix_server(self):
...@@ -685,6 +699,9 @@ class StreamTests(test_utils.TestCase): ...@@ -685,6 +699,9 @@ class StreamTests(test_utils.TestCase):
writer.close() writer.close()
return msgback return msgback
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
# test the server variant with a coroutine as client handler # test the server variant with a coroutine as client handler
with test_utils.unix_socket_path() as path: with test_utils.unix_socket_path() as path:
server = MyServer(self.loop, path) server = MyServer(self.loop, path)
...@@ -703,6 +720,8 @@ class StreamTests(test_utils.TestCase): ...@@ -703,6 +720,8 @@ class StreamTests(test_utils.TestCase):
server.stop() server.stop()
self.assertEqual(msg, b"hello world!\n") self.assertEqual(msg, b"hello world!\n")
self.assertEqual(messages, [])
@unittest.skipIf(sys.platform == 'win32', "Don't have pipes") @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
def test_read_all_from_pipe_reader(self): def test_read_all_from_pipe_reader(self):
# See asyncio issue 168. This test is derived from the example # See asyncio issue 168. This test is derived from the example
...@@ -893,6 +912,58 @@ os.close(fd) ...@@ -893,6 +912,58 @@ os.close(fd)
wr.close() wr.close()
self.loop.run_until_complete(wr.wait_closed()) self.loop.run_until_complete(wr.wait_closed())
def test_del_stream_before_sock_closing(self):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
with test_utils.run_test_server() as httpd:
rd, wr = self.loop.run_until_complete(
asyncio.open_connection(*httpd.address, loop=self.loop))
sock = wr.get_extra_info('socket')
self.assertNotEqual(sock.fileno(), -1)
wr.write(b'GET / HTTP/1.0\r\n\r\n')
f = rd.readline()
data = self.loop.run_until_complete(f)
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
# drop refs to reader/writer
del rd
del wr
gc.collect()
# make a chance to close the socket
test_utils.run_briefly(self.loop)
self.assertEqual(1, len(messages))
self.assertEqual(sock.fileno(), -1)
self.assertEqual(1, len(messages))
self.assertEqual('An open stream object is being garbage '
'collected; call "stream.close()" explicitly.',
messages[0]['message'])
def test_del_stream_before_connection_made(self):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
with test_utils.run_test_server() as httpd:
rd = asyncio.StreamReader(loop=self.loop)
pr = asyncio.StreamReaderProtocol(rd, loop=self.loop)
del rd
gc.collect()
tr, _ = self.loop.run_until_complete(
self.loop.create_connection(
lambda: pr, *httpd.address))
sock = tr.get_extra_info('socket')
self.assertEqual(sock.fileno(), -1)
self.assertEqual(1, len(messages))
self.assertEqual('An open stream was garbage collected prior to '
'establishing network connection; '
'call "stream.close()" explicitly.',
messages[0]['message'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Store a weak reference to stream reader to break strong references loop
between reader and protocol. It allows to detect and close the socket if
the stream is deleted (garbage collected) without ``close()`` call.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment