Commit c73701de authored by Victor Stinner's avatar Victor Stinner

asyncio: Refactor tests: add a base TestCase class

parent d6f02fc6
......@@ -11,6 +11,7 @@ import sys
import tempfile
import threading
import time
import unittest
from unittest import mock
from http.server import HTTPServer
......@@ -379,3 +380,20 @@ def get_function_source(func):
if source is None:
raise ValueError("unable to get the source of %r" % (func,))
return source
class TestCase(unittest.TestCase):
def set_event_loop(self, loop, *, cleanup=True):
assert loop is not None
# ensure that the event loop is passed explicitly in asyncio
events.set_event_loop(None)
if cleanup:
self.addCleanup(loop.close)
def new_test_loop(self, gen=None):
loop = TestLoop(gen)
self.set_event_loop(loop)
return loop
def tearDown(self):
events.set_event_loop(None)
......@@ -19,12 +19,12 @@ MOCK_ANY = mock.ANY
PY34 = sys.version_info >= (3, 4)
class BaseEventLoopTests(unittest.TestCase):
class BaseEventLoopTests(test_utils.TestCase):
def setUp(self):
self.loop = base_events.BaseEventLoop()
self.loop._selector = mock.Mock()
asyncio.set_event_loop(None)
self.set_event_loop(self.loop)
def test_not_implemented(self):
m = mock.Mock()
......@@ -548,14 +548,11 @@ class MyDatagramProto(asyncio.DatagramProtocol):
self.done.set_result(None)
class BaseEventLoopWithSelectorTests(unittest.TestCase):
class BaseEventLoopWithSelectorTests(test_utils.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.set_event_loop(self.loop)
@mock.patch('asyncio.base_events.socket')
def test_create_connection_multiple_errors(self, m_socket):
......
......@@ -224,7 +224,7 @@ class EventLoopTestsMixin:
def setUp(self):
super().setUp()
self.loop = self.create_event_loop()
asyncio.set_event_loop(None)
self.set_event_loop(self.loop)
def tearDown(self):
# just in case if we have transport close callbacks
......@@ -1629,14 +1629,14 @@ class SubprocessTestsMixin:
if sys.platform == 'win32':
class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(EventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):
def create_event_loop(self):
return asyncio.ProactorEventLoop()
......@@ -1691,7 +1691,7 @@ else:
if hasattr(selectors, 'KqueueSelector'):
class KqueueEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(
......@@ -1716,7 +1716,7 @@ else:
if hasattr(selectors, 'EpollSelector'):
class EPollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.EpollSelector())
......@@ -1724,7 +1724,7 @@ else:
if hasattr(selectors, 'PollSelector'):
class PollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.PollSelector())
......@@ -1732,7 +1732,7 @@ else:
# Should always exist.
class SelectEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.SelectSelector())
......
......@@ -13,14 +13,10 @@ def _fakefunc(f):
return f
class FutureTests(unittest.TestCase):
class FutureTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()
def test_initial_state(self):
f = asyncio.Future(loop=self.loop)
......@@ -30,12 +26,9 @@ class FutureTests(unittest.TestCase):
self.assertTrue(f.cancelled())
def test_init_constructor_default_loop(self):
try:
asyncio.set_event_loop(self.loop)
f = asyncio.Future()
self.assertIs(f._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
f = asyncio.Future()
self.assertIs(f._loop, self.loop)
def test_constructor_positional(self):
# Make sure Future doesn't accept a positional argument
......@@ -264,14 +257,10 @@ class FutureTests(unittest.TestCase):
self.assertTrue(f2.cancelled())
class FutureDoneCallbackTests(unittest.TestCase):
class FutureDoneCallbackTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()
def run_briefly(self):
test_utils.run_briefly(self.loop)
......
......@@ -17,14 +17,10 @@ STR_RGX_REPR = (
RGX_REPR = re.compile(STR_RGX_REPR)
class LockTests(unittest.TestCase):
class LockTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()
def test_ctor_loop(self):
loop = mock.Mock()
......@@ -35,12 +31,9 @@ class LockTests(unittest.TestCase):
self.assertIs(lock._loop, self.loop)
def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop)
lock = asyncio.Lock()
self.assertIs(lock._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
lock = asyncio.Lock()
self.assertIs(lock._loop, self.loop)
def test_repr(self):
lock = asyncio.Lock(loop=self.loop)
......@@ -240,14 +233,10 @@ class LockTests(unittest.TestCase):
self.assertFalse(lock.locked())
class EventTests(unittest.TestCase):
class EventTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()
def test_ctor_loop(self):
loop = mock.Mock()
......@@ -258,12 +247,9 @@ class EventTests(unittest.TestCase):
self.assertIs(ev._loop, self.loop)
def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop)
ev = asyncio.Event()
self.assertIs(ev._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
ev = asyncio.Event()
self.assertIs(ev._loop, self.loop)
def test_repr(self):
ev = asyncio.Event(loop=self.loop)
......@@ -376,14 +362,10 @@ class EventTests(unittest.TestCase):
self.assertTrue(t.result())
class ConditionTests(unittest.TestCase):
class ConditionTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()
def test_ctor_loop(self):
loop = mock.Mock()
......@@ -394,12 +376,9 @@ class ConditionTests(unittest.TestCase):
self.assertIs(cond._loop, self.loop)
def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop)
cond = asyncio.Condition()
self.assertIs(cond._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
cond = asyncio.Condition()
self.assertIs(cond._loop, self.loop)
def test_wait(self):
cond = asyncio.Condition(loop=self.loop)
......@@ -678,14 +657,10 @@ class ConditionTests(unittest.TestCase):
self.assertFalse(cond.locked())
class SemaphoreTests(unittest.TestCase):
class SemaphoreTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()
def test_ctor_loop(self):
loop = mock.Mock()
......@@ -696,12 +671,9 @@ class SemaphoreTests(unittest.TestCase):
self.assertIs(sem._loop, self.loop)
def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop)
sem = asyncio.Semaphore()
self.assertIs(sem._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
sem = asyncio.Semaphore()
self.assertIs(sem._loop, self.loop)
def test_initial_value_zero(self):
sem = asyncio.Semaphore(0, loop=self.loop)
......
......@@ -12,10 +12,10 @@ from asyncio.proactor_events import _ProactorDuplexPipeTransport
from asyncio import test_utils
class ProactorSocketTransportTests(unittest.TestCase):
class ProactorSocketTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.proactor = mock.Mock()
self.loop._proactor = self.proactor
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
......@@ -343,7 +343,7 @@ class ProactorSocketTransportTests(unittest.TestCase):
tr.close()
class BaseProactorEventLoopTests(unittest.TestCase):
class BaseProactorEventLoopTests(test_utils.TestCase):
def setUp(self):
self.sock = mock.Mock(socket.socket)
......@@ -356,6 +356,7 @@ class BaseProactorEventLoopTests(unittest.TestCase):
return (self.ssock, self.csock)
self.loop = EventLoop(self.proactor)
self.set_event_loop(self.loop, cleanup=False)
@mock.patch.object(BaseProactorEventLoop, 'call_soon')
@mock.patch.object(BaseProactorEventLoop, '_socketpair')
......
......@@ -7,14 +7,10 @@ import asyncio
from asyncio import test_utils
class _QueueTestBase(unittest.TestCase):
class _QueueTestBase(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()
class QueueBasicTests(_QueueTestBase):
......@@ -32,8 +28,7 @@ class QueueBasicTests(_QueueTestBase):
self.assertAlmostEqual(0.2, when)
yield 0.1
loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close)
loop = self.new_test_loop(gen)
q = asyncio.Queue(loop=loop)
self.assertTrue(fn(q).startswith('<Queue'), fn(q))
......@@ -80,12 +75,9 @@ class QueueBasicTests(_QueueTestBase):
self.assertIs(q._loop, self.loop)
def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop)
q = asyncio.Queue()
self.assertIs(q._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
q = asyncio.Queue()
self.assertIs(q._loop, self.loop)
def test_repr(self):
self._test_repr_or_str(repr, True)
......@@ -126,8 +118,7 @@ class QueueBasicTests(_QueueTestBase):
self.assertAlmostEqual(0.02, when)
yield 0.01
loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close)
loop = self.new_test_loop(gen)
q = asyncio.Queue(maxsize=2, loop=loop)
self.assertEqual(2, q.maxsize)
......@@ -194,8 +185,7 @@ class QueueGetTests(_QueueTestBase):
self.assertAlmostEqual(0.01, when)
yield 0.01
loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close)
loop = self.new_test_loop(gen)
q = asyncio.Queue(loop=loop)
started = asyncio.Event(loop=loop)
......@@ -241,8 +231,7 @@ class QueueGetTests(_QueueTestBase):
self.assertAlmostEqual(0.061, when)
yield 0.05
loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close)
loop = self.new_test_loop(gen)
q = asyncio.Queue(loop=loop)
......@@ -302,8 +291,7 @@ class QueuePutTests(_QueueTestBase):
self.assertAlmostEqual(0.01, when)
yield 0.01
loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close)
loop = self.new_test_loop(gen)
q = asyncio.Queue(maxsize=1, loop=loop)
started = asyncio.Event(loop=loop)
......
......@@ -37,11 +37,12 @@ def list_to_buffer(l=()):
return bytearray().join(l)
class BaseSelectorEventLoopTests(unittest.TestCase):
class BaseSelectorEventLoopTests(test_utils.TestCase):
def setUp(self):
selector = mock.Mock()
self.loop = TestBaseSelectorEventLoop(selector)
self.set_event_loop(self.loop, cleanup=False)
def test_make_socket_transport(self):
m = mock.Mock()
......@@ -597,10 +598,10 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.remove_writer.assert_called_with(1)
class SelectorTransportTests(unittest.TestCase):
class SelectorTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
self.sock.fileno.return_value = 7
......@@ -684,14 +685,14 @@ class SelectorTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop)
self.assertEqual(2, sys.getrefcount(self.loop),
self.assertEqual(3, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop)))
class SelectorSocketTransportTests(unittest.TestCase):
class SelectorSocketTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
self.sock_fd = self.sock.fileno.return_value = 7
......@@ -1061,10 +1062,10 @@ class SelectorSocketTransportTests(unittest.TestCase):
@unittest.skipIf(ssl is None, 'No ssl module')
class SelectorSslTransportTests(unittest.TestCase):
class SelectorSslTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket)
self.sock.fileno.return_value = 7
......@@ -1396,10 +1397,10 @@ class SelectorSslWithoutSslTransportTests(unittest.TestCase):
_SelectorSslTransport(Mock(), Mock(), Mock(), Mock())
class SelectorDatagramTransportTests(unittest.TestCase):
class SelectorDatagramTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
self.sock = mock.Mock(spec_set=socket.socket)
self.sock.fileno.return_value = 7
......
......@@ -15,13 +15,13 @@ import asyncio
from asyncio import test_utils
class StreamReaderTests(unittest.TestCase):
class StreamReaderTests(test_utils.TestCase):
DATA = b'line1\nline2\nline3\n'
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.set_event_loop(self.loop)
def tearDown(self):
# just in case if we have transport close callbacks
......@@ -29,6 +29,7 @@ class StreamReaderTests(unittest.TestCase):
self.loop.close()
gc.collect()
super().tearDown()
@mock.patch('asyncio.streams.events')
def test_ctor_global_loop(self, m_events):
......
from asyncio import subprocess
from asyncio import test_utils
import asyncio
import signal
import sys
......@@ -151,21 +152,21 @@ if sys.platform != 'win32':
policy = asyncio.get_event_loop_policy()
policy.set_child_watcher(None)
self.loop.close()
policy.set_event_loop(None)
super().tearDown()
class SubprocessSafeWatcherTests(SubprocessWatcherMixin,
unittest.TestCase):
test_utils.TestCase):
Watcher = unix_events.SafeChildWatcher
class SubprocessFastWatcherTests(SubprocessWatcherMixin,
unittest.TestCase):
test_utils.TestCase):
Watcher = unix_events.FastChildWatcher
else:
# Windows
class SubprocessProactorTests(SubprocessMixin, unittest.TestCase):
class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):
def setUp(self):
policy = asyncio.get_event_loop_policy()
......@@ -178,6 +179,7 @@ else:
policy = asyncio.get_event_loop_policy()
self.loop.close()
policy.set_event_loop(None)
super().tearDown()
if __name__ == '__main__':
......
This diff is collapsed.
......@@ -29,14 +29,11 @@ MOCK_ANY = mock.ANY
@unittest.skipUnless(signal, 'Signals are not supported')
class SelectorEventLoopSignalTests(unittest.TestCase):
class SelectorEventLoopSignalTests(test_utils.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.set_event_loop(self.loop)
def test_check_signal(self):
self.assertRaises(
......@@ -208,14 +205,11 @@ class SelectorEventLoopSignalTests(unittest.TestCase):
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
'UNIX Sockets are not supported')
class SelectorEventLoopUnixSocketTests(unittest.TestCase):
class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.set_event_loop(self.loop)
def test_create_unix_server_existing_path_sock(self):
with test_utils.unix_socket_path() as path:
......@@ -304,10 +298,10 @@ class SelectorEventLoopUnixSocketTests(unittest.TestCase):
self.loop.run_until_complete(coro)
class UnixReadPipeTransportTests(unittest.TestCase):
class UnixReadPipeTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5
......@@ -451,7 +445,7 @@ class UnixReadPipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop),
self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop)))
def test__call_connection_lost_with_err(self):
......@@ -468,14 +462,14 @@ class UnixReadPipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop),
self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop)))
class UnixWritePipeTransportTests(unittest.TestCase):
class UnixWritePipeTransportTests(test_utils.TestCase):
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5
......@@ -737,7 +731,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop),
self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop)))
def test__call_connection_lost_with_err(self):
......@@ -753,7 +747,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop),
self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop)))
def test_close(self):
......@@ -834,7 +828,7 @@ class ChildWatcherTestsMixin:
ignore_warnings = mock.patch.object(log.logger, "warning")
def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.running = False
self.zombies = {}
......@@ -1392,7 +1386,7 @@ class ChildWatcherTestsMixin:
# attach a new loop
old_loop = self.loop
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
patch = mock.patch.object
with patch(old_loop, "remove_signal_handler") as m_old_remove, \
......@@ -1447,7 +1441,7 @@ class ChildWatcherTestsMixin:
self.assertFalse(callback3.called)
# attach a new loop
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
with mock.patch.object(
self.loop, "add_signal_handler") as m_add_signal_handler:
......@@ -1505,12 +1499,12 @@ class ChildWatcherTestsMixin:
self.assertFalse(self.watcher._zombies)
class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
def create_watcher(self):
return asyncio.SafeChildWatcher()
class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
def create_watcher(self):
return asyncio.FastChildWatcher()
......
......@@ -26,15 +26,11 @@ class UpperProto(asyncio.Protocol):
self.trans.close()
class ProactorTests(unittest.TestCase):
class ProactorTests(test_utils.TestCase):
def setUp(self):
self.loop = asyncio.ProactorEventLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
self.loop = None
self.set_event_loop(self.loop)
def test_close(self):
a, b = self.loop._socketpair()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment