Commit c73701de authored by Victor Stinner's avatar Victor Stinner

asyncio: Refactor tests: add a base TestCase class

parent d6f02fc6
...@@ -11,6 +11,7 @@ import sys ...@@ -11,6 +11,7 @@ import sys
import tempfile import tempfile
import threading import threading
import time import time
import unittest
from unittest import mock from unittest import mock
from http.server import HTTPServer from http.server import HTTPServer
...@@ -379,3 +380,20 @@ def get_function_source(func): ...@@ -379,3 +380,20 @@ def get_function_source(func):
if source is None: if source is None:
raise ValueError("unable to get the source of %r" % (func,)) raise ValueError("unable to get the source of %r" % (func,))
return source return source
class TestCase(unittest.TestCase):
def set_event_loop(self, loop, *, cleanup=True):
assert loop is not None
# ensure that the event loop is passed explicitly in asyncio
events.set_event_loop(None)
if cleanup:
self.addCleanup(loop.close)
def new_test_loop(self, gen=None):
loop = TestLoop(gen)
self.set_event_loop(loop)
return loop
def tearDown(self):
events.set_event_loop(None)
...@@ -19,12 +19,12 @@ MOCK_ANY = mock.ANY ...@@ -19,12 +19,12 @@ MOCK_ANY = mock.ANY
PY34 = sys.version_info >= (3, 4) PY34 = sys.version_info >= (3, 4)
class BaseEventLoopTests(unittest.TestCase): class BaseEventLoopTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = base_events.BaseEventLoop() self.loop = base_events.BaseEventLoop()
self.loop._selector = mock.Mock() self.loop._selector = mock.Mock()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def test_not_implemented(self): def test_not_implemented(self):
m = mock.Mock() m = mock.Mock()
...@@ -548,14 +548,11 @@ class MyDatagramProto(asyncio.DatagramProtocol): ...@@ -548,14 +548,11 @@ class MyDatagramProto(asyncio.DatagramProtocol):
self.done.set_result(None) self.done.set_result(None)
class BaseEventLoopWithSelectorTests(unittest.TestCase): class BaseEventLoopWithSelectorTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
@mock.patch('asyncio.base_events.socket') @mock.patch('asyncio.base_events.socket')
def test_create_connection_multiple_errors(self, m_socket): def test_create_connection_multiple_errors(self, m_socket):
......
...@@ -224,7 +224,7 @@ class EventLoopTestsMixin: ...@@ -224,7 +224,7 @@ class EventLoopTestsMixin:
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.loop = self.create_event_loop() self.loop = self.create_event_loop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self): def tearDown(self):
# just in case if we have transport close callbacks # just in case if we have transport close callbacks
...@@ -1629,14 +1629,14 @@ class SubprocessTestsMixin: ...@@ -1629,14 +1629,14 @@ class SubprocessTestsMixin:
if sys.platform == 'win32': if sys.platform == 'win32':
class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop() return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(EventLoopTestsMixin, class ProactorEventLoopTests(EventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.ProactorEventLoop() return asyncio.ProactorEventLoop()
...@@ -1691,7 +1691,7 @@ else: ...@@ -1691,7 +1691,7 @@ else:
if hasattr(selectors, 'KqueueSelector'): if hasattr(selectors, 'KqueueSelector'):
class KqueueEventLoopTests(UnixEventLoopTestsMixin, class KqueueEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop( return asyncio.SelectorEventLoop(
...@@ -1716,7 +1716,7 @@ else: ...@@ -1716,7 +1716,7 @@ else:
if hasattr(selectors, 'EpollSelector'): if hasattr(selectors, 'EpollSelector'):
class EPollEventLoopTests(UnixEventLoopTestsMixin, class EPollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.EpollSelector()) return asyncio.SelectorEventLoop(selectors.EpollSelector())
...@@ -1724,7 +1724,7 @@ else: ...@@ -1724,7 +1724,7 @@ else:
if hasattr(selectors, 'PollSelector'): if hasattr(selectors, 'PollSelector'):
class PollEventLoopTests(UnixEventLoopTestsMixin, class PollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.PollSelector()) return asyncio.SelectorEventLoop(selectors.PollSelector())
...@@ -1732,7 +1732,7 @@ else: ...@@ -1732,7 +1732,7 @@ else:
# Should always exist. # Should always exist.
class SelectEventLoopTests(UnixEventLoopTestsMixin, class SelectEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.SelectSelector()) return asyncio.SelectorEventLoop(selectors.SelectSelector())
......
...@@ -13,14 +13,10 @@ def _fakefunc(f): ...@@ -13,14 +13,10 @@ def _fakefunc(f):
return f return f
class FutureTests(unittest.TestCase): class FutureTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_initial_state(self): def test_initial_state(self):
f = asyncio.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
...@@ -30,12 +26,9 @@ class FutureTests(unittest.TestCase): ...@@ -30,12 +26,9 @@ class FutureTests(unittest.TestCase):
self.assertTrue(f.cancelled()) self.assertTrue(f.cancelled())
def test_init_constructor_default_loop(self): def test_init_constructor_default_loop(self):
try: asyncio.set_event_loop(self.loop)
asyncio.set_event_loop(self.loop) f = asyncio.Future()
f = asyncio.Future() self.assertIs(f._loop, self.loop)
self.assertIs(f._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_constructor_positional(self): def test_constructor_positional(self):
# Make sure Future doesn't accept a positional argument # Make sure Future doesn't accept a positional argument
...@@ -264,14 +257,10 @@ class FutureTests(unittest.TestCase): ...@@ -264,14 +257,10 @@ class FutureTests(unittest.TestCase):
self.assertTrue(f2.cancelled()) self.assertTrue(f2.cancelled())
class FutureDoneCallbackTests(unittest.TestCase): class FutureDoneCallbackTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def run_briefly(self): def run_briefly(self):
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
......
...@@ -17,14 +17,10 @@ STR_RGX_REPR = ( ...@@ -17,14 +17,10 @@ STR_RGX_REPR = (
RGX_REPR = re.compile(STR_RGX_REPR) RGX_REPR = re.compile(STR_RGX_REPR)
class LockTests(unittest.TestCase): class LockTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = mock.Mock() loop = mock.Mock()
...@@ -35,12 +31,9 @@ class LockTests(unittest.TestCase): ...@@ -35,12 +31,9 @@ class LockTests(unittest.TestCase):
self.assertIs(lock._loop, self.loop) self.assertIs(lock._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: asyncio.set_event_loop(self.loop)
asyncio.set_event_loop(self.loop) lock = asyncio.Lock()
lock = asyncio.Lock() self.assertIs(lock._loop, self.loop)
self.assertIs(lock._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_repr(self): def test_repr(self):
lock = asyncio.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
...@@ -240,14 +233,10 @@ class LockTests(unittest.TestCase): ...@@ -240,14 +233,10 @@ class LockTests(unittest.TestCase):
self.assertFalse(lock.locked()) self.assertFalse(lock.locked())
class EventTests(unittest.TestCase): class EventTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = mock.Mock() loop = mock.Mock()
...@@ -258,12 +247,9 @@ class EventTests(unittest.TestCase): ...@@ -258,12 +247,9 @@ class EventTests(unittest.TestCase):
self.assertIs(ev._loop, self.loop) self.assertIs(ev._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: asyncio.set_event_loop(self.loop)
asyncio.set_event_loop(self.loop) ev = asyncio.Event()
ev = asyncio.Event() self.assertIs(ev._loop, self.loop)
self.assertIs(ev._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_repr(self): def test_repr(self):
ev = asyncio.Event(loop=self.loop) ev = asyncio.Event(loop=self.loop)
...@@ -376,14 +362,10 @@ class EventTests(unittest.TestCase): ...@@ -376,14 +362,10 @@ class EventTests(unittest.TestCase):
self.assertTrue(t.result()) self.assertTrue(t.result())
class ConditionTests(unittest.TestCase): class ConditionTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = mock.Mock() loop = mock.Mock()
...@@ -394,12 +376,9 @@ class ConditionTests(unittest.TestCase): ...@@ -394,12 +376,9 @@ class ConditionTests(unittest.TestCase):
self.assertIs(cond._loop, self.loop) self.assertIs(cond._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: asyncio.set_event_loop(self.loop)
asyncio.set_event_loop(self.loop) cond = asyncio.Condition()
cond = asyncio.Condition() self.assertIs(cond._loop, self.loop)
self.assertIs(cond._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_wait(self): def test_wait(self):
cond = asyncio.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
...@@ -678,14 +657,10 @@ class ConditionTests(unittest.TestCase): ...@@ -678,14 +657,10 @@ class ConditionTests(unittest.TestCase):
self.assertFalse(cond.locked()) self.assertFalse(cond.locked())
class SemaphoreTests(unittest.TestCase): class SemaphoreTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = mock.Mock() loop = mock.Mock()
...@@ -696,12 +671,9 @@ class SemaphoreTests(unittest.TestCase): ...@@ -696,12 +671,9 @@ class SemaphoreTests(unittest.TestCase):
self.assertIs(sem._loop, self.loop) self.assertIs(sem._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: asyncio.set_event_loop(self.loop)
asyncio.set_event_loop(self.loop) sem = asyncio.Semaphore()
sem = asyncio.Semaphore() self.assertIs(sem._loop, self.loop)
self.assertIs(sem._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_initial_value_zero(self): def test_initial_value_zero(self):
sem = asyncio.Semaphore(0, loop=self.loop) sem = asyncio.Semaphore(0, loop=self.loop)
......
...@@ -12,10 +12,10 @@ from asyncio.proactor_events import _ProactorDuplexPipeTransport ...@@ -12,10 +12,10 @@ from asyncio.proactor_events import _ProactorDuplexPipeTransport
from asyncio import test_utils from asyncio import test_utils
class ProactorSocketTransportTests(unittest.TestCase): class ProactorSocketTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.proactor = mock.Mock() self.proactor = mock.Mock()
self.loop._proactor = self.proactor self.loop._proactor = self.proactor
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
...@@ -343,7 +343,7 @@ class ProactorSocketTransportTests(unittest.TestCase): ...@@ -343,7 +343,7 @@ class ProactorSocketTransportTests(unittest.TestCase):
tr.close() tr.close()
class BaseProactorEventLoopTests(unittest.TestCase): class BaseProactorEventLoopTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
...@@ -356,6 +356,7 @@ class BaseProactorEventLoopTests(unittest.TestCase): ...@@ -356,6 +356,7 @@ class BaseProactorEventLoopTests(unittest.TestCase):
return (self.ssock, self.csock) return (self.ssock, self.csock)
self.loop = EventLoop(self.proactor) self.loop = EventLoop(self.proactor)
self.set_event_loop(self.loop, cleanup=False)
@mock.patch.object(BaseProactorEventLoop, 'call_soon') @mock.patch.object(BaseProactorEventLoop, 'call_soon')
@mock.patch.object(BaseProactorEventLoop, '_socketpair') @mock.patch.object(BaseProactorEventLoop, '_socketpair')
......
...@@ -7,14 +7,10 @@ import asyncio ...@@ -7,14 +7,10 @@ import asyncio
from asyncio import test_utils from asyncio import test_utils
class _QueueTestBase(unittest.TestCase): class _QueueTestBase(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
class QueueBasicTests(_QueueTestBase): class QueueBasicTests(_QueueTestBase):
...@@ -32,8 +28,7 @@ class QueueBasicTests(_QueueTestBase): ...@@ -32,8 +28,7 @@ class QueueBasicTests(_QueueTestBase):
self.assertAlmostEqual(0.2, when) self.assertAlmostEqual(0.2, when)
yield 0.1 yield 0.1
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(loop=loop) q = asyncio.Queue(loop=loop)
self.assertTrue(fn(q).startswith('<Queue'), fn(q)) self.assertTrue(fn(q).startswith('<Queue'), fn(q))
...@@ -80,12 +75,9 @@ class QueueBasicTests(_QueueTestBase): ...@@ -80,12 +75,9 @@ class QueueBasicTests(_QueueTestBase):
self.assertIs(q._loop, self.loop) self.assertIs(q._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: asyncio.set_event_loop(self.loop)
asyncio.set_event_loop(self.loop) q = asyncio.Queue()
q = asyncio.Queue() self.assertIs(q._loop, self.loop)
self.assertIs(q._loop, self.loop)
finally:
asyncio.set_event_loop(None)
def test_repr(self): def test_repr(self):
self._test_repr_or_str(repr, True) self._test_repr_or_str(repr, True)
...@@ -126,8 +118,7 @@ class QueueBasicTests(_QueueTestBase): ...@@ -126,8 +118,7 @@ class QueueBasicTests(_QueueTestBase):
self.assertAlmostEqual(0.02, when) self.assertAlmostEqual(0.02, when)
yield 0.01 yield 0.01
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(maxsize=2, loop=loop) q = asyncio.Queue(maxsize=2, loop=loop)
self.assertEqual(2, q.maxsize) self.assertEqual(2, q.maxsize)
...@@ -194,8 +185,7 @@ class QueueGetTests(_QueueTestBase): ...@@ -194,8 +185,7 @@ class QueueGetTests(_QueueTestBase):
self.assertAlmostEqual(0.01, when) self.assertAlmostEqual(0.01, when)
yield 0.01 yield 0.01
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(loop=loop) q = asyncio.Queue(loop=loop)
started = asyncio.Event(loop=loop) started = asyncio.Event(loop=loop)
...@@ -241,8 +231,7 @@ class QueueGetTests(_QueueTestBase): ...@@ -241,8 +231,7 @@ class QueueGetTests(_QueueTestBase):
self.assertAlmostEqual(0.061, when) self.assertAlmostEqual(0.061, when)
yield 0.05 yield 0.05
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(loop=loop) q = asyncio.Queue(loop=loop)
...@@ -302,8 +291,7 @@ class QueuePutTests(_QueueTestBase): ...@@ -302,8 +291,7 @@ class QueuePutTests(_QueueTestBase):
self.assertAlmostEqual(0.01, when) self.assertAlmostEqual(0.01, when)
yield 0.01 yield 0.01
loop = test_utils.TestLoop(gen) loop = self.new_test_loop(gen)
self.addCleanup(loop.close)
q = asyncio.Queue(maxsize=1, loop=loop) q = asyncio.Queue(maxsize=1, loop=loop)
started = asyncio.Event(loop=loop) started = asyncio.Event(loop=loop)
......
...@@ -37,11 +37,12 @@ def list_to_buffer(l=()): ...@@ -37,11 +37,12 @@ def list_to_buffer(l=()):
return bytearray().join(l) return bytearray().join(l)
class BaseSelectorEventLoopTests(unittest.TestCase): class BaseSelectorEventLoopTests(test_utils.TestCase):
def setUp(self): def setUp(self):
selector = mock.Mock() selector = mock.Mock()
self.loop = TestBaseSelectorEventLoop(selector) self.loop = TestBaseSelectorEventLoop(selector)
self.set_event_loop(self.loop, cleanup=False)
def test_make_socket_transport(self): def test_make_socket_transport(self):
m = mock.Mock() m = mock.Mock()
...@@ -597,10 +598,10 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -597,10 +598,10 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.remove_writer.assert_called_with(1) self.loop.remove_writer.assert_called_with(1)
class SelectorTransportTests(unittest.TestCase): class SelectorTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
self.sock.fileno.return_value = 7 self.sock.fileno.return_value = 7
...@@ -684,14 +685,14 @@ class SelectorTransportTests(unittest.TestCase): ...@@ -684,14 +685,14 @@ class SelectorTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(2, sys.getrefcount(self.loop), self.assertEqual(3, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
class SelectorSocketTransportTests(unittest.TestCase): class SelectorSocketTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
self.sock_fd = self.sock.fileno.return_value = 7 self.sock_fd = self.sock.fileno.return_value = 7
...@@ -1061,10 +1062,10 @@ class SelectorSocketTransportTests(unittest.TestCase): ...@@ -1061,10 +1062,10 @@ class SelectorSocketTransportTests(unittest.TestCase):
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
class SelectorSslTransportTests(unittest.TestCase): class SelectorSslTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = mock.Mock(socket.socket) self.sock = mock.Mock(socket.socket)
self.sock.fileno.return_value = 7 self.sock.fileno.return_value = 7
...@@ -1396,10 +1397,10 @@ class SelectorSslWithoutSslTransportTests(unittest.TestCase): ...@@ -1396,10 +1397,10 @@ class SelectorSslWithoutSslTransportTests(unittest.TestCase):
_SelectorSslTransport(Mock(), Mock(), Mock(), Mock()) _SelectorSslTransport(Mock(), Mock(), Mock(), Mock())
class SelectorDatagramTransportTests(unittest.TestCase): class SelectorDatagramTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
self.sock = mock.Mock(spec_set=socket.socket) self.sock = mock.Mock(spec_set=socket.socket)
self.sock.fileno.return_value = 7 self.sock.fileno.return_value = 7
......
...@@ -15,13 +15,13 @@ import asyncio ...@@ -15,13 +15,13 @@ import asyncio
from asyncio import test_utils from asyncio import test_utils
class StreamReaderTests(unittest.TestCase): class StreamReaderTests(test_utils.TestCase):
DATA = b'line1\nline2\nline3\n' DATA = b'line1\nline2\nline3\n'
def setUp(self): def setUp(self):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self): def tearDown(self):
# just in case if we have transport close callbacks # just in case if we have transport close callbacks
...@@ -29,6 +29,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -29,6 +29,7 @@ class StreamReaderTests(unittest.TestCase):
self.loop.close() self.loop.close()
gc.collect() gc.collect()
super().tearDown()
@mock.patch('asyncio.streams.events') @mock.patch('asyncio.streams.events')
def test_ctor_global_loop(self, m_events): def test_ctor_global_loop(self, m_events):
......
from asyncio import subprocess from asyncio import subprocess
from asyncio import test_utils
import asyncio import asyncio
import signal import signal
import sys import sys
...@@ -151,21 +152,21 @@ if sys.platform != 'win32': ...@@ -151,21 +152,21 @@ if sys.platform != 'win32':
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
policy.set_child_watcher(None) policy.set_child_watcher(None)
self.loop.close() self.loop.close()
policy.set_event_loop(None) super().tearDown()
class SubprocessSafeWatcherTests(SubprocessWatcherMixin, class SubprocessSafeWatcherTests(SubprocessWatcherMixin,
unittest.TestCase): test_utils.TestCase):
Watcher = unix_events.SafeChildWatcher Watcher = unix_events.SafeChildWatcher
class SubprocessFastWatcherTests(SubprocessWatcherMixin, class SubprocessFastWatcherTests(SubprocessWatcherMixin,
unittest.TestCase): test_utils.TestCase):
Watcher = unix_events.FastChildWatcher Watcher = unix_events.FastChildWatcher
else: else:
# Windows # Windows
class SubprocessProactorTests(SubprocessMixin, unittest.TestCase): class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):
def setUp(self): def setUp(self):
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
...@@ -178,6 +179,7 @@ else: ...@@ -178,6 +179,7 @@ else:
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
self.loop.close() self.loop.close()
policy.set_event_loop(None) policy.set_event_loop(None)
super().tearDown()
if __name__ == '__main__': if __name__ == '__main__':
......
This diff is collapsed.
...@@ -29,14 +29,11 @@ MOCK_ANY = mock.ANY ...@@ -29,14 +29,11 @@ MOCK_ANY = mock.ANY
@unittest.skipUnless(signal, 'Signals are not supported') @unittest.skipUnless(signal, 'Signals are not supported')
class SelectorEventLoopSignalTests(unittest.TestCase): class SelectorEventLoopSignalTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.SelectorEventLoop() self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def test_check_signal(self): def test_check_signal(self):
self.assertRaises( self.assertRaises(
...@@ -208,14 +205,11 @@ class SelectorEventLoopSignalTests(unittest.TestCase): ...@@ -208,14 +205,11 @@ class SelectorEventLoopSignalTests(unittest.TestCase):
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), @unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
'UNIX Sockets are not supported') 'UNIX Sockets are not supported')
class SelectorEventLoopUnixSocketTests(unittest.TestCase): class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.SelectorEventLoop() self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
def test_create_unix_server_existing_path_sock(self): def test_create_unix_server_existing_path_sock(self):
with test_utils.unix_socket_path() as path: with test_utils.unix_socket_path() as path:
...@@ -304,10 +298,10 @@ class SelectorEventLoopUnixSocketTests(unittest.TestCase): ...@@ -304,10 +298,10 @@ class SelectorEventLoopUnixSocketTests(unittest.TestCase):
self.loop.run_until_complete(coro) self.loop.run_until_complete(coro)
class UnixReadPipeTransportTests(unittest.TestCase): class UnixReadPipeTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe = mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5 self.pipe.fileno.return_value = 5
...@@ -451,7 +445,7 @@ class UnixReadPipeTransportTests(unittest.TestCase): ...@@ -451,7 +445,7 @@ class UnixReadPipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop), self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
def test__call_connection_lost_with_err(self): def test__call_connection_lost_with_err(self):
...@@ -468,14 +462,14 @@ class UnixReadPipeTransportTests(unittest.TestCase): ...@@ -468,14 +462,14 @@ class UnixReadPipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop), self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
class UnixWritePipeTransportTests(unittest.TestCase): class UnixWritePipeTransportTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol) self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe = mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5 self.pipe.fileno.return_value = 5
...@@ -737,7 +731,7 @@ class UnixWritePipeTransportTests(unittest.TestCase): ...@@ -737,7 +731,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop), self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
def test__call_connection_lost_with_err(self): def test__call_connection_lost_with_err(self):
...@@ -753,7 +747,7 @@ class UnixWritePipeTransportTests(unittest.TestCase): ...@@ -753,7 +747,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
self.assertEqual(2, sys.getrefcount(self.protocol), self.assertEqual(2, sys.getrefcount(self.protocol),
pprint.pformat(gc.get_referrers(self.protocol))) pprint.pformat(gc.get_referrers(self.protocol)))
self.assertIsNone(tr._loop) self.assertIsNone(tr._loop)
self.assertEqual(4, sys.getrefcount(self.loop), self.assertEqual(5, sys.getrefcount(self.loop),
pprint.pformat(gc.get_referrers(self.loop))) pprint.pformat(gc.get_referrers(self.loop)))
def test_close(self): def test_close(self):
...@@ -834,7 +828,7 @@ class ChildWatcherTestsMixin: ...@@ -834,7 +828,7 @@ class ChildWatcherTestsMixin:
ignore_warnings = mock.patch.object(log.logger, "warning") ignore_warnings = mock.patch.object(log.logger, "warning")
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
self.running = False self.running = False
self.zombies = {} self.zombies = {}
...@@ -1392,7 +1386,7 @@ class ChildWatcherTestsMixin: ...@@ -1392,7 +1386,7 @@ class ChildWatcherTestsMixin:
# attach a new loop # attach a new loop
old_loop = self.loop old_loop = self.loop
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
patch = mock.patch.object patch = mock.patch.object
with patch(old_loop, "remove_signal_handler") as m_old_remove, \ with patch(old_loop, "remove_signal_handler") as m_old_remove, \
...@@ -1447,7 +1441,7 @@ class ChildWatcherTestsMixin: ...@@ -1447,7 +1441,7 @@ class ChildWatcherTestsMixin:
self.assertFalse(callback3.called) self.assertFalse(callback3.called)
# attach a new loop # attach a new loop
self.loop = test_utils.TestLoop() self.loop = self.new_test_loop()
with mock.patch.object( with mock.patch.object(
self.loop, "add_signal_handler") as m_add_signal_handler: self.loop, "add_signal_handler") as m_add_signal_handler:
...@@ -1505,12 +1499,12 @@ class ChildWatcherTestsMixin: ...@@ -1505,12 +1499,12 @@ class ChildWatcherTestsMixin:
self.assertFalse(self.watcher._zombies) self.assertFalse(self.watcher._zombies)
class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
def create_watcher(self): def create_watcher(self):
return asyncio.SafeChildWatcher() return asyncio.SafeChildWatcher()
class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
def create_watcher(self): def create_watcher(self):
return asyncio.FastChildWatcher() return asyncio.FastChildWatcher()
......
...@@ -26,15 +26,11 @@ class UpperProto(asyncio.Protocol): ...@@ -26,15 +26,11 @@ class UpperProto(asyncio.Protocol):
self.trans.close() self.trans.close()
class ProactorTests(unittest.TestCase): class ProactorTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.ProactorEventLoop() self.loop = asyncio.ProactorEventLoop()
asyncio.set_event_loop(None) self.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
self.loop = None
def test_close(self): def test_close(self):
a, b = self.loop._socketpair() a, b = self.loop._socketpair()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment