Commit ab485051 authored by Yury Selivanov's avatar Yury Selivanov

Issue #28369: Raise an error when transport's FD is used with add_reader

parent 77e3f63e
This diff is collapsed.
...@@ -13,6 +13,8 @@ import tempfile ...@@ -13,6 +13,8 @@ import tempfile
import threading import threading
import time import time
import unittest import unittest
import weakref
from unittest import mock from unittest import mock
from http.server import HTTPServer from http.server import HTTPServer
...@@ -300,6 +302,8 @@ class TestLoop(base_events.BaseEventLoop): ...@@ -300,6 +302,8 @@ class TestLoop(base_events.BaseEventLoop):
self.writers = {} self.writers = {}
self.reset_counters() self.reset_counters()
self._transports = weakref.WeakValueDictionary()
def time(self): def time(self):
return self._time return self._time
...@@ -318,10 +322,10 @@ class TestLoop(base_events.BaseEventLoop): ...@@ -318,10 +322,10 @@ class TestLoop(base_events.BaseEventLoop):
else: # pragma: no cover else: # pragma: no cover
raise AssertionError("Time generator is not finished") raise AssertionError("Time generator is not finished")
def add_reader(self, fd, callback, *args): def _add_reader(self, fd, callback, *args):
self.readers[fd] = events.Handle(callback, args, self) self.readers[fd] = events.Handle(callback, args, self)
def remove_reader(self, fd): def _remove_reader(self, fd):
self.remove_reader_count[fd] += 1 self.remove_reader_count[fd] += 1
if fd in self.readers: if fd in self.readers:
del self.readers[fd] del self.readers[fd]
...@@ -337,10 +341,10 @@ class TestLoop(base_events.BaseEventLoop): ...@@ -337,10 +341,10 @@ class TestLoop(base_events.BaseEventLoop):
assert handle._args == args, '{!r} != {!r}'.format( assert handle._args == args, '{!r} != {!r}'.format(
handle._args, args) handle._args, args)
def add_writer(self, fd, callback, *args): def _add_writer(self, fd, callback, *args):
self.writers[fd] = events.Handle(callback, args, self) self.writers[fd] = events.Handle(callback, args, self)
def remove_writer(self, fd): def _remove_writer(self, fd):
self.remove_writer_count[fd] += 1 self.remove_writer_count[fd] += 1
if fd in self.writers: if fd in self.writers:
del self.writers[fd] del self.writers[fd]
...@@ -356,6 +360,36 @@ class TestLoop(base_events.BaseEventLoop): ...@@ -356,6 +360,36 @@ class TestLoop(base_events.BaseEventLoop):
assert handle._args == args, '{!r} != {!r}'.format( assert handle._args == args, '{!r} != {!r}'.format(
handle._args, args) handle._args, args)
def _ensure_fd_no_transport(self, fd):
try:
transport = self._transports[fd]
except KeyError:
pass
else:
raise RuntimeError(
'File descriptor {!r} is used by transport {!r}'.format(
fd, transport))
def add_reader(self, fd, callback, *args):
"""Add a reader callback."""
self._ensure_fd_no_transport(fd)
return self._add_reader(fd, callback, *args)
def remove_reader(self, fd):
"""Remove a reader callback."""
self._ensure_fd_no_transport(fd)
return self._remove_reader(fd)
def add_writer(self, fd, callback, *args):
"""Add a writer callback.."""
self._ensure_fd_no_transport(fd)
return self._add_writer(fd, callback, *args)
def remove_writer(self, fd):
"""Remove a writer callback."""
self._ensure_fd_no_transport(fd)
return self._remove_writer(fd)
def reset_counters(self): def reset_counters(self):
self.remove_reader_count = collections.defaultdict(int) self.remove_reader_count = collections.defaultdict(int)
self.remove_writer_count = collections.defaultdict(int) self.remove_writer_count = collections.defaultdict(int)
......
...@@ -321,7 +321,7 @@ class _UnixReadPipeTransport(transports.ReadTransport): ...@@ -321,7 +321,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called # only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader, self._loop.call_soon(self._loop._add_reader,
self._fileno, self._read_ready) self._fileno, self._read_ready)
if waiter is not None: if waiter is not None:
# only wake up the waiter when connection_made() has been called # only wake up the waiter when connection_made() has been called
...@@ -364,15 +364,15 @@ class _UnixReadPipeTransport(transports.ReadTransport): ...@@ -364,15 +364,15 @@ class _UnixReadPipeTransport(transports.ReadTransport):
if self._loop.get_debug(): if self._loop.get_debug():
logger.info("%r was closed by peer", self) logger.info("%r was closed by peer", self)
self._closing = True self._closing = True
self._loop.remove_reader(self._fileno) self._loop._remove_reader(self._fileno)
self._loop.call_soon(self._protocol.eof_received) self._loop.call_soon(self._protocol.eof_received)
self._loop.call_soon(self._call_connection_lost, None) self._loop.call_soon(self._call_connection_lost, None)
def pause_reading(self): def pause_reading(self):
self._loop.remove_reader(self._fileno) self._loop._remove_reader(self._fileno)
def resume_reading(self): def resume_reading(self):
self._loop.add_reader(self._fileno, self._read_ready) self._loop._add_reader(self._fileno, self._read_ready)
def set_protocol(self, protocol): def set_protocol(self, protocol):
self._protocol = protocol self._protocol = protocol
...@@ -412,7 +412,7 @@ class _UnixReadPipeTransport(transports.ReadTransport): ...@@ -412,7 +412,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
def _close(self, exc): def _close(self, exc):
self._closing = True self._closing = True
self._loop.remove_reader(self._fileno) self._loop._remove_reader(self._fileno)
self._loop.call_soon(self._call_connection_lost, exc) self._loop.call_soon(self._call_connection_lost, exc)
def _call_connection_lost(self, exc): def _call_connection_lost(self, exc):
...@@ -457,7 +457,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, ...@@ -457,7 +457,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
# works for pipes and sockets. (Exception: OS X 10.4? Issue #19294.) # works for pipes and sockets. (Exception: OS X 10.4? Issue #19294.)
if is_socket or (is_fifo and not sys.platform.startswith("aix")): if is_socket or (is_fifo and not sys.platform.startswith("aix")):
# only start reading when connection_made() has been called # only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader, self._loop.call_soon(self._loop._add_reader,
self._fileno, self._read_ready) self._fileno, self._read_ready)
if waiter is not None: if waiter is not None:
...@@ -530,7 +530,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, ...@@ -530,7 +530,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
return return
elif n > 0: elif n > 0:
data = memoryview(data)[n:] data = memoryview(data)[n:]
self._loop.add_writer(self._fileno, self._write_ready) self._loop._add_writer(self._fileno, self._write_ready)
self._buffer += data self._buffer += data
self._maybe_pause_protocol() self._maybe_pause_protocol()
...@@ -547,15 +547,15 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, ...@@ -547,15 +547,15 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
self._conn_lost += 1 self._conn_lost += 1
# Remove writer here, _fatal_error() doesn't it # Remove writer here, _fatal_error() doesn't it
# because _buffer is empty. # because _buffer is empty.
self._loop.remove_writer(self._fileno) self._loop._remove_writer(self._fileno)
self._fatal_error(exc, 'Fatal write error on pipe transport') self._fatal_error(exc, 'Fatal write error on pipe transport')
else: else:
if n == len(self._buffer): if n == len(self._buffer):
self._buffer.clear() self._buffer.clear()
self._loop.remove_writer(self._fileno) self._loop._remove_writer(self._fileno)
self._maybe_resume_protocol() # May append to buffer. self._maybe_resume_protocol() # May append to buffer.
if self._closing: if self._closing:
self._loop.remove_reader(self._fileno) self._loop._remove_reader(self._fileno)
self._call_connection_lost(None) self._call_connection_lost(None)
return return
elif n > 0: elif n > 0:
...@@ -570,7 +570,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, ...@@ -570,7 +570,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
assert self._pipe assert self._pipe
self._closing = True self._closing = True
if not self._buffer: if not self._buffer:
self._loop.remove_reader(self._fileno) self._loop._remove_reader(self._fileno)
self._loop.call_soon(self._call_connection_lost, None) self._loop.call_soon(self._call_connection_lost, None)
def set_protocol(self, protocol): def set_protocol(self, protocol):
...@@ -616,9 +616,9 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, ...@@ -616,9 +616,9 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
def _close(self, exc=None): def _close(self, exc=None):
self._closing = True self._closing = True
if self._buffer: if self._buffer:
self._loop.remove_writer(self._fileno) self._loop._remove_writer(self._fileno)
self._buffer.clear() self._buffer.clear()
self._loop.remove_reader(self._fileno) self._loop._remove_reader(self._fileno)
self._loop.call_soon(self._call_connection_lost, exc) self._loop.call_soon(self._call_connection_lost, exc)
def _call_connection_lost(self, exc): def _call_connection_lost(self, exc):
......
...@@ -1148,10 +1148,10 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1148,10 +1148,10 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
m_socket.getaddrinfo = socket.getaddrinfo m_socket.getaddrinfo = socket.getaddrinfo
sock = m_socket.socket.return_value sock = m_socket.socket.return_value
self.loop.add_reader = mock.Mock() self.loop._add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False self.loop._add_reader._is_coroutine = False
self.loop.add_writer = mock.Mock() self.loop._add_writer = mock.Mock()
self.loop.add_writer._is_coroutine = False self.loop._add_writer._is_coroutine = False
coro = self.loop.create_connection(asyncio.Protocol, '1.2.3.4', 80) coro = self.loop.create_connection(asyncio.Protocol, '1.2.3.4', 80)
t, p = self.loop.run_until_complete(coro) t, p = self.loop.run_until_complete(coro)
...@@ -1194,10 +1194,10 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1194,10 +1194,10 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
m_socket.getaddrinfo = socket.getaddrinfo m_socket.getaddrinfo = socket.getaddrinfo
sock = m_socket.socket.return_value sock = m_socket.socket.return_value
self.loop.add_reader = mock.Mock() self.loop._add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False self.loop._add_reader._is_coroutine = False
self.loop.add_writer = mock.Mock() self.loop._add_writer = mock.Mock()
self.loop.add_writer._is_coroutine = False self.loop._add_writer._is_coroutine = False
for service, port in ('http', 80), (b'http', 80): for service, port in ('http', 80), (b'http', 80):
coro = self.loop.create_connection(asyncio.Protocol, coro = self.loop.create_connection(asyncio.Protocol,
...@@ -1614,8 +1614,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1614,8 +1614,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
m_socket.getaddrinfo = getaddrinfo m_socket.getaddrinfo = getaddrinfo
m_socket.socket.return_value.bind = bind = mock.Mock() m_socket.socket.return_value.bind = bind = mock.Mock()
self.loop.add_reader = mock.Mock() self.loop._add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False self.loop._add_reader._is_coroutine = False
reuseport_supported = hasattr(socket, 'SO_REUSEPORT') reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
...@@ -1646,13 +1646,13 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1646,13 +1646,13 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
sock = mock.Mock() sock = mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files') sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files')
self.loop.remove_reader = mock.Mock() self.loop._remove_reader = mock.Mock()
self.loop.call_later = mock.Mock() self.loop.call_later = mock.Mock()
self.loop._accept_connection(MyProto, sock) self.loop._accept_connection(MyProto, sock)
self.assertTrue(m_log.error.called) self.assertTrue(m_log.error.called)
self.assertFalse(sock.close.called) self.assertFalse(sock.close.called)
self.loop.remove_reader.assert_called_with(10) self.loop._remove_reader.assert_called_with(10)
self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY, self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving # self.loop._start_serving
mock.ANY, mock.ANY,
......
...@@ -350,6 +350,9 @@ Library ...@@ -350,6 +350,9 @@ Library
no loop attached. no loop attached.
Patch by Vincent Michel. Patch by Vincent Michel.
- Issue #28369: Raise RuntimeError when transport's FD is used with
add_reader, add_writer, etc.
IDLE IDLE
---- ----
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment