Commit 8dffc456 authored by Victor Stinner's avatar Victor Stinner

Update asyncio from the Tulip project

Major changes:

- StreamReader.readexactly() now raises an IncompleteReadError if the
  end of stream is reached before we received enough bytes, instead of
  returning less bytes than requested.

- Unit tests use the main asyncio module instead of submodules like events

- _UnixWritePipeTransport now also supports character devices, as
  _UnixReadPipeTransport. Patch written by Jonathan Slenders.

- Export more symbols: BaseEventLoop, BaseProactorEventLoop,
  BaseSelectorEventLoop, Queue and Queue sublasses, Empty, Full
parent 75a5ec88
...@@ -18,13 +18,17 @@ if sys.platform == 'win32': ...@@ -18,13 +18,17 @@ if sys.platform == 'win32':
import _overlapped # Will also be exported. import _overlapped # Will also be exported.
# This relies on each of the submodules having an __all__ variable. # This relies on each of the submodules having an __all__ variable.
from .futures import * from .base_events import *
from .events import * from .events import *
from .futures import *
from .locks import * from .locks import *
from .transports import * from .proactor_events import *
from .protocols import * from .protocols import *
from .queues import *
from .selector_events import *
from .streams import * from .streams import *
from .tasks import * from .tasks import *
from .transports import *
if sys.platform == 'win32': # pragma: no cover if sys.platform == 'win32': # pragma: no cover
from .windows_events import * from .windows_events import *
...@@ -32,10 +36,14 @@ else: ...@@ -32,10 +36,14 @@ else:
from .unix_events import * # pragma: no cover from .unix_events import * # pragma: no cover
__all__ = (futures.__all__ + __all__ = (base_events.__all__ +
events.__all__ + events.__all__ +
futures.__all__ +
locks.__all__ + locks.__all__ +
transports.__all__ + proactor_events.__all__ +
protocols.__all__ + protocols.__all__ +
queues.__all__ +
selector_events.__all__ +
streams.__all__ + streams.__all__ +
tasks.__all__) tasks.__all__ +
transports.__all__)
...@@ -4,6 +4,8 @@ A proactor is a "notify-on-completion" multiplexer. Currently a ...@@ -4,6 +4,8 @@ A proactor is a "notify-on-completion" multiplexer. Currently a
proactor is only implemented on Windows with IOCP. proactor is only implemented on Windows with IOCP.
""" """
__all__ = ['BaseProactorEventLoop']
import socket import socket
from . import base_events from . import base_events
......
...@@ -4,6 +4,8 @@ A selector is a "notify-when-ready" multiplexer. For a subclass which ...@@ -4,6 +4,8 @@ A selector is a "notify-when-ready" multiplexer. For a subclass which
also includes support for signal handling, see the unix_events sub-module. also includes support for signal handling, see the unix_events sub-module.
""" """
__all__ = ['BaseSelectorEventLoop']
import collections import collections
import errno import errno
import socket import socket
......
"""Stream-related things.""" """Stream-related things."""
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server', 'open_connection', 'start_server', 'IncompleteReadError',
] ]
import collections import collections
...@@ -14,6 +14,19 @@ from . import tasks ...@@ -14,6 +14,19 @@ from . import tasks
_DEFAULT_LIMIT = 2**16 _DEFAULT_LIMIT = 2**16
class IncompleteReadError(EOFError):
"""
Incomplete read error. Attributes:
- partial: read bytes string before the end of stream was reached
- expected: total number of expected bytes
"""
def __init__(self, partial, expected):
EOFError.__init__(self, "%s bytes read on a total of %s expected bytes"
% (len(partial), expected))
self.partial = partial
self.expected = expected
@tasks.coroutine @tasks.coroutine
def open_connection(host=None, port=None, *, def open_connection(host=None, port=None, *,
...@@ -403,12 +416,9 @@ class StreamReader: ...@@ -403,12 +416,9 @@ class StreamReader:
while n > 0: while n > 0:
block = yield from self.read(n) block = yield from self.read(n)
if not block: if not block:
break partial = b''.join(blocks)
raise IncompleteReadError(partial, len(partial) + n)
blocks.append(block) blocks.append(block)
n -= len(block) n -= len(block)
# TODO: Raise EOFError if we break before n == 0? (That would
# be a change in specification, but I've always had to add an
# explicit size check to the caller.)
return b''.join(blocks) return b''.join(blocks)
...@@ -259,9 +259,11 @@ class _UnixWritePipeTransport(transports.WriteTransport): ...@@ -259,9 +259,11 @@ class _UnixWritePipeTransport(transports.WriteTransport):
self._fileno = pipe.fileno() self._fileno = pipe.fileno()
mode = os.fstat(self._fileno).st_mode mode = os.fstat(self._fileno).st_mode
is_socket = stat.S_ISSOCK(mode) is_socket = stat.S_ISSOCK(mode)
is_pipe = stat.S_ISFIFO(mode) if not (is_socket or
if not (is_socket or is_pipe): stat.S_ISFIFO(mode) or
raise ValueError("Pipe transport is for pipes/sockets only.") stat.S_ISCHR(mode)):
raise ValueError("Pipe transport is only for "
"pipes, sockets and character devices")
_set_nonblocking(self._fileno) _set_nonblocking(self._fileno)
self._protocol = protocol self._protocol = protocol
self._buffer = [] self._buffer = []
......
...@@ -8,21 +8,17 @@ import unittest ...@@ -8,21 +8,17 @@ import unittest
import unittest.mock import unittest.mock
from test.support import find_unused_port, IPV6_ENABLED from test.support import find_unused_port, IPV6_ENABLED
from asyncio import base_events import asyncio
from asyncio import constants from asyncio import constants
from asyncio import events
from asyncio import futures
from asyncio import protocols
from asyncio import tasks
from asyncio import test_utils from asyncio import test_utils
class BaseEventLoopTests(unittest.TestCase): class BaseEventLoopTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = base_events.BaseEventLoop() self.loop = asyncio.BaseEventLoop()
self.loop._selector = unittest.mock.Mock() self.loop._selector = unittest.mock.Mock()
events.set_event_loop(None) asyncio.set_event_loop(None)
def test_not_implemented(self): def test_not_implemented(self):
m = unittest.mock.Mock() m = unittest.mock.Mock()
...@@ -51,20 +47,20 @@ class BaseEventLoopTests(unittest.TestCase): ...@@ -51,20 +47,20 @@ class BaseEventLoopTests(unittest.TestCase):
self.assertRaises(NotImplementedError, next, iter(gen)) self.assertRaises(NotImplementedError, next, iter(gen))
def test__add_callback_handle(self): def test__add_callback_handle(self):
h = events.Handle(lambda: False, ()) h = asyncio.Handle(lambda: False, ())
self.loop._add_callback(h) self.loop._add_callback(h)
self.assertFalse(self.loop._scheduled) self.assertFalse(self.loop._scheduled)
self.assertIn(h, self.loop._ready) self.assertIn(h, self.loop._ready)
def test__add_callback_timer(self): def test__add_callback_timer(self):
h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) h = asyncio.TimerHandle(time.monotonic()+10, lambda: False, ())
self.loop._add_callback(h) self.loop._add_callback(h)
self.assertIn(h, self.loop._scheduled) self.assertIn(h, self.loop._scheduled)
def test__add_callback_cancelled_handle(self): def test__add_callback_cancelled_handle(self):
h = events.Handle(lambda: False, ()) h = asyncio.Handle(lambda: False, ())
h.cancel() h.cancel()
self.loop._add_callback(h) self.loop._add_callback(h)
...@@ -90,7 +86,7 @@ class BaseEventLoopTests(unittest.TestCase): ...@@ -90,7 +86,7 @@ class BaseEventLoopTests(unittest.TestCase):
h = self.loop.call_soon(cb) h = self.loop.call_soon(cb)
self.assertEqual(h._callback, cb) self.assertEqual(h._callback, cb)
self.assertIsInstance(h, events.Handle) self.assertIsInstance(h, asyncio.Handle)
self.assertIn(h, self.loop._ready) self.assertIn(h, self.loop._ready)
def test_call_later(self): def test_call_later(self):
...@@ -98,7 +94,7 @@ class BaseEventLoopTests(unittest.TestCase): ...@@ -98,7 +94,7 @@ class BaseEventLoopTests(unittest.TestCase):
pass pass
h = self.loop.call_later(10.0, cb) h = self.loop.call_later(10.0, cb)
self.assertIsInstance(h, events.TimerHandle) self.assertIsInstance(h, asyncio.TimerHandle)
self.assertIn(h, self.loop._scheduled) self.assertIn(h, self.loop._scheduled)
self.assertNotIn(h, self.loop._ready) self.assertNotIn(h, self.loop._ready)
...@@ -132,27 +128,27 @@ class BaseEventLoopTests(unittest.TestCase): ...@@ -132,27 +128,27 @@ class BaseEventLoopTests(unittest.TestCase):
self.assertRaises( self.assertRaises(
AssertionError, self.loop.run_in_executor, AssertionError, self.loop.run_in_executor,
None, events.Handle(cb, ()), ('',)) None, asyncio.Handle(cb, ()), ('',))
self.assertRaises( self.assertRaises(
AssertionError, self.loop.run_in_executor, AssertionError, self.loop.run_in_executor,
None, events.TimerHandle(10, cb, ())) None, asyncio.TimerHandle(10, cb, ()))
def test_run_once_in_executor_cancelled(self): def test_run_once_in_executor_cancelled(self):
def cb(): def cb():
pass pass
h = events.Handle(cb, ()) h = asyncio.Handle(cb, ())
h.cancel() h.cancel()
f = self.loop.run_in_executor(None, h) f = self.loop.run_in_executor(None, h)
self.assertIsInstance(f, futures.Future) self.assertIsInstance(f, asyncio.Future)
self.assertTrue(f.done()) self.assertTrue(f.done())
self.assertIsNone(f.result()) self.assertIsNone(f.result())
def test_run_once_in_executor_plain(self): def test_run_once_in_executor_plain(self):
def cb(): def cb():
pass pass
h = events.Handle(cb, ()) h = asyncio.Handle(cb, ())
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
executor = unittest.mock.Mock() executor = unittest.mock.Mock()
executor.submit.return_value = f executor.submit.return_value = f
...@@ -170,8 +166,8 @@ class BaseEventLoopTests(unittest.TestCase): ...@@ -170,8 +166,8 @@ class BaseEventLoopTests(unittest.TestCase):
f.cancel() # Don't complain about abandoned Future. f.cancel() # Don't complain about abandoned Future.
def test__run_once(self): def test__run_once(self):
h1 = events.TimerHandle(time.monotonic() + 5.0, lambda: True, ()) h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, ())
h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, ())
h1.cancel() h1.cancel()
...@@ -202,14 +198,14 @@ class BaseEventLoopTests(unittest.TestCase): ...@@ -202,14 +198,14 @@ class BaseEventLoopTests(unittest.TestCase):
m_logging.DEBUG = logging.DEBUG m_logging.DEBUG = logging.DEBUG
self.loop._scheduled.append( self.loop._scheduled.append(
events.TimerHandle(11.0, lambda: True, ())) asyncio.TimerHandle(11.0, lambda: True, ()))
self.loop._process_events = unittest.mock.Mock() self.loop._process_events = unittest.mock.Mock()
self.loop._run_once() self.loop._run_once()
self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) self.assertEqual(logging.INFO, m_logging.log.call_args[0][0])
idx = -1 idx = -1
data = [10.0, 10.0, 10.3, 13.0] data = [10.0, 10.0, 10.3, 13.0]
self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())]
self.loop._run_once() self.loop._run_once()
self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0])
...@@ -222,7 +218,7 @@ class BaseEventLoopTests(unittest.TestCase): ...@@ -222,7 +218,7 @@ class BaseEventLoopTests(unittest.TestCase):
processed = True processed = True
handle = loop.call_soon(lambda: True) handle = loop.call_soon(lambda: True)
h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,))
self.loop._process_events = unittest.mock.Mock() self.loop._process_events = unittest.mock.Mock()
self.loop._scheduled.append(h) self.loop._scheduled.append(h)
...@@ -236,14 +232,14 @@ class BaseEventLoopTests(unittest.TestCase): ...@@ -236,14 +232,14 @@ class BaseEventLoopTests(unittest.TestCase):
TypeError, self.loop.run_until_complete, 'blah') TypeError, self.loop.run_until_complete, 'blah')
class MyProto(protocols.Protocol): class MyProto(asyncio.Protocol):
done = None done = None
def __init__(self, create_future=False): def __init__(self, create_future=False):
self.state = 'INITIAL' self.state = 'INITIAL'
self.nbytes = 0 self.nbytes = 0
if create_future: if create_future:
self.done = futures.Future() self.done = asyncio.Future()
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
...@@ -266,14 +262,14 @@ class MyProto(protocols.Protocol): ...@@ -266,14 +262,14 @@ class MyProto(protocols.Protocol):
self.done.set_result(None) self.done.set_result(None)
class MyDatagramProto(protocols.DatagramProtocol): class MyDatagramProto(asyncio.DatagramProtocol):
done = None done = None
def __init__(self, create_future=False): def __init__(self, create_future=False):
self.state = 'INITIAL' self.state = 'INITIAL'
self.nbytes = 0 self.nbytes = 0
if create_future: if create_future:
self.done = futures.Future() self.done = asyncio.Future()
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
...@@ -297,8 +293,8 @@ class MyDatagramProto(protocols.DatagramProtocol): ...@@ -297,8 +293,8 @@ class MyDatagramProto(protocols.DatagramProtocol):
class BaseEventLoopWithSelectorTests(unittest.TestCase): class BaseEventLoopWithSelectorTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = events.new_event_loop() self.loop = asyncio.new_event_loop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
...@@ -306,17 +302,17 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -306,17 +302,17 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
@unittest.mock.patch('asyncio.base_events.socket') @unittest.mock.patch('asyncio.base_events.socket')
def test_create_connection_multiple_errors(self, m_socket): def test_create_connection_multiple_errors(self, m_socket):
class MyProto(protocols.Protocol): class MyProto(asyncio.Protocol):
pass pass
@tasks.coroutine @asyncio.coroutine
def getaddrinfo(*args, **kw): def getaddrinfo(*args, **kw):
yield from [] yield from []
return [(2, 1, 6, '', ('107.6.106.82', 80)), return [(2, 1, 6, '', ('107.6.106.82', 80)),
(2, 1, 6, '', ('107.6.106.82', 80))] (2, 1, 6, '', ('107.6.106.82', 80))]
def getaddrinfo_task(*args, **kwds): def getaddrinfo_task(*args, **kwds):
return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
idx = -1 idx = -1
errors = ['err1', 'err2'] errors = ['err1', 'err2']
...@@ -346,12 +342,12 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -346,12 +342,12 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
self.assertRaises(ValueError, self.loop.run_until_complete, coro) self.assertRaises(ValueError, self.loop.run_until_complete, coro)
def test_create_connection_no_getaddrinfo(self): def test_create_connection_no_getaddrinfo(self):
@tasks.coroutine @asyncio.coroutine
def getaddrinfo(*args, **kw): def getaddrinfo(*args, **kw):
yield from [] yield from []
def getaddrinfo_task(*args, **kwds): def getaddrinfo_task(*args, **kwds):
return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
self.loop.getaddrinfo = getaddrinfo_task self.loop.getaddrinfo = getaddrinfo_task
coro = self.loop.create_connection(MyProto, 'example.com', 80) coro = self.loop.create_connection(MyProto, 'example.com', 80)
...@@ -359,13 +355,13 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -359,13 +355,13 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
OSError, self.loop.run_until_complete, coro) OSError, self.loop.run_until_complete, coro)
def test_create_connection_connect_err(self): def test_create_connection_connect_err(self):
@tasks.coroutine @asyncio.coroutine
def getaddrinfo(*args, **kw): def getaddrinfo(*args, **kw):
yield from [] yield from []
return [(2, 1, 6, '', ('107.6.106.82', 80))] return [(2, 1, 6, '', ('107.6.106.82', 80))]
def getaddrinfo_task(*args, **kwds): def getaddrinfo_task(*args, **kwds):
return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
self.loop.getaddrinfo = getaddrinfo_task self.loop.getaddrinfo = getaddrinfo_task
self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect = unittest.mock.Mock()
...@@ -376,13 +372,13 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -376,13 +372,13 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
OSError, self.loop.run_until_complete, coro) OSError, self.loop.run_until_complete, coro)
def test_create_connection_multiple(self): def test_create_connection_multiple(self):
@tasks.coroutine @asyncio.coroutine
def getaddrinfo(*args, **kw): def getaddrinfo(*args, **kw):
return [(2, 1, 6, '', ('0.0.0.1', 80)), return [(2, 1, 6, '', ('0.0.0.1', 80)),
(2, 1, 6, '', ('0.0.0.2', 80))] (2, 1, 6, '', ('0.0.0.2', 80))]
def getaddrinfo_task(*args, **kwds): def getaddrinfo_task(*args, **kwds):
return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
self.loop.getaddrinfo = getaddrinfo_task self.loop.getaddrinfo = getaddrinfo_task
self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect = unittest.mock.Mock()
...@@ -404,13 +400,13 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -404,13 +400,13 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
m_socket.socket.return_value.bind = bind m_socket.socket.return_value.bind = bind
@tasks.coroutine @asyncio.coroutine
def getaddrinfo(*args, **kw): def getaddrinfo(*args, **kw):
return [(2, 1, 6, '', ('0.0.0.1', 80)), return [(2, 1, 6, '', ('0.0.0.1', 80)),
(2, 1, 6, '', ('0.0.0.2', 80))] (2, 1, 6, '', ('0.0.0.2', 80))]
def getaddrinfo_task(*args, **kwds): def getaddrinfo_task(*args, **kwds):
return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
self.loop.getaddrinfo = getaddrinfo_task self.loop.getaddrinfo = getaddrinfo_task
self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect = unittest.mock.Mock()
...@@ -426,7 +422,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -426,7 +422,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
self.assertTrue(m_socket.socket.return_value.close.called) self.assertTrue(m_socket.socket.return_value.close.called)
def test_create_connection_no_local_addr(self): def test_create_connection_no_local_addr(self):
@tasks.coroutine @asyncio.coroutine
def getaddrinfo(host, *args, **kw): def getaddrinfo(host, *args, **kw):
if host == 'example.com': if host == 'example.com':
return [(2, 1, 6, '', ('107.6.106.82', 80)), return [(2, 1, 6, '', ('107.6.106.82', 80)),
...@@ -435,7 +431,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -435,7 +431,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
return [] return []
def getaddrinfo_task(*args, **kwds): def getaddrinfo_task(*args, **kwds):
return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
self.loop.getaddrinfo = getaddrinfo_task self.loop.getaddrinfo = getaddrinfo_task
coro = self.loop.create_connection( coro = self.loop.create_connection(
...@@ -448,7 +444,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -448,7 +444,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
self.loop.getaddrinfo = unittest.mock.Mock() self.loop.getaddrinfo = unittest.mock.Mock()
def mock_getaddrinfo(*args, **kwds): def mock_getaddrinfo(*args, **kwds):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.set_result([(socket.AF_INET, socket.SOCK_STREAM, f.set_result([(socket.AF_INET, socket.SOCK_STREAM,
socket.SOL_TCP, '', ('1.2.3.4', 80))]) socket.SOL_TCP, '', ('1.2.3.4', 80))])
return f return f
...@@ -527,14 +523,14 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -527,14 +523,14 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
# if host is empty string use None instead # if host is empty string use None instead
host = object() host = object()
@tasks.coroutine @asyncio.coroutine
def getaddrinfo(*args, **kw): def getaddrinfo(*args, **kw):
nonlocal host nonlocal host
host = args[0] host = args[0]
yield from [] yield from []
def getaddrinfo_task(*args, **kwds): def getaddrinfo_task(*args, **kwds):
return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop)
self.loop.getaddrinfo = getaddrinfo_task self.loop.getaddrinfo = getaddrinfo_task
fut = self.loop.create_server(MyProto, '', 0) fut = self.loop.create_server(MyProto, '', 0)
...@@ -596,7 +592,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -596,7 +592,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
self.loop.sock_connect.side_effect = OSError self.loop.sock_connect.side_effect = OSError
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) asyncio.DatagramProtocol, remote_addr=('127.0.0.1', 0))
self.assertRaises( self.assertRaises(
OSError, self.loop.run_until_complete, coro) OSError, self.loop.run_until_complete, coro)
...@@ -606,19 +602,19 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -606,19 +602,19 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
m_socket.socket.side_effect = OSError m_socket.socket.side_effect = OSError
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
protocols.DatagramProtocol, family=socket.AF_INET) asyncio.DatagramProtocol, family=socket.AF_INET)
self.assertRaises( self.assertRaises(
OSError, self.loop.run_until_complete, coro) OSError, self.loop.run_until_complete, coro)
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) asyncio.DatagramProtocol, local_addr=('127.0.0.1', 0))
self.assertRaises( self.assertRaises(
OSError, self.loop.run_until_complete, coro) OSError, self.loop.run_until_complete, coro)
@unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled')
def test_create_datagram_endpoint_no_matching_family(self): def test_create_datagram_endpoint_no_matching_family(self):
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
protocols.DatagramProtocol, asyncio.DatagramProtocol,
remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) remote_addr=('127.0.0.1', 0), local_addr=('::1', 0))
self.assertRaises( self.assertRaises(
ValueError, self.loop.run_until_complete, coro) ValueError, self.loop.run_until_complete, coro)
...@@ -628,7 +624,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -628,7 +624,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
m_socket.socket.return_value.setblocking.side_effect = OSError m_socket.socket.return_value.setblocking.side_effect = OSError
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
protocols.DatagramProtocol, family=socket.AF_INET) asyncio.DatagramProtocol, family=socket.AF_INET)
self.assertRaises( self.assertRaises(
OSError, self.loop.run_until_complete, coro) OSError, self.loop.run_until_complete, coro)
self.assertTrue( self.assertTrue(
...@@ -636,7 +632,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -636,7 +632,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
def test_create_datagram_endpoint_noaddr_nofamily(self): def test_create_datagram_endpoint_noaddr_nofamily(self):
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
protocols.DatagramProtocol) asyncio.DatagramProtocol)
self.assertRaises(ValueError, self.loop.run_until_complete, coro) self.assertRaises(ValueError, self.loop.run_until_complete, coro)
@unittest.mock.patch('asyncio.base_events.socket') @unittest.mock.patch('asyncio.base_events.socket')
......
...@@ -23,14 +23,9 @@ import unittest.mock ...@@ -23,14 +23,9 @@ import unittest.mock
from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR
from asyncio import futures import asyncio
from asyncio import events from asyncio import events
from asyncio import transports
from asyncio import protocols
from asyncio import selector_events
from asyncio import tasks
from asyncio import test_utils from asyncio import test_utils
from asyncio import locks
def data_file(filename): def data_file(filename):
...@@ -49,7 +44,7 @@ SIGNED_CERTFILE = data_file('keycert3.pem') ...@@ -49,7 +44,7 @@ SIGNED_CERTFILE = data_file('keycert3.pem')
SIGNING_CA = data_file('pycacert.pem') SIGNING_CA = data_file('pycacert.pem')
class MyProto(protocols.Protocol): class MyProto(asyncio.Protocol):
done = None done = None
def __init__(self, loop=None): def __init__(self, loop=None):
...@@ -57,7 +52,7 @@ class MyProto(protocols.Protocol): ...@@ -57,7 +52,7 @@ class MyProto(protocols.Protocol):
self.state = 'INITIAL' self.state = 'INITIAL'
self.nbytes = 0 self.nbytes = 0
if loop is not None: if loop is not None:
self.done = futures.Future(loop=loop) self.done = asyncio.Future(loop=loop)
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
...@@ -80,14 +75,14 @@ class MyProto(protocols.Protocol): ...@@ -80,14 +75,14 @@ class MyProto(protocols.Protocol):
self.done.set_result(None) self.done.set_result(None)
class MyDatagramProto(protocols.DatagramProtocol): class MyDatagramProto(asyncio.DatagramProtocol):
done = None done = None
def __init__(self, loop=None): def __init__(self, loop=None):
self.state = 'INITIAL' self.state = 'INITIAL'
self.nbytes = 0 self.nbytes = 0
if loop is not None: if loop is not None:
self.done = futures.Future(loop=loop) self.done = asyncio.Future(loop=loop)
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
...@@ -108,7 +103,7 @@ class MyDatagramProto(protocols.DatagramProtocol): ...@@ -108,7 +103,7 @@ class MyDatagramProto(protocols.DatagramProtocol):
self.done.set_result(None) self.done.set_result(None)
class MyReadPipeProto(protocols.Protocol): class MyReadPipeProto(asyncio.Protocol):
done = None done = None
def __init__(self, loop=None): def __init__(self, loop=None):
...@@ -116,7 +111,7 @@ class MyReadPipeProto(protocols.Protocol): ...@@ -116,7 +111,7 @@ class MyReadPipeProto(protocols.Protocol):
self.nbytes = 0 self.nbytes = 0
self.transport = None self.transport = None
if loop is not None: if loop is not None:
self.done = futures.Future(loop=loop) self.done = asyncio.Future(loop=loop)
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
...@@ -140,14 +135,14 @@ class MyReadPipeProto(protocols.Protocol): ...@@ -140,14 +135,14 @@ class MyReadPipeProto(protocols.Protocol):
self.done.set_result(None) self.done.set_result(None)
class MyWritePipeProto(protocols.BaseProtocol): class MyWritePipeProto(asyncio.BaseProtocol):
done = None done = None
def __init__(self, loop=None): def __init__(self, loop=None):
self.state = 'INITIAL' self.state = 'INITIAL'
self.transport = None self.transport = None
if loop is not None: if loop is not None:
self.done = futures.Future(loop=loop) self.done = asyncio.Future(loop=loop)
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
...@@ -161,18 +156,18 @@ class MyWritePipeProto(protocols.BaseProtocol): ...@@ -161,18 +156,18 @@ class MyWritePipeProto(protocols.BaseProtocol):
self.done.set_result(None) self.done.set_result(None)
class MySubprocessProtocol(protocols.SubprocessProtocol): class MySubprocessProtocol(asyncio.SubprocessProtocol):
def __init__(self, loop): def __init__(self, loop):
self.state = 'INITIAL' self.state = 'INITIAL'
self.transport = None self.transport = None
self.connected = futures.Future(loop=loop) self.connected = asyncio.Future(loop=loop)
self.completed = futures.Future(loop=loop) self.completed = asyncio.Future(loop=loop)
self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} self.disconnects = {fd: asyncio.Future(loop=loop) for fd in range(3)}
self.data = {1: b'', 2: b''} self.data = {1: b'', 2: b''}
self.returncode = None self.returncode = None
self.got_data = {1: locks.Event(loop=loop), self.got_data = {1: asyncio.Event(loop=loop),
2: locks.Event(loop=loop)} 2: asyncio.Event(loop=loop)}
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
...@@ -207,7 +202,7 @@ class EventLoopTestsMixin: ...@@ -207,7 +202,7 @@ class EventLoopTestsMixin:
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.loop = self.create_event_loop() self.loop = self.create_event_loop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
# just in case if we have transport close callbacks # just in case if we have transport close callbacks
...@@ -218,11 +213,11 @@ class EventLoopTestsMixin: ...@@ -218,11 +213,11 @@ class EventLoopTestsMixin:
super().tearDown() super().tearDown()
def test_run_until_complete_nesting(self): def test_run_until_complete_nesting(self):
@tasks.coroutine @asyncio.coroutine
def coro1(): def coro1():
yield yield
@tasks.coroutine @asyncio.coroutine
def coro2(): def coro2():
self.assertTrue(self.loop.is_running()) self.assertTrue(self.loop.is_running())
self.loop.run_until_complete(coro1()) self.loop.run_until_complete(coro1())
...@@ -235,15 +230,15 @@ class EventLoopTestsMixin: ...@@ -235,15 +230,15 @@ class EventLoopTestsMixin:
def test_run_until_complete(self): def test_run_until_complete(self):
t0 = self.loop.time() t0 = self.loop.time()
self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop))
t1 = self.loop.time() t1 = self.loop.time()
self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0)
def test_run_until_complete_stopped(self): def test_run_until_complete_stopped(self):
@tasks.coroutine @asyncio.coroutine
def cb(): def cb():
self.loop.stop() self.loop.stop()
yield from tasks.sleep(0.1, loop=self.loop) yield from asyncio.sleep(0.1, loop=self.loop)
task = cb() task = cb()
self.assertRaises(RuntimeError, self.assertRaises(RuntimeError,
self.loop.run_until_complete, task) self.loop.run_until_complete, task)
...@@ -494,8 +489,8 @@ class EventLoopTestsMixin: ...@@ -494,8 +489,8 @@ class EventLoopTestsMixin:
f = self.loop.create_connection( f = self.loop.create_connection(
lambda: MyProto(loop=self.loop), *httpd.address) lambda: MyProto(loop=self.loop), *httpd.address)
tr, pr = self.loop.run_until_complete(f) tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, transports.Transport) self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, protocols.Protocol) self.assertIsInstance(pr, asyncio.Protocol)
self.loop.run_until_complete(pr.done) self.loop.run_until_complete(pr.done)
self.assertGreater(pr.nbytes, 0) self.assertGreater(pr.nbytes, 0)
tr.close() tr.close()
...@@ -522,8 +517,8 @@ class EventLoopTestsMixin: ...@@ -522,8 +517,8 @@ class EventLoopTestsMixin:
f = self.loop.create_connection( f = self.loop.create_connection(
lambda: MyProto(loop=self.loop), sock=sock) lambda: MyProto(loop=self.loop), sock=sock)
tr, pr = self.loop.run_until_complete(f) tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, transports.Transport) self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, protocols.Protocol) self.assertIsInstance(pr, asyncio.Protocol)
self.loop.run_until_complete(pr.done) self.loop.run_until_complete(pr.done)
self.assertGreater(pr.nbytes, 0) self.assertGreater(pr.nbytes, 0)
tr.close() tr.close()
...@@ -535,8 +530,8 @@ class EventLoopTestsMixin: ...@@ -535,8 +530,8 @@ class EventLoopTestsMixin:
lambda: MyProto(loop=self.loop), *httpd.address, lambda: MyProto(loop=self.loop), *httpd.address,
ssl=test_utils.dummy_ssl_context()) ssl=test_utils.dummy_ssl_context())
tr, pr = self.loop.run_until_complete(f) tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, transports.Transport) self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, protocols.Protocol) self.assertIsInstance(pr, asyncio.Protocol)
self.assertTrue('ssl' in tr.__class__.__name__.lower()) self.assertTrue('ssl' in tr.__class__.__name__.lower())
self.assertIsNotNone(tr.get_extra_info('sockname')) self.assertIsNotNone(tr.get_extra_info('sockname'))
self.loop.run_until_complete(pr.done) self.loop.run_until_complete(pr.done)
...@@ -762,7 +757,7 @@ class EventLoopTestsMixin: ...@@ -762,7 +757,7 @@ class EventLoopTestsMixin:
server.close() server.close()
def test_create_server_sock(self): def test_create_server_sock(self):
proto = futures.Future(loop=self.loop) proto = asyncio.Future(loop=self.loop)
class TestMyProto(MyProto): class TestMyProto(MyProto):
def connection_made(self, transport): def connection_made(self, transport):
...@@ -805,7 +800,7 @@ class EventLoopTestsMixin: ...@@ -805,7 +800,7 @@ class EventLoopTestsMixin:
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled')
def test_create_server_dual_stack(self): def test_create_server_dual_stack(self):
f_proto = futures.Future(loop=self.loop) f_proto = asyncio.Future(loop=self.loop)
class TestMyProto(MyProto): class TestMyProto(MyProto):
def connection_made(self, transport): def connection_made(self, transport):
...@@ -834,7 +829,7 @@ class EventLoopTestsMixin: ...@@ -834,7 +829,7 @@ class EventLoopTestsMixin:
proto.transport.close() proto.transport.close()
client.close() client.close()
f_proto = futures.Future(loop=self.loop) f_proto = asyncio.Future(loop=self.loop)
client = socket.socket(socket.AF_INET6) client = socket.socket(socket.AF_INET6)
client.connect(('::1', port)) client.connect(('::1', port))
client.send(b'xxx') client.send(b'xxx')
...@@ -907,7 +902,7 @@ class EventLoopTestsMixin: ...@@ -907,7 +902,7 @@ class EventLoopTestsMixin:
def test_internal_fds(self): def test_internal_fds(self):
loop = self.create_event_loop() loop = self.create_event_loop()
if not isinstance(loop, selector_events.BaseSelectorEventLoop): if not isinstance(loop, asyncio.BaseSelectorEventLoop):
self.skipTest('loop is not a BaseSelectorEventLoop') self.skipTest('loop is not a BaseSelectorEventLoop')
self.assertEqual(1, loop._internal_fds) self.assertEqual(1, loop._internal_fds)
...@@ -929,7 +924,7 @@ class EventLoopTestsMixin: ...@@ -929,7 +924,7 @@ class EventLoopTestsMixin:
rpipe, wpipe = os.pipe() rpipe, wpipe = os.pipe()
pipeobj = io.open(rpipe, 'rb', 1024) pipeobj = io.open(rpipe, 'rb', 1024)
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) t, p = yield from self.loop.connect_read_pipe(factory, pipeobj)
self.assertIs(p, proto) self.assertIs(p, proto)
...@@ -957,9 +952,6 @@ class EventLoopTestsMixin: ...@@ -957,9 +952,6 @@ class EventLoopTestsMixin:
@unittest.skipUnless(sys.platform != 'win32', @unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows") "Don't support pipes for Windows")
# kqueue doesn't support character devices (PTY) on Mac OS X older
# than 10.9 (Maverick)
@support.requires_mac_ver(10, 9)
def test_read_pty_output(self): def test_read_pty_output(self):
proto = None proto = None
...@@ -971,7 +963,7 @@ class EventLoopTestsMixin: ...@@ -971,7 +963,7 @@ class EventLoopTestsMixin:
master, slave = os.openpty() master, slave = os.openpty()
master_read_obj = io.open(master, 'rb', 0) master_read_obj = io.open(master, 'rb', 0)
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
t, p = yield from self.loop.connect_read_pipe(factory, t, p = yield from self.loop.connect_read_pipe(factory,
master_read_obj) master_read_obj)
...@@ -1012,7 +1004,7 @@ class EventLoopTestsMixin: ...@@ -1012,7 +1004,7 @@ class EventLoopTestsMixin:
rpipe, wpipe = os.pipe() rpipe, wpipe = os.pipe()
pipeobj = io.open(wpipe, 'wb', 1024) pipeobj = io.open(wpipe, 'wb', 1024)
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal transport nonlocal transport
t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) t, p = yield from self.loop.connect_write_pipe(factory, pipeobj)
...@@ -1058,7 +1050,7 @@ class EventLoopTestsMixin: ...@@ -1058,7 +1050,7 @@ class EventLoopTestsMixin:
rsock, wsock = test_utils.socketpair() rsock, wsock = test_utils.socketpair()
pipeobj = io.open(wsock.detach(), 'wb', 1024) pipeobj = io.open(wsock.detach(), 'wb', 1024)
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal transport nonlocal transport
t, p = yield from self.loop.connect_write_pipe(factory, t, p = yield from self.loop.connect_write_pipe(factory,
...@@ -1080,6 +1072,53 @@ class EventLoopTestsMixin: ...@@ -1080,6 +1072,53 @@ class EventLoopTestsMixin:
self.loop.run_until_complete(proto.done) self.loop.run_until_complete(proto.done)
self.assertEqual('CLOSED', proto.state) self.assertEqual('CLOSED', proto.state)
@unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows")
def test_write_pty(self):
proto = None
transport = None
def factory():
nonlocal proto
proto = MyWritePipeProto(loop=self.loop)
return proto
master, slave = os.openpty()
slave_write_obj = io.open(slave, 'wb', 0)
@asyncio.coroutine
def connect():
nonlocal transport
t, p = yield from self.loop.connect_write_pipe(factory,
slave_write_obj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
self.assertEqual('CONNECTED', proto.state)
transport = t
self.loop.run_until_complete(connect())
transport.write(b'1')
test_utils.run_briefly(self.loop)
data = os.read(master, 1024)
self.assertEqual(b'1', data)
transport.write(b'2345')
test_utils.run_briefly(self.loop)
data = os.read(master, 1024)
self.assertEqual(b'2345', data)
self.assertEqual('CONNECTED', proto.state)
os.close(master)
# extra info is available
self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
# close connection
proto.transport.close()
self.loop.run_until_complete(proto.done)
self.assertEqual('CLOSED', proto.state)
def test_prompt_cancellation(self): def test_prompt_cancellation(self):
r, w = test_utils.socketpair() r, w = test_utils.socketpair()
r.setblocking(False) r.setblocking(False)
...@@ -1088,12 +1127,12 @@ class EventLoopTestsMixin: ...@@ -1088,12 +1127,12 @@ class EventLoopTestsMixin:
if ov is not None: if ov is not None:
self.assertTrue(ov.pending) self.assertTrue(ov.pending)
@tasks.coroutine @asyncio.coroutine
def main(): def main():
try: try:
self.loop.call_soon(f.cancel) self.loop.call_soon(f.cancel)
yield from f yield from f
except futures.CancelledError: except asyncio.CancelledError:
res = 'cancelled' res = 'cancelled'
else: else:
res = None res = None
...@@ -1102,13 +1141,13 @@ class EventLoopTestsMixin: ...@@ -1102,13 +1141,13 @@ class EventLoopTestsMixin:
return res return res
start = time.monotonic() start = time.monotonic()
t = tasks.Task(main(), loop=self.loop) t = asyncio.Task(main(), loop=self.loop)
self.loop.run_forever() self.loop.run_forever()
elapsed = time.monotonic() - start elapsed = time.monotonic() - start
self.assertLess(elapsed, 0.1) self.assertLess(elapsed, 0.1)
self.assertEqual(t.result(), 'cancelled') self.assertEqual(t.result(), 'cancelled')
self.assertRaises(futures.CancelledError, f.result) self.assertRaises(asyncio.CancelledError, f.result)
if ov is not None: if ov is not None:
self.assertFalse(ov.pending) self.assertFalse(ov.pending)
self.loop._stop_serving(r) self.loop._stop_serving(r)
...@@ -1126,13 +1165,13 @@ class EventLoopTestsMixin: ...@@ -1126,13 +1165,13 @@ class EventLoopTestsMixin:
self.loop._run_once = _run_once self.loop._run_once = _run_once
calls = [] calls = []
@tasks.coroutine @asyncio.coroutine
def wait(): def wait():
loop = self.loop loop = self.loop
calls.append(loop._run_once_counter) calls.append(loop._run_once_counter)
yield from tasks.sleep(loop.granularity * 10, loop=loop) yield from asyncio.sleep(loop.granularity * 10, loop=loop)
calls.append(loop._run_once_counter) calls.append(loop._run_once_counter)
yield from tasks.sleep(loop.granularity / 10, loop=loop) yield from asyncio.sleep(loop.granularity / 10, loop=loop)
calls.append(loop._run_once_counter) calls.append(loop._run_once_counter)
self.loop.run_until_complete(wait()) self.loop.run_until_complete(wait())
...@@ -1162,7 +1201,7 @@ class SubprocessTestsMixin: ...@@ -1162,7 +1201,7 @@ class SubprocessTestsMixin:
prog = os.path.join(os.path.dirname(__file__), 'echo.py') prog = os.path.join(os.path.dirname(__file__), 'echo.py')
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_exec( transp, proto = yield from self.loop.subprocess_exec(
...@@ -1188,7 +1227,7 @@ class SubprocessTestsMixin: ...@@ -1188,7 +1227,7 @@ class SubprocessTestsMixin:
prog = os.path.join(os.path.dirname(__file__), 'echo.py') prog = os.path.join(os.path.dirname(__file__), 'echo.py')
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_exec( transp, proto = yield from self.loop.subprocess_exec(
...@@ -1220,7 +1259,7 @@ class SubprocessTestsMixin: ...@@ -1220,7 +1259,7 @@ class SubprocessTestsMixin:
proto = None proto = None
transp = None transp = None
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_shell( transp, proto = yield from self.loop.subprocess_shell(
...@@ -1241,7 +1280,7 @@ class SubprocessTestsMixin: ...@@ -1241,7 +1280,7 @@ class SubprocessTestsMixin:
def test_subprocess_exitcode(self): def test_subprocess_exitcode(self):
proto = None proto = None
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto nonlocal proto
transp, proto = yield from self.loop.subprocess_shell( transp, proto = yield from self.loop.subprocess_shell(
...@@ -1257,7 +1296,7 @@ class SubprocessTestsMixin: ...@@ -1257,7 +1296,7 @@ class SubprocessTestsMixin:
proto = None proto = None
transp = None transp = None
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_shell( transp, proto = yield from self.loop.subprocess_shell(
...@@ -1279,7 +1318,7 @@ class SubprocessTestsMixin: ...@@ -1279,7 +1318,7 @@ class SubprocessTestsMixin:
prog = os.path.join(os.path.dirname(__file__), 'echo.py') prog = os.path.join(os.path.dirname(__file__), 'echo.py')
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_exec( transp, proto = yield from self.loop.subprocess_exec(
...@@ -1300,7 +1339,7 @@ class SubprocessTestsMixin: ...@@ -1300,7 +1339,7 @@ class SubprocessTestsMixin:
prog = os.path.join(os.path.dirname(__file__), 'echo.py') prog = os.path.join(os.path.dirname(__file__), 'echo.py')
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_exec( transp, proto = yield from self.loop.subprocess_exec(
...@@ -1322,7 +1361,7 @@ class SubprocessTestsMixin: ...@@ -1322,7 +1361,7 @@ class SubprocessTestsMixin:
prog = os.path.join(os.path.dirname(__file__), 'echo.py') prog = os.path.join(os.path.dirname(__file__), 'echo.py')
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_exec( transp, proto = yield from self.loop.subprocess_exec(
...@@ -1343,7 +1382,7 @@ class SubprocessTestsMixin: ...@@ -1343,7 +1382,7 @@ class SubprocessTestsMixin:
prog = os.path.join(os.path.dirname(__file__), 'echo2.py') prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_exec( transp, proto = yield from self.loop.subprocess_exec(
...@@ -1370,7 +1409,7 @@ class SubprocessTestsMixin: ...@@ -1370,7 +1409,7 @@ class SubprocessTestsMixin:
prog = os.path.join(os.path.dirname(__file__), 'echo2.py') prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_exec( transp, proto = yield from self.loop.subprocess_exec(
...@@ -1400,7 +1439,7 @@ class SubprocessTestsMixin: ...@@ -1400,7 +1439,7 @@ class SubprocessTestsMixin:
prog = os.path.join(os.path.dirname(__file__), 'echo3.py') prog = os.path.join(os.path.dirname(__file__), 'echo3.py')
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto, transp nonlocal proto, transp
transp, proto = yield from self.loop.subprocess_exec( transp, proto = yield from self.loop.subprocess_exec(
...@@ -1437,7 +1476,7 @@ class SubprocessTestsMixin: ...@@ -1437,7 +1476,7 @@ class SubprocessTestsMixin:
proto = None proto = None
transp = None transp = None
@tasks.coroutine @asyncio.coroutine
def connect(): def connect():
nonlocal proto nonlocal proto
# start the new process in a new session # start the new process in a new session
...@@ -1453,19 +1492,18 @@ class SubprocessTestsMixin: ...@@ -1453,19 +1492,18 @@ class SubprocessTestsMixin:
if sys.platform == 'win32': if sys.platform == 'win32':
from asyncio import windows_events
class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
def create_event_loop(self): def create_event_loop(self):
return windows_events.SelectorEventLoop() return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(EventLoopTestsMixin, class ProactorEventLoopTests(EventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
def create_event_loop(self): def create_event_loop(self):
return windows_events.ProactorEventLoop() return asyncio.ProactorEventLoop()
def test_create_ssl_connection(self): def test_create_ssl_connection(self):
raise unittest.SkipTest("IocpEventLoop incompatible with SSL") raise unittest.SkipTest("IocpEventLoop incompatible with SSL")
...@@ -1499,17 +1537,16 @@ if sys.platform == 'win32': ...@@ -1499,17 +1537,16 @@ if sys.platform == 'win32':
"IocpEventLoop does not have create_datagram_endpoint()") "IocpEventLoop does not have create_datagram_endpoint()")
else: else:
from asyncio import selectors from asyncio import selectors
from asyncio import unix_events
class UnixEventLoopTestsMixin(EventLoopTestsMixin): class UnixEventLoopTestsMixin(EventLoopTestsMixin):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
watcher = unix_events.SafeChildWatcher() watcher = asyncio.SafeChildWatcher()
watcher.attach_loop(self.loop) watcher.attach_loop(self.loop)
events.set_child_watcher(watcher) asyncio.set_child_watcher(watcher)
def tearDown(self): def tearDown(self):
events.set_child_watcher(None) asyncio.set_child_watcher(None)
super().tearDown() super().tearDown()
if hasattr(selectors, 'KqueueSelector'): if hasattr(selectors, 'KqueueSelector'):
...@@ -1518,16 +1555,28 @@ else: ...@@ -1518,16 +1555,28 @@ else:
unittest.TestCase): unittest.TestCase):
def create_event_loop(self): def create_event_loop(self):
return unix_events.SelectorEventLoop( return asyncio.SelectorEventLoop(
selectors.KqueueSelector()) selectors.KqueueSelector())
# kqueue doesn't support character devices (PTY) on Mac OS X older
# than 10.9 (Maverick)
@support.requires_mac_ver(10, 9)
def test_read_pty_output(self):
super().test_read_pty_output()
# kqueue doesn't support character devices (PTY) on Mac OS X older
# than 10.9 (Maverick)
@support.requires_mac_ver(10, 9)
def test_write_pty(self):
super().test_write_pty()
if hasattr(selectors, 'EpollSelector'): if hasattr(selectors, 'EpollSelector'):
class EPollEventLoopTests(UnixEventLoopTestsMixin, class EPollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
unittest.TestCase): unittest.TestCase):
def create_event_loop(self): def create_event_loop(self):
return unix_events.SelectorEventLoop(selectors.EpollSelector()) return asyncio.SelectorEventLoop(selectors.EpollSelector())
if hasattr(selectors, 'PollSelector'): if hasattr(selectors, 'PollSelector'):
class PollEventLoopTests(UnixEventLoopTestsMixin, class PollEventLoopTests(UnixEventLoopTestsMixin,
...@@ -1535,7 +1584,7 @@ else: ...@@ -1535,7 +1584,7 @@ else:
unittest.TestCase): unittest.TestCase):
def create_event_loop(self): def create_event_loop(self):
return unix_events.SelectorEventLoop(selectors.PollSelector()) return asyncio.SelectorEventLoop(selectors.PollSelector())
# Should always exist. # Should always exist.
class SelectEventLoopTests(UnixEventLoopTestsMixin, class SelectEventLoopTests(UnixEventLoopTestsMixin,
...@@ -1543,7 +1592,7 @@ else: ...@@ -1543,7 +1592,7 @@ else:
unittest.TestCase): unittest.TestCase):
def create_event_loop(self): def create_event_loop(self):
return unix_events.SelectorEventLoop(selectors.SelectSelector()) return asyncio.SelectorEventLoop(selectors.SelectSelector())
class HandleTests(unittest.TestCase): class HandleTests(unittest.TestCase):
...@@ -1553,7 +1602,7 @@ class HandleTests(unittest.TestCase): ...@@ -1553,7 +1602,7 @@ class HandleTests(unittest.TestCase):
return args return args
args = () args = ()
h = events.Handle(callback, args) h = asyncio.Handle(callback, args)
self.assertIs(h._callback, callback) self.assertIs(h._callback, callback)
self.assertIs(h._args, args) self.assertIs(h._args, args)
self.assertFalse(h._cancelled) self.assertFalse(h._cancelled)
...@@ -1576,16 +1625,16 @@ class HandleTests(unittest.TestCase): ...@@ -1576,16 +1625,16 @@ class HandleTests(unittest.TestCase):
def test_make_handle(self): def test_make_handle(self):
def callback(*args): def callback(*args):
return args return args
h1 = events.Handle(callback, ()) h1 = asyncio.Handle(callback, ())
self.assertRaises( self.assertRaises(
AssertionError, events.make_handle, h1, ()) AssertionError, asyncio.events.make_handle, h1, ())
@unittest.mock.patch('asyncio.events.logger') @unittest.mock.patch('asyncio.events.logger')
def test_callback_with_exception(self, log): def test_callback_with_exception(self, log):
def callback(): def callback():
raise ValueError() raise ValueError()
h = events.Handle(callback, ()) h = asyncio.Handle(callback, ())
h._run() h._run()
self.assertTrue(log.exception.called) self.assertTrue(log.exception.called)
...@@ -1594,7 +1643,7 @@ class TimerTests(unittest.TestCase): ...@@ -1594,7 +1643,7 @@ class TimerTests(unittest.TestCase):
def test_hash(self): def test_hash(self):
when = time.monotonic() when = time.monotonic()
h = events.TimerHandle(when, lambda: False, ()) h = asyncio.TimerHandle(when, lambda: False, ())
self.assertEqual(hash(h), hash(when)) self.assertEqual(hash(h), hash(when))
def test_timer(self): def test_timer(self):
...@@ -1603,7 +1652,7 @@ class TimerTests(unittest.TestCase): ...@@ -1603,7 +1652,7 @@ class TimerTests(unittest.TestCase):
args = () args = ()
when = time.monotonic() when = time.monotonic()
h = events.TimerHandle(when, callback, args) h = asyncio.TimerHandle(when, callback, args)
self.assertIs(h._callback, callback) self.assertIs(h._callback, callback)
self.assertIs(h._args, args) self.assertIs(h._args, args)
self.assertFalse(h._cancelled) self.assertFalse(h._cancelled)
...@@ -1618,7 +1667,7 @@ class TimerTests(unittest.TestCase): ...@@ -1618,7 +1667,7 @@ class TimerTests(unittest.TestCase):
self.assertTrue(r.endswith('())<cancelled>'), r) self.assertTrue(r.endswith('())<cancelled>'), r)
self.assertRaises(AssertionError, self.assertRaises(AssertionError,
events.TimerHandle, None, callback, args) asyncio.TimerHandle, None, callback, args)
def test_timer_comparison(self): def test_timer_comparison(self):
def callback(*args): def callback(*args):
...@@ -1626,8 +1675,8 @@ class TimerTests(unittest.TestCase): ...@@ -1626,8 +1675,8 @@ class TimerTests(unittest.TestCase):
when = time.monotonic() when = time.monotonic()
h1 = events.TimerHandle(when, callback, ()) h1 = asyncio.TimerHandle(when, callback, ())
h2 = events.TimerHandle(when, callback, ()) h2 = asyncio.TimerHandle(when, callback, ())
# TODO: Use assertLess etc. # TODO: Use assertLess etc.
self.assertFalse(h1 < h2) self.assertFalse(h1 < h2)
self.assertFalse(h2 < h1) self.assertFalse(h2 < h1)
...@@ -1643,8 +1692,8 @@ class TimerTests(unittest.TestCase): ...@@ -1643,8 +1692,8 @@ class TimerTests(unittest.TestCase):
h2.cancel() h2.cancel()
self.assertFalse(h1 == h2) self.assertFalse(h1 == h2)
h1 = events.TimerHandle(when, callback, ()) h1 = asyncio.TimerHandle(when, callback, ())
h2 = events.TimerHandle(when + 10.0, callback, ()) h2 = asyncio.TimerHandle(when + 10.0, callback, ())
self.assertTrue(h1 < h2) self.assertTrue(h1 < h2)
self.assertFalse(h2 < h1) self.assertFalse(h2 < h1)
self.assertTrue(h1 <= h2) self.assertTrue(h1 <= h2)
...@@ -1656,7 +1705,7 @@ class TimerTests(unittest.TestCase): ...@@ -1656,7 +1705,7 @@ class TimerTests(unittest.TestCase):
self.assertFalse(h1 == h2) self.assertFalse(h1 == h2)
self.assertTrue(h1 != h2) self.assertTrue(h1 != h2)
h3 = events.Handle(callback, ()) h3 = asyncio.Handle(callback, ())
self.assertIs(NotImplemented, h1.__eq__(h3)) self.assertIs(NotImplemented, h1.__eq__(h3))
self.assertIs(NotImplemented, h1.__ne__(h3)) self.assertIs(NotImplemented, h1.__ne__(h3))
...@@ -1665,7 +1714,7 @@ class AbstractEventLoopTests(unittest.TestCase): ...@@ -1665,7 +1714,7 @@ class AbstractEventLoopTests(unittest.TestCase):
def test_not_implemented(self): def test_not_implemented(self):
f = unittest.mock.Mock() f = unittest.mock.Mock()
loop = events.AbstractEventLoop() loop = asyncio.AbstractEventLoop()
self.assertRaises( self.assertRaises(
NotImplementedError, loop.run_forever) NotImplementedError, loop.run_forever)
self.assertRaises( self.assertRaises(
...@@ -1739,19 +1788,19 @@ class ProtocolsAbsTests(unittest.TestCase): ...@@ -1739,19 +1788,19 @@ class ProtocolsAbsTests(unittest.TestCase):
def test_empty(self): def test_empty(self):
f = unittest.mock.Mock() f = unittest.mock.Mock()
p = protocols.Protocol() p = asyncio.Protocol()
self.assertIsNone(p.connection_made(f)) self.assertIsNone(p.connection_made(f))
self.assertIsNone(p.connection_lost(f)) self.assertIsNone(p.connection_lost(f))
self.assertIsNone(p.data_received(f)) self.assertIsNone(p.data_received(f))
self.assertIsNone(p.eof_received()) self.assertIsNone(p.eof_received())
dp = protocols.DatagramProtocol() dp = asyncio.DatagramProtocol()
self.assertIsNone(dp.connection_made(f)) self.assertIsNone(dp.connection_made(f))
self.assertIsNone(dp.connection_lost(f)) self.assertIsNone(dp.connection_lost(f))
self.assertIsNone(dp.error_received(f)) self.assertIsNone(dp.error_received(f))
self.assertIsNone(dp.datagram_received(f, f)) self.assertIsNone(dp.datagram_received(f, f))
sp = protocols.SubprocessProtocol() sp = asyncio.SubprocessProtocol()
self.assertIsNone(sp.connection_made(f)) self.assertIsNone(sp.connection_made(f))
self.assertIsNone(sp.connection_lost(f)) self.assertIsNone(sp.connection_lost(f))
self.assertIsNone(sp.pipe_data_received(1, f)) self.assertIsNone(sp.pipe_data_received(1, f))
...@@ -1761,16 +1810,8 @@ class ProtocolsAbsTests(unittest.TestCase): ...@@ -1761,16 +1810,8 @@ class ProtocolsAbsTests(unittest.TestCase):
class PolicyTests(unittest.TestCase): class PolicyTests(unittest.TestCase):
def create_policy(self):
if sys.platform == "win32":
from asyncio import windows_events
return windows_events.DefaultEventLoopPolicy()
else:
from asyncio import unix_events
return unix_events.DefaultEventLoopPolicy()
def test_event_loop_policy(self): def test_event_loop_policy(self):
policy = events.AbstractEventLoopPolicy() policy = asyncio.AbstractEventLoopPolicy()
self.assertRaises(NotImplementedError, policy.get_event_loop) self.assertRaises(NotImplementedError, policy.get_event_loop)
self.assertRaises(NotImplementedError, policy.set_event_loop, object()) self.assertRaises(NotImplementedError, policy.set_event_loop, object())
self.assertRaises(NotImplementedError, policy.new_event_loop) self.assertRaises(NotImplementedError, policy.new_event_loop)
...@@ -1779,18 +1820,18 @@ class PolicyTests(unittest.TestCase): ...@@ -1779,18 +1820,18 @@ class PolicyTests(unittest.TestCase):
object()) object())
def test_get_event_loop(self): def test_get_event_loop(self):
policy = self.create_policy() policy = asyncio.DefaultEventLoopPolicy()
self.assertIsNone(policy._local._loop) self.assertIsNone(policy._local._loop)
loop = policy.get_event_loop() loop = policy.get_event_loop()
self.assertIsInstance(loop, events.AbstractEventLoop) self.assertIsInstance(loop, asyncio.AbstractEventLoop)
self.assertIs(policy._local._loop, loop) self.assertIs(policy._local._loop, loop)
self.assertIs(loop, policy.get_event_loop()) self.assertIs(loop, policy.get_event_loop())
loop.close() loop.close()
def test_get_event_loop_calls_set_event_loop(self): def test_get_event_loop_calls_set_event_loop(self):
policy = self.create_policy() policy = asyncio.DefaultEventLoopPolicy()
with unittest.mock.patch.object( with unittest.mock.patch.object(
policy, "set_event_loop", policy, "set_event_loop",
...@@ -1806,7 +1847,7 @@ class PolicyTests(unittest.TestCase): ...@@ -1806,7 +1847,7 @@ class PolicyTests(unittest.TestCase):
loop.close() loop.close()
def test_get_event_loop_after_set_none(self): def test_get_event_loop_after_set_none(self):
policy = self.create_policy() policy = asyncio.DefaultEventLoopPolicy()
policy.set_event_loop(None) policy.set_event_loop(None)
self.assertRaises(AssertionError, policy.get_event_loop) self.assertRaises(AssertionError, policy.get_event_loop)
...@@ -1814,7 +1855,7 @@ class PolicyTests(unittest.TestCase): ...@@ -1814,7 +1855,7 @@ class PolicyTests(unittest.TestCase):
def test_get_event_loop_thread(self, m_current_thread): def test_get_event_loop_thread(self, m_current_thread):
def f(): def f():
policy = self.create_policy() policy = asyncio.DefaultEventLoopPolicy()
self.assertRaises(AssertionError, policy.get_event_loop) self.assertRaises(AssertionError, policy.get_event_loop)
th = threading.Thread(target=f) th = threading.Thread(target=f)
...@@ -1822,14 +1863,14 @@ class PolicyTests(unittest.TestCase): ...@@ -1822,14 +1863,14 @@ class PolicyTests(unittest.TestCase):
th.join() th.join()
def test_new_event_loop(self): def test_new_event_loop(self):
policy = self.create_policy() policy = asyncio.DefaultEventLoopPolicy()
loop = policy.new_event_loop() loop = policy.new_event_loop()
self.assertIsInstance(loop, events.AbstractEventLoop) self.assertIsInstance(loop, asyncio.AbstractEventLoop)
loop.close() loop.close()
def test_set_event_loop(self): def test_set_event_loop(self):
policy = self.create_policy() policy = asyncio.DefaultEventLoopPolicy()
old_loop = policy.get_event_loop() old_loop = policy.get_event_loop()
self.assertRaises(AssertionError, policy.set_event_loop, object()) self.assertRaises(AssertionError, policy.set_event_loop, object())
...@@ -1842,19 +1883,19 @@ class PolicyTests(unittest.TestCase): ...@@ -1842,19 +1883,19 @@ class PolicyTests(unittest.TestCase):
old_loop.close() old_loop.close()
def test_get_event_loop_policy(self): def test_get_event_loop_policy(self):
policy = events.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
self.assertIsInstance(policy, events.AbstractEventLoopPolicy) self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy)
self.assertIs(policy, events.get_event_loop_policy()) self.assertIs(policy, asyncio.get_event_loop_policy())
def test_set_event_loop_policy(self): def test_set_event_loop_policy(self):
self.assertRaises( self.assertRaises(
AssertionError, events.set_event_loop_policy, object()) AssertionError, asyncio.set_event_loop_policy, object())
old_policy = events.get_event_loop_policy() old_policy = asyncio.get_event_loop_policy()
policy = self.create_policy() policy = asyncio.DefaultEventLoopPolicy()
events.set_event_loop_policy(policy) asyncio.set_event_loop_policy(policy)
self.assertIs(policy, events.get_event_loop_policy()) self.assertIs(policy, asyncio.get_event_loop_policy())
self.assertIsNot(policy, old_policy) self.assertIsNot(policy, old_policy)
......
...@@ -5,8 +5,7 @@ import threading ...@@ -5,8 +5,7 @@ import threading
import unittest import unittest
import unittest.mock import unittest.mock
from asyncio import events import asyncio
from asyncio import futures
from asyncio import test_utils from asyncio import test_utils
...@@ -18,13 +17,13 @@ class FutureTests(unittest.TestCase): ...@@ -18,13 +17,13 @@ class FutureTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
def test_initial_state(self): def test_initial_state(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
self.assertFalse(f.cancelled()) self.assertFalse(f.cancelled())
self.assertFalse(f.done()) self.assertFalse(f.done())
f.cancel() f.cancel()
...@@ -32,56 +31,56 @@ class FutureTests(unittest.TestCase): ...@@ -32,56 +31,56 @@ class FutureTests(unittest.TestCase):
def test_init_constructor_default_loop(self): def test_init_constructor_default_loop(self):
try: try:
events.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
f = futures.Future() f = asyncio.Future()
self.assertIs(f._loop, self.loop) self.assertIs(f._loop, self.loop)
finally: finally:
events.set_event_loop(None) asyncio.set_event_loop(None)
def test_constructor_positional(self): def test_constructor_positional(self):
# Make sure Future does't accept a positional argument # Make sure Future does't accept a positional argument
self.assertRaises(TypeError, futures.Future, 42) self.assertRaises(TypeError, asyncio.Future, 42)
def test_cancel(self): def test_cancel(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
self.assertTrue(f.cancel()) self.assertTrue(f.cancel())
self.assertTrue(f.cancelled()) self.assertTrue(f.cancelled())
self.assertTrue(f.done()) self.assertTrue(f.done())
self.assertRaises(futures.CancelledError, f.result) self.assertRaises(asyncio.CancelledError, f.result)
self.assertRaises(futures.CancelledError, f.exception) self.assertRaises(asyncio.CancelledError, f.exception)
self.assertRaises(futures.InvalidStateError, f.set_result, None) self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
self.assertFalse(f.cancel()) self.assertFalse(f.cancel())
def test_result(self): def test_result(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
self.assertRaises(futures.InvalidStateError, f.result) self.assertRaises(asyncio.InvalidStateError, f.result)
f.set_result(42) f.set_result(42)
self.assertFalse(f.cancelled()) self.assertFalse(f.cancelled())
self.assertTrue(f.done()) self.assertTrue(f.done())
self.assertEqual(f.result(), 42) self.assertEqual(f.result(), 42)
self.assertEqual(f.exception(), None) self.assertEqual(f.exception(), None)
self.assertRaises(futures.InvalidStateError, f.set_result, None) self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
self.assertFalse(f.cancel()) self.assertFalse(f.cancel())
def test_exception(self): def test_exception(self):
exc = RuntimeError() exc = RuntimeError()
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
self.assertRaises(futures.InvalidStateError, f.exception) self.assertRaises(asyncio.InvalidStateError, f.exception)
f.set_exception(exc) f.set_exception(exc)
self.assertFalse(f.cancelled()) self.assertFalse(f.cancelled())
self.assertTrue(f.done()) self.assertTrue(f.done())
self.assertRaises(RuntimeError, f.result) self.assertRaises(RuntimeError, f.result)
self.assertEqual(f.exception(), exc) self.assertEqual(f.exception(), exc)
self.assertRaises(futures.InvalidStateError, f.set_result, None) self.assertRaises(asyncio.InvalidStateError, f.set_result, None)
self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertRaises(asyncio.InvalidStateError, f.set_exception, None)
self.assertFalse(f.cancel()) self.assertFalse(f.cancel())
def test_yield_from_twice(self): def test_yield_from_twice(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
def fixture(): def fixture():
yield 'A' yield 'A'
...@@ -99,32 +98,32 @@ class FutureTests(unittest.TestCase): ...@@ -99,32 +98,32 @@ class FutureTests(unittest.TestCase):
self.assertEqual(next(g), ('C', 42)) # yield 'C', y. self.assertEqual(next(g), ('C', 42)) # yield 'C', y.
def test_repr(self): def test_repr(self):
f_pending = futures.Future(loop=self.loop) f_pending = asyncio.Future(loop=self.loop)
self.assertEqual(repr(f_pending), 'Future<PENDING>') self.assertEqual(repr(f_pending), 'Future<PENDING>')
f_pending.cancel() f_pending.cancel()
f_cancelled = futures.Future(loop=self.loop) f_cancelled = asyncio.Future(loop=self.loop)
f_cancelled.cancel() f_cancelled.cancel()
self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>') self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>')
f_result = futures.Future(loop=self.loop) f_result = asyncio.Future(loop=self.loop)
f_result.set_result(4) f_result.set_result(4)
self.assertEqual(repr(f_result), 'Future<result=4>') self.assertEqual(repr(f_result), 'Future<result=4>')
self.assertEqual(f_result.result(), 4) self.assertEqual(f_result.result(), 4)
exc = RuntimeError() exc = RuntimeError()
f_exception = futures.Future(loop=self.loop) f_exception = asyncio.Future(loop=self.loop)
f_exception.set_exception(exc) f_exception.set_exception(exc)
self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>') self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>')
self.assertIs(f_exception.exception(), exc) self.assertIs(f_exception.exception(), exc)
f_few_callbacks = futures.Future(loop=self.loop) f_few_callbacks = asyncio.Future(loop=self.loop)
f_few_callbacks.add_done_callback(_fakefunc) f_few_callbacks.add_done_callback(_fakefunc)
self.assertIn('Future<PENDING, [<function _fakefunc', self.assertIn('Future<PENDING, [<function _fakefunc',
repr(f_few_callbacks)) repr(f_few_callbacks))
f_few_callbacks.cancel() f_few_callbacks.cancel()
f_many_callbacks = futures.Future(loop=self.loop) f_many_callbacks = asyncio.Future(loop=self.loop)
for i in range(20): for i in range(20):
f_many_callbacks.add_done_callback(_fakefunc) f_many_callbacks.add_done_callback(_fakefunc)
r = repr(f_many_callbacks) r = repr(f_many_callbacks)
...@@ -135,31 +134,31 @@ class FutureTests(unittest.TestCase): ...@@ -135,31 +134,31 @@ class FutureTests(unittest.TestCase):
def test_copy_state(self): def test_copy_state(self):
# Test the internal _copy_state method since it's being directly # Test the internal _copy_state method since it's being directly
# invoked in other modules. # invoked in other modules.
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.set_result(10) f.set_result(10)
newf = futures.Future(loop=self.loop) newf = asyncio.Future(loop=self.loop)
newf._copy_state(f) newf._copy_state(f)
self.assertTrue(newf.done()) self.assertTrue(newf.done())
self.assertEqual(newf.result(), 10) self.assertEqual(newf.result(), 10)
f_exception = futures.Future(loop=self.loop) f_exception = asyncio.Future(loop=self.loop)
f_exception.set_exception(RuntimeError()) f_exception.set_exception(RuntimeError())
newf_exception = futures.Future(loop=self.loop) newf_exception = asyncio.Future(loop=self.loop)
newf_exception._copy_state(f_exception) newf_exception._copy_state(f_exception)
self.assertTrue(newf_exception.done()) self.assertTrue(newf_exception.done())
self.assertRaises(RuntimeError, newf_exception.result) self.assertRaises(RuntimeError, newf_exception.result)
f_cancelled = futures.Future(loop=self.loop) f_cancelled = asyncio.Future(loop=self.loop)
f_cancelled.cancel() f_cancelled.cancel()
newf_cancelled = futures.Future(loop=self.loop) newf_cancelled = asyncio.Future(loop=self.loop)
newf_cancelled._copy_state(f_cancelled) newf_cancelled._copy_state(f_cancelled)
self.assertTrue(newf_cancelled.cancelled()) self.assertTrue(newf_cancelled.cancelled())
def test_iter(self): def test_iter(self):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
def coro(): def coro():
yield from fut yield from fut
...@@ -172,20 +171,20 @@ class FutureTests(unittest.TestCase): ...@@ -172,20 +171,20 @@ class FutureTests(unittest.TestCase):
@unittest.mock.patch('asyncio.futures.logger') @unittest.mock.patch('asyncio.futures.logger')
def test_tb_logger_abandoned(self, m_log): def test_tb_logger_abandoned(self, m_log):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
del fut del fut
self.assertFalse(m_log.error.called) self.assertFalse(m_log.error.called)
@unittest.mock.patch('asyncio.futures.logger') @unittest.mock.patch('asyncio.futures.logger')
def test_tb_logger_result_unretrieved(self, m_log): def test_tb_logger_result_unretrieved(self, m_log):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
fut.set_result(42) fut.set_result(42)
del fut del fut
self.assertFalse(m_log.error.called) self.assertFalse(m_log.error.called)
@unittest.mock.patch('asyncio.futures.logger') @unittest.mock.patch('asyncio.futures.logger')
def test_tb_logger_result_retrieved(self, m_log): def test_tb_logger_result_retrieved(self, m_log):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
fut.set_result(42) fut.set_result(42)
fut.result() fut.result()
del fut del fut
...@@ -193,7 +192,7 @@ class FutureTests(unittest.TestCase): ...@@ -193,7 +192,7 @@ class FutureTests(unittest.TestCase):
@unittest.mock.patch('asyncio.futures.logger') @unittest.mock.patch('asyncio.futures.logger')
def test_tb_logger_exception_unretrieved(self, m_log): def test_tb_logger_exception_unretrieved(self, m_log):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
fut.set_exception(RuntimeError('boom')) fut.set_exception(RuntimeError('boom'))
del fut del fut
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
...@@ -201,7 +200,7 @@ class FutureTests(unittest.TestCase): ...@@ -201,7 +200,7 @@ class FutureTests(unittest.TestCase):
@unittest.mock.patch('asyncio.futures.logger') @unittest.mock.patch('asyncio.futures.logger')
def test_tb_logger_exception_retrieved(self, m_log): def test_tb_logger_exception_retrieved(self, m_log):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
fut.set_exception(RuntimeError('boom')) fut.set_exception(RuntimeError('boom'))
fut.exception() fut.exception()
del fut del fut
...@@ -209,7 +208,7 @@ class FutureTests(unittest.TestCase): ...@@ -209,7 +208,7 @@ class FutureTests(unittest.TestCase):
@unittest.mock.patch('asyncio.futures.logger') @unittest.mock.patch('asyncio.futures.logger')
def test_tb_logger_exception_result_retrieved(self, m_log): def test_tb_logger_exception_result_retrieved(self, m_log):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
fut.set_exception(RuntimeError('boom')) fut.set_exception(RuntimeError('boom'))
self.assertRaises(RuntimeError, fut.result) self.assertRaises(RuntimeError, fut.result)
del fut del fut
...@@ -221,15 +220,15 @@ class FutureTests(unittest.TestCase): ...@@ -221,15 +220,15 @@ class FutureTests(unittest.TestCase):
return (arg, threading.get_ident()) return (arg, threading.get_ident())
ex = concurrent.futures.ThreadPoolExecutor(1) ex = concurrent.futures.ThreadPoolExecutor(1)
f1 = ex.submit(run, 'oi') f1 = ex.submit(run, 'oi')
f2 = futures.wrap_future(f1, loop=self.loop) f2 = asyncio.wrap_future(f1, loop=self.loop)
res, ident = self.loop.run_until_complete(f2) res, ident = self.loop.run_until_complete(f2)
self.assertIsInstance(f2, futures.Future) self.assertIsInstance(f2, asyncio.Future)
self.assertEqual(res, 'oi') self.assertEqual(res, 'oi')
self.assertNotEqual(ident, threading.get_ident()) self.assertNotEqual(ident, threading.get_ident())
def test_wrap_future_future(self): def test_wrap_future_future(self):
f1 = futures.Future(loop=self.loop) f1 = asyncio.Future(loop=self.loop)
f2 = futures.wrap_future(f1) f2 = asyncio.wrap_future(f1)
self.assertIs(f1, f2) self.assertIs(f1, f2)
@unittest.mock.patch('asyncio.futures.events') @unittest.mock.patch('asyncio.futures.events')
...@@ -238,12 +237,12 @@ class FutureTests(unittest.TestCase): ...@@ -238,12 +237,12 @@ class FutureTests(unittest.TestCase):
return (arg, threading.get_ident()) return (arg, threading.get_ident())
ex = concurrent.futures.ThreadPoolExecutor(1) ex = concurrent.futures.ThreadPoolExecutor(1)
f1 = ex.submit(run, 'oi') f1 = ex.submit(run, 'oi')
f2 = futures.wrap_future(f1) f2 = asyncio.wrap_future(f1)
self.assertIs(m_events.get_event_loop.return_value, f2._loop) self.assertIs(m_events.get_event_loop.return_value, f2._loop)
def test_wrap_future_cancel(self): def test_wrap_future_cancel(self):
f1 = concurrent.futures.Future() f1 = concurrent.futures.Future()
f2 = futures.wrap_future(f1, loop=self.loop) f2 = asyncio.wrap_future(f1, loop=self.loop)
f2.cancel() f2.cancel()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertTrue(f1.cancelled()) self.assertTrue(f1.cancelled())
...@@ -251,7 +250,7 @@ class FutureTests(unittest.TestCase): ...@@ -251,7 +250,7 @@ class FutureTests(unittest.TestCase):
def test_wrap_future_cancel2(self): def test_wrap_future_cancel2(self):
f1 = concurrent.futures.Future() f1 = concurrent.futures.Future()
f2 = futures.wrap_future(f1, loop=self.loop) f2 = asyncio.wrap_future(f1, loop=self.loop)
f1.set_result(42) f1.set_result(42)
f2.cancel() f2.cancel()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
...@@ -264,7 +263,7 @@ class FutureDoneCallbackTests(unittest.TestCase): ...@@ -264,7 +263,7 @@ class FutureDoneCallbackTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
...@@ -279,7 +278,7 @@ class FutureDoneCallbackTests(unittest.TestCase): ...@@ -279,7 +278,7 @@ class FutureDoneCallbackTests(unittest.TestCase):
return bag_appender return bag_appender
def _new_future(self): def _new_future(self):
return futures.Future(loop=self.loop) return asyncio.Future(loop=self.loop)
def test_callbacks_invoked_on_set_result(self): def test_callbacks_invoked_on_set_result(self):
bag = [] bag = []
......
...@@ -4,10 +4,7 @@ import unittest ...@@ -4,10 +4,7 @@ import unittest
import unittest.mock import unittest.mock
import re import re
from asyncio import events import asyncio
from asyncio import futures
from asyncio import locks
from asyncio import tasks
from asyncio import test_utils from asyncio import test_utils
...@@ -24,33 +21,33 @@ class LockTests(unittest.TestCase): ...@@ -24,33 +21,33 @@ class LockTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = unittest.mock.Mock() loop = unittest.mock.Mock()
lock = locks.Lock(loop=loop) lock = asyncio.Lock(loop=loop)
self.assertIs(lock._loop, loop) self.assertIs(lock._loop, loop)
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
self.assertIs(lock._loop, self.loop) self.assertIs(lock._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: try:
events.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
lock = locks.Lock() lock = asyncio.Lock()
self.assertIs(lock._loop, self.loop) self.assertIs(lock._loop, self.loop)
finally: finally:
events.set_event_loop(None) asyncio.set_event_loop(None)
def test_repr(self): def test_repr(self):
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
self.assertTrue(repr(lock).endswith('[unlocked]>')) self.assertTrue(repr(lock).endswith('[unlocked]>'))
self.assertTrue(RGX_REPR.match(repr(lock))) self.assertTrue(RGX_REPR.match(repr(lock)))
@tasks.coroutine @asyncio.coroutine
def acquire_lock(): def acquire_lock():
yield from lock yield from lock
...@@ -59,9 +56,9 @@ class LockTests(unittest.TestCase): ...@@ -59,9 +56,9 @@ class LockTests(unittest.TestCase):
self.assertTrue(RGX_REPR.match(repr(lock))) self.assertTrue(RGX_REPR.match(repr(lock)))
def test_lock(self): def test_lock(self):
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def acquire_lock(): def acquire_lock():
return (yield from lock) return (yield from lock)
...@@ -74,31 +71,31 @@ class LockTests(unittest.TestCase): ...@@ -74,31 +71,31 @@ class LockTests(unittest.TestCase):
self.assertFalse(lock.locked()) self.assertFalse(lock.locked())
def test_acquire(self): def test_acquire(self):
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
result = [] result = []
self.assertTrue(self.loop.run_until_complete(lock.acquire())) self.assertTrue(self.loop.run_until_complete(lock.acquire()))
@tasks.coroutine @asyncio.coroutine
def c1(result): def c1(result):
if (yield from lock.acquire()): if (yield from lock.acquire()):
result.append(1) result.append(1)
return True return True
@tasks.coroutine @asyncio.coroutine
def c2(result): def c2(result):
if (yield from lock.acquire()): if (yield from lock.acquire()):
result.append(2) result.append(2)
return True return True
@tasks.coroutine @asyncio.coroutine
def c3(result): def c3(result):
if (yield from lock.acquire()): if (yield from lock.acquire()):
result.append(3) result.append(3)
return True return True
t1 = tasks.Task(c1(result), loop=self.loop) t1 = asyncio.Task(c1(result), loop=self.loop)
t2 = tasks.Task(c2(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual([], result) self.assertEqual([], result)
...@@ -110,7 +107,7 @@ class LockTests(unittest.TestCase): ...@@ -110,7 +107,7 @@ class LockTests(unittest.TestCase):
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual([1], result) self.assertEqual([1], result)
t3 = tasks.Task(c3(result), loop=self.loop) t3 = asyncio.Task(c3(result), loop=self.loop)
lock.release() lock.release()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
...@@ -128,13 +125,13 @@ class LockTests(unittest.TestCase): ...@@ -128,13 +125,13 @@ class LockTests(unittest.TestCase):
self.assertTrue(t3.result()) self.assertTrue(t3.result())
def test_acquire_cancel(self): def test_acquire_cancel(self):
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
self.assertTrue(self.loop.run_until_complete(lock.acquire())) self.assertTrue(self.loop.run_until_complete(lock.acquire()))
task = tasks.Task(lock.acquire(), loop=self.loop) task = asyncio.Task(lock.acquire(), loop=self.loop)
self.loop.call_soon(task.cancel) self.loop.call_soon(task.cancel)
self.assertRaises( self.assertRaises(
futures.CancelledError, asyncio.CancelledError,
self.loop.run_until_complete, task) self.loop.run_until_complete, task)
self.assertFalse(lock._waiters) self.assertFalse(lock._waiters)
...@@ -153,9 +150,9 @@ class LockTests(unittest.TestCase): ...@@ -153,9 +150,9 @@ class LockTests(unittest.TestCase):
# B's waiter; instead, it should move on to C's waiter. # B's waiter; instead, it should move on to C's waiter.
# Setup: A has the lock, b and c are waiting. # Setup: A has the lock, b and c are waiting.
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def lockit(name, blocker): def lockit(name, blocker):
yield from lock.acquire() yield from lock.acquire()
try: try:
...@@ -164,14 +161,14 @@ class LockTests(unittest.TestCase): ...@@ -164,14 +161,14 @@ class LockTests(unittest.TestCase):
finally: finally:
lock.release() lock.release()
fa = futures.Future(loop=self.loop) fa = asyncio.Future(loop=self.loop)
ta = tasks.Task(lockit('A', fa), loop=self.loop) ta = asyncio.Task(lockit('A', fa), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertTrue(lock.locked()) self.assertTrue(lock.locked())
tb = tasks.Task(lockit('B', None), loop=self.loop) tb = asyncio.Task(lockit('B', None), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual(len(lock._waiters), 1) self.assertEqual(len(lock._waiters), 1)
tc = tasks.Task(lockit('C', None), loop=self.loop) tc = asyncio.Task(lockit('C', None), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual(len(lock._waiters), 2) self.assertEqual(len(lock._waiters), 2)
...@@ -187,12 +184,12 @@ class LockTests(unittest.TestCase): ...@@ -187,12 +184,12 @@ class LockTests(unittest.TestCase):
self.assertTrue(tc.done()) self.assertTrue(tc.done())
def test_release_not_acquired(self): def test_release_not_acquired(self):
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
self.assertRaises(RuntimeError, lock.release) self.assertRaises(RuntimeError, lock.release)
def test_release_no_waiters(self): def test_release_no_waiters(self):
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
self.loop.run_until_complete(lock.acquire()) self.loop.run_until_complete(lock.acquire())
self.assertTrue(lock.locked()) self.assertTrue(lock.locked())
...@@ -200,9 +197,9 @@ class LockTests(unittest.TestCase): ...@@ -200,9 +197,9 @@ class LockTests(unittest.TestCase):
self.assertFalse(lock.locked()) self.assertFalse(lock.locked())
def test_context_manager(self): def test_context_manager(self):
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def acquire_lock(): def acquire_lock():
return (yield from lock) return (yield from lock)
...@@ -212,7 +209,7 @@ class LockTests(unittest.TestCase): ...@@ -212,7 +209,7 @@ class LockTests(unittest.TestCase):
self.assertFalse(lock.locked()) self.assertFalse(lock.locked())
def test_context_manager_no_yield(self): def test_context_manager_no_yield(self):
lock = locks.Lock(loop=self.loop) lock = asyncio.Lock(loop=self.loop)
try: try:
with lock: with lock:
...@@ -227,29 +224,29 @@ class EventTests(unittest.TestCase): ...@@ -227,29 +224,29 @@ class EventTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = unittest.mock.Mock() loop = unittest.mock.Mock()
ev = locks.Event(loop=loop) ev = asyncio.Event(loop=loop)
self.assertIs(ev._loop, loop) self.assertIs(ev._loop, loop)
ev = locks.Event(loop=self.loop) ev = asyncio.Event(loop=self.loop)
self.assertIs(ev._loop, self.loop) self.assertIs(ev._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: try:
events.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
ev = locks.Event() ev = asyncio.Event()
self.assertIs(ev._loop, self.loop) self.assertIs(ev._loop, self.loop)
finally: finally:
events.set_event_loop(None) asyncio.set_event_loop(None)
def test_repr(self): def test_repr(self):
ev = locks.Event(loop=self.loop) ev = asyncio.Event(loop=self.loop)
self.assertTrue(repr(ev).endswith('[unset]>')) self.assertTrue(repr(ev).endswith('[unset]>'))
match = RGX_REPR.match(repr(ev)) match = RGX_REPR.match(repr(ev))
self.assertEqual(match.group('extras'), 'unset') self.assertEqual(match.group('extras'), 'unset')
...@@ -263,33 +260,33 @@ class EventTests(unittest.TestCase): ...@@ -263,33 +260,33 @@ class EventTests(unittest.TestCase):
self.assertTrue(RGX_REPR.match(repr(ev))) self.assertTrue(RGX_REPR.match(repr(ev)))
def test_wait(self): def test_wait(self):
ev = locks.Event(loop=self.loop) ev = asyncio.Event(loop=self.loop)
self.assertFalse(ev.is_set()) self.assertFalse(ev.is_set())
result = [] result = []
@tasks.coroutine @asyncio.coroutine
def c1(result): def c1(result):
if (yield from ev.wait()): if (yield from ev.wait()):
result.append(1) result.append(1)
@tasks.coroutine @asyncio.coroutine
def c2(result): def c2(result):
if (yield from ev.wait()): if (yield from ev.wait()):
result.append(2) result.append(2)
@tasks.coroutine @asyncio.coroutine
def c3(result): def c3(result):
if (yield from ev.wait()): if (yield from ev.wait()):
result.append(3) result.append(3)
t1 = tasks.Task(c1(result), loop=self.loop) t1 = asyncio.Task(c1(result), loop=self.loop)
t2 = tasks.Task(c2(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual([], result) self.assertEqual([], result)
t3 = tasks.Task(c3(result), loop=self.loop) t3 = asyncio.Task(c3(result), loop=self.loop)
ev.set() ev.set()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
...@@ -303,24 +300,24 @@ class EventTests(unittest.TestCase): ...@@ -303,24 +300,24 @@ class EventTests(unittest.TestCase):
self.assertIsNone(t3.result()) self.assertIsNone(t3.result())
def test_wait_on_set(self): def test_wait_on_set(self):
ev = locks.Event(loop=self.loop) ev = asyncio.Event(loop=self.loop)
ev.set() ev.set()
res = self.loop.run_until_complete(ev.wait()) res = self.loop.run_until_complete(ev.wait())
self.assertTrue(res) self.assertTrue(res)
def test_wait_cancel(self): def test_wait_cancel(self):
ev = locks.Event(loop=self.loop) ev = asyncio.Event(loop=self.loop)
wait = tasks.Task(ev.wait(), loop=self.loop) wait = asyncio.Task(ev.wait(), loop=self.loop)
self.loop.call_soon(wait.cancel) self.loop.call_soon(wait.cancel)
self.assertRaises( self.assertRaises(
futures.CancelledError, asyncio.CancelledError,
self.loop.run_until_complete, wait) self.loop.run_until_complete, wait)
self.assertFalse(ev._waiters) self.assertFalse(ev._waiters)
def test_clear(self): def test_clear(self):
ev = locks.Event(loop=self.loop) ev = asyncio.Event(loop=self.loop)
self.assertFalse(ev.is_set()) self.assertFalse(ev.is_set())
ev.set() ev.set()
...@@ -330,16 +327,16 @@ class EventTests(unittest.TestCase): ...@@ -330,16 +327,16 @@ class EventTests(unittest.TestCase):
self.assertFalse(ev.is_set()) self.assertFalse(ev.is_set())
def test_clear_with_waiters(self): def test_clear_with_waiters(self):
ev = locks.Event(loop=self.loop) ev = asyncio.Event(loop=self.loop)
result = [] result = []
@tasks.coroutine @asyncio.coroutine
def c1(result): def c1(result):
if (yield from ev.wait()): if (yield from ev.wait()):
result.append(1) result.append(1)
return True return True
t = tasks.Task(c1(result), loop=self.loop) t = asyncio.Task(c1(result), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual([], result) self.assertEqual([], result)
...@@ -363,55 +360,55 @@ class ConditionTests(unittest.TestCase): ...@@ -363,55 +360,55 @@ class ConditionTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = unittest.mock.Mock() loop = unittest.mock.Mock()
cond = locks.Condition(loop=loop) cond = asyncio.Condition(loop=loop)
self.assertIs(cond._loop, loop) self.assertIs(cond._loop, loop)
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
self.assertIs(cond._loop, self.loop) self.assertIs(cond._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: try:
events.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
cond = locks.Condition() cond = asyncio.Condition()
self.assertIs(cond._loop, self.loop) self.assertIs(cond._loop, self.loop)
finally: finally:
events.set_event_loop(None) asyncio.set_event_loop(None)
def test_wait(self): def test_wait(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
result = [] result = []
@tasks.coroutine @asyncio.coroutine
def c1(result): def c1(result):
yield from cond.acquire() yield from cond.acquire()
if (yield from cond.wait()): if (yield from cond.wait()):
result.append(1) result.append(1)
return True return True
@tasks.coroutine @asyncio.coroutine
def c2(result): def c2(result):
yield from cond.acquire() yield from cond.acquire()
if (yield from cond.wait()): if (yield from cond.wait()):
result.append(2) result.append(2)
return True return True
@tasks.coroutine @asyncio.coroutine
def c3(result): def c3(result):
yield from cond.acquire() yield from cond.acquire()
if (yield from cond.wait()): if (yield from cond.wait()):
result.append(3) result.append(3)
return True return True
t1 = tasks.Task(c1(result), loop=self.loop) t1 = asyncio.Task(c1(result), loop=self.loop)
t2 = tasks.Task(c2(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop)
t3 = tasks.Task(c3(result), loop=self.loop) t3 = asyncio.Task(c3(result), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual([], result) self.assertEqual([], result)
...@@ -451,25 +448,25 @@ class ConditionTests(unittest.TestCase): ...@@ -451,25 +448,25 @@ class ConditionTests(unittest.TestCase):
self.assertTrue(t3.result()) self.assertTrue(t3.result())
def test_wait_cancel(self): def test_wait_cancel(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
self.loop.run_until_complete(cond.acquire()) self.loop.run_until_complete(cond.acquire())
wait = tasks.Task(cond.wait(), loop=self.loop) wait = asyncio.Task(cond.wait(), loop=self.loop)
self.loop.call_soon(wait.cancel) self.loop.call_soon(wait.cancel)
self.assertRaises( self.assertRaises(
futures.CancelledError, asyncio.CancelledError,
self.loop.run_until_complete, wait) self.loop.run_until_complete, wait)
self.assertFalse(cond._waiters) self.assertFalse(cond._waiters)
self.assertTrue(cond.locked()) self.assertTrue(cond.locked())
def test_wait_unacquired(self): def test_wait_unacquired(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
self.assertRaises( self.assertRaises(
RuntimeError, RuntimeError,
self.loop.run_until_complete, cond.wait()) self.loop.run_until_complete, cond.wait())
def test_wait_for(self): def test_wait_for(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
presult = False presult = False
def predicate(): def predicate():
...@@ -477,7 +474,7 @@ class ConditionTests(unittest.TestCase): ...@@ -477,7 +474,7 @@ class ConditionTests(unittest.TestCase):
result = [] result = []
@tasks.coroutine @asyncio.coroutine
def c1(result): def c1(result):
yield from cond.acquire() yield from cond.acquire()
if (yield from cond.wait_for(predicate)): if (yield from cond.wait_for(predicate)):
...@@ -485,7 +482,7 @@ class ConditionTests(unittest.TestCase): ...@@ -485,7 +482,7 @@ class ConditionTests(unittest.TestCase):
cond.release() cond.release()
return True return True
t = tasks.Task(c1(result), loop=self.loop) t = asyncio.Task(c1(result), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual([], result) self.assertEqual([], result)
...@@ -507,7 +504,7 @@ class ConditionTests(unittest.TestCase): ...@@ -507,7 +504,7 @@ class ConditionTests(unittest.TestCase):
self.assertTrue(t.result()) self.assertTrue(t.result())
def test_wait_for_unacquired(self): def test_wait_for_unacquired(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
# predicate can return true immediately # predicate can return true immediately
res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3]))
...@@ -519,10 +516,10 @@ class ConditionTests(unittest.TestCase): ...@@ -519,10 +516,10 @@ class ConditionTests(unittest.TestCase):
cond.wait_for(lambda: False)) cond.wait_for(lambda: False))
def test_notify(self): def test_notify(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
result = [] result = []
@tasks.coroutine @asyncio.coroutine
def c1(result): def c1(result):
yield from cond.acquire() yield from cond.acquire()
if (yield from cond.wait()): if (yield from cond.wait()):
...@@ -530,7 +527,7 @@ class ConditionTests(unittest.TestCase): ...@@ -530,7 +527,7 @@ class ConditionTests(unittest.TestCase):
cond.release() cond.release()
return True return True
@tasks.coroutine @asyncio.coroutine
def c2(result): def c2(result):
yield from cond.acquire() yield from cond.acquire()
if (yield from cond.wait()): if (yield from cond.wait()):
...@@ -538,7 +535,7 @@ class ConditionTests(unittest.TestCase): ...@@ -538,7 +535,7 @@ class ConditionTests(unittest.TestCase):
cond.release() cond.release()
return True return True
@tasks.coroutine @asyncio.coroutine
def c3(result): def c3(result):
yield from cond.acquire() yield from cond.acquire()
if (yield from cond.wait()): if (yield from cond.wait()):
...@@ -546,9 +543,9 @@ class ConditionTests(unittest.TestCase): ...@@ -546,9 +543,9 @@ class ConditionTests(unittest.TestCase):
cond.release() cond.release()
return True return True
t1 = tasks.Task(c1(result), loop=self.loop) t1 = asyncio.Task(c1(result), loop=self.loop)
t2 = tasks.Task(c2(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop)
t3 = tasks.Task(c3(result), loop=self.loop) t3 = asyncio.Task(c3(result), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual([], result) self.assertEqual([], result)
...@@ -574,11 +571,11 @@ class ConditionTests(unittest.TestCase): ...@@ -574,11 +571,11 @@ class ConditionTests(unittest.TestCase):
self.assertTrue(t3.result()) self.assertTrue(t3.result())
def test_notify_all(self): def test_notify_all(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
result = [] result = []
@tasks.coroutine @asyncio.coroutine
def c1(result): def c1(result):
yield from cond.acquire() yield from cond.acquire()
if (yield from cond.wait()): if (yield from cond.wait()):
...@@ -586,7 +583,7 @@ class ConditionTests(unittest.TestCase): ...@@ -586,7 +583,7 @@ class ConditionTests(unittest.TestCase):
cond.release() cond.release()
return True return True
@tasks.coroutine @asyncio.coroutine
def c2(result): def c2(result):
yield from cond.acquire() yield from cond.acquire()
if (yield from cond.wait()): if (yield from cond.wait()):
...@@ -594,8 +591,8 @@ class ConditionTests(unittest.TestCase): ...@@ -594,8 +591,8 @@ class ConditionTests(unittest.TestCase):
cond.release() cond.release()
return True return True
t1 = tasks.Task(c1(result), loop=self.loop) t1 = asyncio.Task(c1(result), loop=self.loop)
t2 = tasks.Task(c2(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual([], result) self.assertEqual([], result)
...@@ -612,15 +609,15 @@ class ConditionTests(unittest.TestCase): ...@@ -612,15 +609,15 @@ class ConditionTests(unittest.TestCase):
self.assertTrue(t2.result()) self.assertTrue(t2.result())
def test_notify_unacquired(self): def test_notify_unacquired(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
self.assertRaises(RuntimeError, cond.notify) self.assertRaises(RuntimeError, cond.notify)
def test_notify_all_unacquired(self): def test_notify_all_unacquired(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
self.assertRaises(RuntimeError, cond.notify_all) self.assertRaises(RuntimeError, cond.notify_all)
def test_repr(self): def test_repr(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
self.assertTrue('unlocked' in repr(cond)) self.assertTrue('unlocked' in repr(cond))
self.assertTrue(RGX_REPR.match(repr(cond))) self.assertTrue(RGX_REPR.match(repr(cond)))
...@@ -636,9 +633,9 @@ class ConditionTests(unittest.TestCase): ...@@ -636,9 +633,9 @@ class ConditionTests(unittest.TestCase):
self.assertTrue(RGX_REPR.match(repr(cond))) self.assertTrue(RGX_REPR.match(repr(cond)))
def test_context_manager(self): def test_context_manager(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def acquire_cond(): def acquire_cond():
return (yield from cond) return (yield from cond)
...@@ -648,7 +645,7 @@ class ConditionTests(unittest.TestCase): ...@@ -648,7 +645,7 @@ class ConditionTests(unittest.TestCase):
self.assertFalse(cond.locked()) self.assertFalse(cond.locked())
def test_context_manager_no_yield(self): def test_context_manager_no_yield(self):
cond = locks.Condition(loop=self.loop) cond = asyncio.Condition(loop=self.loop)
try: try:
with cond: with cond:
...@@ -663,33 +660,33 @@ class SemaphoreTests(unittest.TestCase): ...@@ -663,33 +660,33 @@ class SemaphoreTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
def test_ctor_loop(self): def test_ctor_loop(self):
loop = unittest.mock.Mock() loop = unittest.mock.Mock()
sem = locks.Semaphore(loop=loop) sem = asyncio.Semaphore(loop=loop)
self.assertIs(sem._loop, loop) self.assertIs(sem._loop, loop)
sem = locks.Semaphore(loop=self.loop) sem = asyncio.Semaphore(loop=self.loop)
self.assertIs(sem._loop, self.loop) self.assertIs(sem._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: try:
events.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
sem = locks.Semaphore() sem = asyncio.Semaphore()
self.assertIs(sem._loop, self.loop) self.assertIs(sem._loop, self.loop)
finally: finally:
events.set_event_loop(None) asyncio.set_event_loop(None)
def test_initial_value_zero(self): def test_initial_value_zero(self):
sem = locks.Semaphore(0, loop=self.loop) sem = asyncio.Semaphore(0, loop=self.loop)
self.assertTrue(sem.locked()) self.assertTrue(sem.locked())
def test_repr(self): def test_repr(self):
sem = locks.Semaphore(loop=self.loop) sem = asyncio.Semaphore(loop=self.loop)
self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) self.assertTrue(repr(sem).endswith('[unlocked,value:1]>'))
self.assertTrue(RGX_REPR.match(repr(sem))) self.assertTrue(RGX_REPR.match(repr(sem)))
...@@ -707,10 +704,10 @@ class SemaphoreTests(unittest.TestCase): ...@@ -707,10 +704,10 @@ class SemaphoreTests(unittest.TestCase):
self.assertTrue(RGX_REPR.match(repr(sem))) self.assertTrue(RGX_REPR.match(repr(sem)))
def test_semaphore(self): def test_semaphore(self):
sem = locks.Semaphore(loop=self.loop) sem = asyncio.Semaphore(loop=self.loop)
self.assertEqual(1, sem._value) self.assertEqual(1, sem._value)
@tasks.coroutine @asyncio.coroutine
def acquire_lock(): def acquire_lock():
return (yield from sem) return (yield from sem)
...@@ -725,43 +722,43 @@ class SemaphoreTests(unittest.TestCase): ...@@ -725,43 +722,43 @@ class SemaphoreTests(unittest.TestCase):
self.assertEqual(1, sem._value) self.assertEqual(1, sem._value)
def test_semaphore_value(self): def test_semaphore_value(self):
self.assertRaises(ValueError, locks.Semaphore, -1) self.assertRaises(ValueError, asyncio.Semaphore, -1)
def test_acquire(self): def test_acquire(self):
sem = locks.Semaphore(3, loop=self.loop) sem = asyncio.Semaphore(3, loop=self.loop)
result = [] result = []
self.assertTrue(self.loop.run_until_complete(sem.acquire())) self.assertTrue(self.loop.run_until_complete(sem.acquire()))
self.assertTrue(self.loop.run_until_complete(sem.acquire())) self.assertTrue(self.loop.run_until_complete(sem.acquire()))
self.assertFalse(sem.locked()) self.assertFalse(sem.locked())
@tasks.coroutine @asyncio.coroutine
def c1(result): def c1(result):
yield from sem.acquire() yield from sem.acquire()
result.append(1) result.append(1)
return True return True
@tasks.coroutine @asyncio.coroutine
def c2(result): def c2(result):
yield from sem.acquire() yield from sem.acquire()
result.append(2) result.append(2)
return True return True
@tasks.coroutine @asyncio.coroutine
def c3(result): def c3(result):
yield from sem.acquire() yield from sem.acquire()
result.append(3) result.append(3)
return True return True
@tasks.coroutine @asyncio.coroutine
def c4(result): def c4(result):
yield from sem.acquire() yield from sem.acquire()
result.append(4) result.append(4)
return True return True
t1 = tasks.Task(c1(result), loop=self.loop) t1 = asyncio.Task(c1(result), loop=self.loop)
t2 = tasks.Task(c2(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop)
t3 = tasks.Task(c3(result), loop=self.loop) t3 = asyncio.Task(c3(result), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual([1], result) self.assertEqual([1], result)
...@@ -769,7 +766,7 @@ class SemaphoreTests(unittest.TestCase): ...@@ -769,7 +766,7 @@ class SemaphoreTests(unittest.TestCase):
self.assertEqual(2, len(sem._waiters)) self.assertEqual(2, len(sem._waiters))
self.assertEqual(0, sem._value) self.assertEqual(0, sem._value)
t4 = tasks.Task(c4(result), loop=self.loop) t4 = asyncio.Task(c4(result), loop=self.loop)
sem.release() sem.release()
sem.release() sem.release()
...@@ -794,23 +791,23 @@ class SemaphoreTests(unittest.TestCase): ...@@ -794,23 +791,23 @@ class SemaphoreTests(unittest.TestCase):
sem.release() sem.release()
def test_acquire_cancel(self): def test_acquire_cancel(self):
sem = locks.Semaphore(loop=self.loop) sem = asyncio.Semaphore(loop=self.loop)
self.loop.run_until_complete(sem.acquire()) self.loop.run_until_complete(sem.acquire())
acquire = tasks.Task(sem.acquire(), loop=self.loop) acquire = asyncio.Task(sem.acquire(), loop=self.loop)
self.loop.call_soon(acquire.cancel) self.loop.call_soon(acquire.cancel)
self.assertRaises( self.assertRaises(
futures.CancelledError, asyncio.CancelledError,
self.loop.run_until_complete, acquire) self.loop.run_until_complete, acquire)
self.assertFalse(sem._waiters) self.assertFalse(sem._waiters)
def test_release_not_acquired(self): def test_release_not_acquired(self):
sem = locks.BoundedSemaphore(loop=self.loop) sem = asyncio.BoundedSemaphore(loop=self.loop)
self.assertRaises(ValueError, sem.release) self.assertRaises(ValueError, sem.release)
def test_release_no_waiters(self): def test_release_no_waiters(self):
sem = locks.Semaphore(loop=self.loop) sem = asyncio.Semaphore(loop=self.loop)
self.loop.run_until_complete(sem.acquire()) self.loop.run_until_complete(sem.acquire())
self.assertTrue(sem.locked()) self.assertTrue(sem.locked())
...@@ -818,9 +815,9 @@ class SemaphoreTests(unittest.TestCase): ...@@ -818,9 +815,9 @@ class SemaphoreTests(unittest.TestCase):
self.assertFalse(sem.locked()) self.assertFalse(sem.locked())
def test_context_manager(self): def test_context_manager(self):
sem = locks.Semaphore(2, loop=self.loop) sem = asyncio.Semaphore(2, loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def acquire_lock(): def acquire_lock():
return (yield from sem) return (yield from sem)
......
...@@ -5,7 +5,6 @@ import unittest ...@@ -5,7 +5,6 @@ import unittest
import unittest.mock import unittest.mock
import asyncio import asyncio
from asyncio.proactor_events import BaseProactorEventLoop
from asyncio.proactor_events import _ProactorSocketTransport from asyncio.proactor_events import _ProactorSocketTransport
from asyncio.proactor_events import _ProactorWritePipeTransport from asyncio.proactor_events import _ProactorWritePipeTransport
from asyncio.proactor_events import _ProactorDuplexPipeTransport from asyncio.proactor_events import _ProactorDuplexPipeTransport
...@@ -345,18 +344,18 @@ class BaseProactorEventLoopTests(unittest.TestCase): ...@@ -345,18 +344,18 @@ class BaseProactorEventLoopTests(unittest.TestCase):
self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock()
class EventLoop(BaseProactorEventLoop): class EventLoop(asyncio.BaseProactorEventLoop):
def _socketpair(s): def _socketpair(s):
return (self.ssock, self.csock) return (self.ssock, self.csock)
self.loop = EventLoop(self.proactor) self.loop = EventLoop(self.proactor)
@unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') @unittest.mock.patch.object(asyncio.BaseProactorEventLoop, 'call_soon')
@unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') @unittest.mock.patch.object(asyncio.BaseProactorEventLoop, '_socketpair')
def test_ctor(self, socketpair, call_soon): def test_ctor(self, socketpair, call_soon):
ssock, csock = socketpair.return_value = ( ssock, csock = socketpair.return_value = (
unittest.mock.Mock(), unittest.mock.Mock()) unittest.mock.Mock(), unittest.mock.Mock())
loop = BaseProactorEventLoop(self.proactor) loop = asyncio.BaseProactorEventLoop(self.proactor)
self.assertIs(loop._ssock, ssock) self.assertIs(loop._ssock, ssock)
self.assertIs(loop._csock, csock) self.assertIs(loop._csock, csock)
self.assertEqual(loop._internal_fds, 1) self.assertEqual(loop._internal_fds, 1)
...@@ -399,7 +398,7 @@ class BaseProactorEventLoopTests(unittest.TestCase): ...@@ -399,7 +398,7 @@ class BaseProactorEventLoopTests(unittest.TestCase):
def test_socketpair(self): def test_socketpair(self):
self.assertRaises( self.assertRaises(
NotImplementedError, BaseProactorEventLoop, self.proactor) NotImplementedError, asyncio.BaseProactorEventLoop, self.proactor)
def test_make_socket_transport(self): def test_make_socket_transport(self):
tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock())
......
...@@ -3,11 +3,7 @@ ...@@ -3,11 +3,7 @@
import unittest import unittest
import unittest.mock import unittest.mock
from asyncio import events import asyncio
from asyncio import futures
from asyncio import locks
from asyncio import queues
from asyncio import tasks
from asyncio import test_utils from asyncio import test_utils
...@@ -15,7 +11,7 @@ class _QueueTestBase(unittest.TestCase): ...@@ -15,7 +11,7 @@ class _QueueTestBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
...@@ -39,57 +35,57 @@ class QueueBasicTests(_QueueTestBase): ...@@ -39,57 +35,57 @@ class QueueBasicTests(_QueueTestBase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
q = queues.Queue(loop=loop) q = asyncio.Queue(loop=loop)
self.assertTrue(fn(q).startswith('<Queue'), fn(q)) self.assertTrue(fn(q).startswith('<Queue'), fn(q))
id_is_present = hex(id(q)) in fn(q) id_is_present = hex(id(q)) in fn(q)
self.assertEqual(expect_id, id_is_present) self.assertEqual(expect_id, id_is_present)
@tasks.coroutine @asyncio.coroutine
def add_getter(): def add_getter():
q = queues.Queue(loop=loop) q = asyncio.Queue(loop=loop)
# Start a task that waits to get. # Start a task that waits to get.
tasks.Task(q.get(), loop=loop) asyncio.Task(q.get(), loop=loop)
# Let it start waiting. # Let it start waiting.
yield from tasks.sleep(0.1, loop=loop) yield from asyncio.sleep(0.1, loop=loop)
self.assertTrue('_getters[1]' in fn(q)) self.assertTrue('_getters[1]' in fn(q))
# resume q.get coroutine to finish generator # resume q.get coroutine to finish generator
q.put_nowait(0) q.put_nowait(0)
loop.run_until_complete(add_getter()) loop.run_until_complete(add_getter())
@tasks.coroutine @asyncio.coroutine
def add_putter(): def add_putter():
q = queues.Queue(maxsize=1, loop=loop) q = asyncio.Queue(maxsize=1, loop=loop)
q.put_nowait(1) q.put_nowait(1)
# Start a task that waits to put. # Start a task that waits to put.
tasks.Task(q.put(2), loop=loop) asyncio.Task(q.put(2), loop=loop)
# Let it start waiting. # Let it start waiting.
yield from tasks.sleep(0.1, loop=loop) yield from asyncio.sleep(0.1, loop=loop)
self.assertTrue('_putters[1]' in fn(q)) self.assertTrue('_putters[1]' in fn(q))
# resume q.put coroutine to finish generator # resume q.put coroutine to finish generator
q.get_nowait() q.get_nowait()
loop.run_until_complete(add_putter()) loop.run_until_complete(add_putter())
q = queues.Queue(loop=loop) q = asyncio.Queue(loop=loop)
q.put_nowait(1) q.put_nowait(1)
self.assertTrue('_queue=[1]' in fn(q)) self.assertTrue('_queue=[1]' in fn(q))
def test_ctor_loop(self): def test_ctor_loop(self):
loop = unittest.mock.Mock() loop = unittest.mock.Mock()
q = queues.Queue(loop=loop) q = asyncio.Queue(loop=loop)
self.assertIs(q._loop, loop) self.assertIs(q._loop, loop)
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
self.assertIs(q._loop, self.loop) self.assertIs(q._loop, self.loop)
def test_ctor_noloop(self): def test_ctor_noloop(self):
try: try:
events.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
q = queues.Queue() q = asyncio.Queue()
self.assertIs(q._loop, self.loop) self.assertIs(q._loop, self.loop)
finally: finally:
events.set_event_loop(None) asyncio.set_event_loop(None)
def test_repr(self): def test_repr(self):
self._test_repr_or_str(repr, True) self._test_repr_or_str(repr, True)
...@@ -98,7 +94,7 @@ class QueueBasicTests(_QueueTestBase): ...@@ -98,7 +94,7 @@ class QueueBasicTests(_QueueTestBase):
self._test_repr_or_str(str, False) self._test_repr_or_str(str, False)
def test_empty(self): def test_empty(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
self.assertTrue(q.empty()) self.assertTrue(q.empty())
q.put_nowait(1) q.put_nowait(1)
self.assertFalse(q.empty()) self.assertFalse(q.empty())
...@@ -106,15 +102,15 @@ class QueueBasicTests(_QueueTestBase): ...@@ -106,15 +102,15 @@ class QueueBasicTests(_QueueTestBase):
self.assertTrue(q.empty()) self.assertTrue(q.empty())
def test_full(self): def test_full(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
self.assertFalse(q.full()) self.assertFalse(q.full())
q = queues.Queue(maxsize=1, loop=self.loop) q = asyncio.Queue(maxsize=1, loop=self.loop)
q.put_nowait(1) q.put_nowait(1)
self.assertTrue(q.full()) self.assertTrue(q.full())
def test_order(self): def test_order(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
for i in [1, 3, 2]: for i in [1, 3, 2]:
q.put_nowait(i) q.put_nowait(i)
...@@ -133,28 +129,28 @@ class QueueBasicTests(_QueueTestBase): ...@@ -133,28 +129,28 @@ class QueueBasicTests(_QueueTestBase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
q = queues.Queue(maxsize=2, loop=loop) q = asyncio.Queue(maxsize=2, loop=loop)
self.assertEqual(2, q.maxsize) self.assertEqual(2, q.maxsize)
have_been_put = [] have_been_put = []
@tasks.coroutine @asyncio.coroutine
def putter(): def putter():
for i in range(3): for i in range(3):
yield from q.put(i) yield from q.put(i)
have_been_put.append(i) have_been_put.append(i)
return True return True
@tasks.coroutine @asyncio.coroutine
def test(): def test():
t = tasks.Task(putter(), loop=loop) t = asyncio.Task(putter(), loop=loop)
yield from tasks.sleep(0.01, loop=loop) yield from asyncio.sleep(0.01, loop=loop)
# The putter is blocked after putting two items. # The putter is blocked after putting two items.
self.assertEqual([0, 1], have_been_put) self.assertEqual([0, 1], have_been_put)
self.assertEqual(0, q.get_nowait()) self.assertEqual(0, q.get_nowait())
# Let the putter resume and put last item. # Let the putter resume and put last item.
yield from tasks.sleep(0.01, loop=loop) yield from asyncio.sleep(0.01, loop=loop)
self.assertEqual([0, 1, 2], have_been_put) self.assertEqual([0, 1, 2], have_been_put)
self.assertEqual(1, q.get_nowait()) self.assertEqual(1, q.get_nowait())
self.assertEqual(2, q.get_nowait()) self.assertEqual(2, q.get_nowait())
...@@ -169,10 +165,10 @@ class QueueBasicTests(_QueueTestBase): ...@@ -169,10 +165,10 @@ class QueueBasicTests(_QueueTestBase):
class QueueGetTests(_QueueTestBase): class QueueGetTests(_QueueTestBase):
def test_blocking_get(self): def test_blocking_get(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
q.put_nowait(1) q.put_nowait(1)
@tasks.coroutine @asyncio.coroutine
def queue_get(): def queue_get():
return (yield from q.get()) return (yield from q.get())
...@@ -180,10 +176,10 @@ class QueueGetTests(_QueueTestBase): ...@@ -180,10 +176,10 @@ class QueueGetTests(_QueueTestBase):
self.assertEqual(1, res) self.assertEqual(1, res)
def test_get_with_putters(self): def test_get_with_putters(self):
q = queues.Queue(1, loop=self.loop) q = asyncio.Queue(1, loop=self.loop)
q.put_nowait(1) q.put_nowait(1)
waiter = futures.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
q._putters.append((2, waiter)) q._putters.append((2, waiter))
res = self.loop.run_until_complete(q.get()) res = self.loop.run_until_complete(q.get())
...@@ -201,11 +197,11 @@ class QueueGetTests(_QueueTestBase): ...@@ -201,11 +197,11 @@ class QueueGetTests(_QueueTestBase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
q = queues.Queue(loop=loop) q = asyncio.Queue(loop=loop)
started = locks.Event(loop=loop) started = asyncio.Event(loop=loop)
finished = False finished = False
@tasks.coroutine @asyncio.coroutine
def queue_get(): def queue_get():
nonlocal finished nonlocal finished
started.set() started.set()
...@@ -213,10 +209,10 @@ class QueueGetTests(_QueueTestBase): ...@@ -213,10 +209,10 @@ class QueueGetTests(_QueueTestBase):
finished = True finished = True
return res return res
@tasks.coroutine @asyncio.coroutine
def queue_put(): def queue_put():
loop.call_later(0.01, q.put_nowait, 1) loop.call_later(0.01, q.put_nowait, 1)
queue_get_task = tasks.Task(queue_get(), loop=loop) queue_get_task = asyncio.Task(queue_get(), loop=loop)
yield from started.wait() yield from started.wait()
self.assertFalse(finished) self.assertFalse(finished)
res = yield from queue_get_task res = yield from queue_get_task
...@@ -228,13 +224,13 @@ class QueueGetTests(_QueueTestBase): ...@@ -228,13 +224,13 @@ class QueueGetTests(_QueueTestBase):
self.assertAlmostEqual(0.01, loop.time()) self.assertAlmostEqual(0.01, loop.time())
def test_nonblocking_get(self): def test_nonblocking_get(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
q.put_nowait(1) q.put_nowait(1)
self.assertEqual(1, q.get_nowait()) self.assertEqual(1, q.get_nowait())
def test_nonblocking_get_exception(self): def test_nonblocking_get_exception(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
self.assertRaises(queues.Empty, q.get_nowait) self.assertRaises(asyncio.Empty, q.get_nowait)
def test_get_cancelled(self): def test_get_cancelled(self):
...@@ -248,16 +244,16 @@ class QueueGetTests(_QueueTestBase): ...@@ -248,16 +244,16 @@ class QueueGetTests(_QueueTestBase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
q = queues.Queue(loop=loop) q = asyncio.Queue(loop=loop)
@tasks.coroutine @asyncio.coroutine
def queue_get(): def queue_get():
return (yield from tasks.wait_for(q.get(), 0.051, loop=loop)) return (yield from asyncio.wait_for(q.get(), 0.051, loop=loop))
@tasks.coroutine @asyncio.coroutine
def test(): def test():
get_task = tasks.Task(queue_get(), loop=loop) get_task = asyncio.Task(queue_get(), loop=loop)
yield from tasks.sleep(0.01, loop=loop) # let the task start yield from asyncio.sleep(0.01, loop=loop) # let the task start
q.put_nowait(1) q.put_nowait(1)
return (yield from get_task) return (yield from get_task)
...@@ -265,10 +261,10 @@ class QueueGetTests(_QueueTestBase): ...@@ -265,10 +261,10 @@ class QueueGetTests(_QueueTestBase):
self.assertAlmostEqual(0.06, loop.time()) self.assertAlmostEqual(0.06, loop.time())
def test_get_cancelled_race(self): def test_get_cancelled_race(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
t1 = tasks.Task(q.get(), loop=self.loop) t1 = asyncio.Task(q.get(), loop=self.loop)
t2 = tasks.Task(q.get(), loop=self.loop) t2 = asyncio.Task(q.get(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
t1.cancel() t1.cancel()
...@@ -279,9 +275,9 @@ class QueueGetTests(_QueueTestBase): ...@@ -279,9 +275,9 @@ class QueueGetTests(_QueueTestBase):
self.assertEqual(t2.result(), 'a') self.assertEqual(t2.result(), 'a')
def test_get_with_waiting_putters(self): def test_get_with_waiting_putters(self):
q = queues.Queue(loop=self.loop, maxsize=1) q = asyncio.Queue(loop=self.loop, maxsize=1)
tasks.Task(q.put('a'), loop=self.loop) asyncio.Task(q.put('a'), loop=self.loop)
tasks.Task(q.put('b'), loop=self.loop) asyncio.Task(q.put('b'), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual(self.loop.run_until_complete(q.get()), 'a') self.assertEqual(self.loop.run_until_complete(q.get()), 'a')
self.assertEqual(self.loop.run_until_complete(q.get()), 'b') self.assertEqual(self.loop.run_until_complete(q.get()), 'b')
...@@ -290,9 +286,9 @@ class QueueGetTests(_QueueTestBase): ...@@ -290,9 +286,9 @@ class QueueGetTests(_QueueTestBase):
class QueuePutTests(_QueueTestBase): class QueuePutTests(_QueueTestBase):
def test_blocking_put(self): def test_blocking_put(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def queue_put(): def queue_put():
# No maxsize, won't block. # No maxsize, won't block.
yield from q.put(1) yield from q.put(1)
...@@ -309,11 +305,11 @@ class QueuePutTests(_QueueTestBase): ...@@ -309,11 +305,11 @@ class QueuePutTests(_QueueTestBase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
q = queues.Queue(maxsize=1, loop=loop) q = asyncio.Queue(maxsize=1, loop=loop)
started = locks.Event(loop=loop) started = asyncio.Event(loop=loop)
finished = False finished = False
@tasks.coroutine @asyncio.coroutine
def queue_put(): def queue_put():
nonlocal finished nonlocal finished
started.set() started.set()
...@@ -321,10 +317,10 @@ class QueuePutTests(_QueueTestBase): ...@@ -321,10 +317,10 @@ class QueuePutTests(_QueueTestBase):
yield from q.put(2) yield from q.put(2)
finished = True finished = True
@tasks.coroutine @asyncio.coroutine
def queue_get(): def queue_get():
loop.call_later(0.01, q.get_nowait) loop.call_later(0.01, q.get_nowait)
queue_put_task = tasks.Task(queue_put(), loop=loop) queue_put_task = asyncio.Task(queue_put(), loop=loop)
yield from started.wait() yield from started.wait()
self.assertFalse(finished) self.assertFalse(finished)
yield from queue_put_task yield from queue_put_task
...@@ -334,38 +330,38 @@ class QueuePutTests(_QueueTestBase): ...@@ -334,38 +330,38 @@ class QueuePutTests(_QueueTestBase):
self.assertAlmostEqual(0.01, loop.time()) self.assertAlmostEqual(0.01, loop.time())
def test_nonblocking_put(self): def test_nonblocking_put(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
q.put_nowait(1) q.put_nowait(1)
self.assertEqual(1, q.get_nowait()) self.assertEqual(1, q.get_nowait())
def test_nonblocking_put_exception(self): def test_nonblocking_put_exception(self):
q = queues.Queue(maxsize=1, loop=self.loop) q = asyncio.Queue(maxsize=1, loop=self.loop)
q.put_nowait(1) q.put_nowait(1)
self.assertRaises(queues.Full, q.put_nowait, 2) self.assertRaises(asyncio.Full, q.put_nowait, 2)
def test_put_cancelled(self): def test_put_cancelled(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def queue_put(): def queue_put():
yield from q.put(1) yield from q.put(1)
return True return True
@tasks.coroutine @asyncio.coroutine
def test(): def test():
return (yield from q.get()) return (yield from q.get())
t = tasks.Task(queue_put(), loop=self.loop) t = asyncio.Task(queue_put(), loop=self.loop)
self.assertEqual(1, self.loop.run_until_complete(test())) self.assertEqual(1, self.loop.run_until_complete(test()))
self.assertTrue(t.done()) self.assertTrue(t.done())
self.assertTrue(t.result()) self.assertTrue(t.result())
def test_put_cancelled_race(self): def test_put_cancelled_race(self):
q = queues.Queue(loop=self.loop, maxsize=1) q = asyncio.Queue(loop=self.loop, maxsize=1)
tasks.Task(q.put('a'), loop=self.loop) asyncio.Task(q.put('a'), loop=self.loop)
tasks.Task(q.put('c'), loop=self.loop) asyncio.Task(q.put('c'), loop=self.loop)
t = tasks.Task(q.put('b'), loop=self.loop) t = asyncio.Task(q.put('b'), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
t.cancel() t.cancel()
...@@ -375,8 +371,8 @@ class QueuePutTests(_QueueTestBase): ...@@ -375,8 +371,8 @@ class QueuePutTests(_QueueTestBase):
self.assertEqual(q.get_nowait(), 'c') self.assertEqual(q.get_nowait(), 'c')
def test_put_with_waiting_getters(self): def test_put_with_waiting_getters(self):
q = queues.Queue(loop=self.loop) q = asyncio.Queue(loop=self.loop)
t = tasks.Task(q.get(), loop=self.loop) t = asyncio.Task(q.get(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.loop.run_until_complete(q.put('a')) self.loop.run_until_complete(q.put('a'))
self.assertEqual(self.loop.run_until_complete(t), 'a') self.assertEqual(self.loop.run_until_complete(t), 'a')
...@@ -385,7 +381,7 @@ class QueuePutTests(_QueueTestBase): ...@@ -385,7 +381,7 @@ class QueuePutTests(_QueueTestBase):
class LifoQueueTests(_QueueTestBase): class LifoQueueTests(_QueueTestBase):
def test_order(self): def test_order(self):
q = queues.LifoQueue(loop=self.loop) q = asyncio.LifoQueue(loop=self.loop)
for i in [1, 3, 2]: for i in [1, 3, 2]:
q.put_nowait(i) q.put_nowait(i)
...@@ -396,7 +392,7 @@ class LifoQueueTests(_QueueTestBase): ...@@ -396,7 +392,7 @@ class LifoQueueTests(_QueueTestBase):
class PriorityQueueTests(_QueueTestBase): class PriorityQueueTests(_QueueTestBase):
def test_order(self): def test_order(self):
q = queues.PriorityQueue(loop=self.loop) q = asyncio.PriorityQueue(loop=self.loop)
for i in [1, 3, 2]: for i in [1, 3, 2]:
q.put_nowait(i) q.put_nowait(i)
...@@ -407,11 +403,11 @@ class PriorityQueueTests(_QueueTestBase): ...@@ -407,11 +403,11 @@ class PriorityQueueTests(_QueueTestBase):
class JoinableQueueTests(_QueueTestBase): class JoinableQueueTests(_QueueTestBase):
def test_task_done_underflow(self): def test_task_done_underflow(self):
q = queues.JoinableQueue(loop=self.loop) q = asyncio.JoinableQueue(loop=self.loop)
self.assertRaises(ValueError, q.task_done) self.assertRaises(ValueError, q.task_done)
def test_task_done(self): def test_task_done(self):
q = queues.JoinableQueue(loop=self.loop) q = asyncio.JoinableQueue(loop=self.loop)
for i in range(100): for i in range(100):
q.put_nowait(i) q.put_nowait(i)
...@@ -421,7 +417,7 @@ class JoinableQueueTests(_QueueTestBase): ...@@ -421,7 +417,7 @@ class JoinableQueueTests(_QueueTestBase):
# Join the queue and assert all items have been processed. # Join the queue and assert all items have been processed.
running = True running = True
@tasks.coroutine @asyncio.coroutine
def worker(): def worker():
nonlocal accumulator nonlocal accumulator
...@@ -430,10 +426,10 @@ class JoinableQueueTests(_QueueTestBase): ...@@ -430,10 +426,10 @@ class JoinableQueueTests(_QueueTestBase):
accumulator += item accumulator += item
q.task_done() q.task_done()
@tasks.coroutine @asyncio.coroutine
def test(): def test():
for _ in range(2): for _ in range(2):
tasks.Task(worker(), loop=self.loop) asyncio.Task(worker(), loop=self.loop)
yield from q.join() yield from q.join()
...@@ -446,12 +442,12 @@ class JoinableQueueTests(_QueueTestBase): ...@@ -446,12 +442,12 @@ class JoinableQueueTests(_QueueTestBase):
q.put_nowait(0) q.put_nowait(0)
def test_join_empty_queue(self): def test_join_empty_queue(self):
q = queues.JoinableQueue(loop=self.loop) q = asyncio.JoinableQueue(loop=self.loop)
# Test that a queue join()s successfully, and before anything else # Test that a queue join()s successfully, and before anything else
# (done twice for insurance). # (done twice for insurance).
@tasks.coroutine @asyncio.coroutine
def join(): def join():
yield from q.join() yield from q.join()
yield from q.join() yield from q.join()
...@@ -459,7 +455,7 @@ class JoinableQueueTests(_QueueTestBase): ...@@ -459,7 +455,7 @@ class JoinableQueueTests(_QueueTestBase):
self.loop.run_until_complete(join()) self.loop.run_until_complete(join())
def test_format(self): def test_format(self):
q = queues.JoinableQueue(loop=self.loop) q = asyncio.JoinableQueue(loop=self.loop)
self.assertEqual(q._format(), 'maxsize=0') self.assertEqual(q._format(), 'maxsize=0')
q._unfinished_tasks = 2 q._unfinished_tasks = 2
......
...@@ -13,18 +13,16 @@ try: ...@@ -13,18 +13,16 @@ try:
except ImportError: except ImportError:
ssl = None ssl = None
from asyncio import futures import asyncio
from asyncio import selectors from asyncio import selectors
from asyncio import test_utils from asyncio import test_utils
from asyncio.protocols import DatagramProtocol, Protocol
from asyncio.selector_events import BaseSelectorEventLoop
from asyncio.selector_events import _SelectorTransport from asyncio.selector_events import _SelectorTransport
from asyncio.selector_events import _SelectorSslTransport from asyncio.selector_events import _SelectorSslTransport
from asyncio.selector_events import _SelectorSocketTransport from asyncio.selector_events import _SelectorSocketTransport
from asyncio.selector_events import _SelectorDatagramTransport from asyncio.selector_events import _SelectorDatagramTransport
class TestBaseSelectorEventLoop(BaseSelectorEventLoop): class TestBaseSelectorEventLoop(asyncio.BaseSelectorEventLoop):
def _make_self_pipe(self): def _make_self_pipe(self):
self._ssock = unittest.mock.Mock() self._ssock = unittest.mock.Mock()
...@@ -127,13 +125,13 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -127,13 +125,13 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop._sock_recv = unittest.mock.Mock() self.loop._sock_recv = unittest.mock.Mock()
f = self.loop.sock_recv(sock, 1024) f = self.loop.sock_recv(sock, 1024)
self.assertIsInstance(f, futures.Future) self.assertIsInstance(f, asyncio.Future)
self.loop._sock_recv.assert_called_with(f, False, sock, 1024) self.loop._sock_recv.assert_called_with(f, False, sock, 1024)
def test__sock_recv_canceled_fut(self): def test__sock_recv_canceled_fut(self):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() f.cancel()
self.loop._sock_recv(f, False, sock, 1024) self.loop._sock_recv(f, False, sock, 1024)
...@@ -143,7 +141,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -143,7 +141,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() f.cancel()
self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_reader = unittest.mock.Mock()
...@@ -151,7 +149,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -151,7 +149,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.assertEqual((10,), self.loop.remove_reader.call_args[0]) self.assertEqual((10,), self.loop.remove_reader.call_args[0])
def test__sock_recv_tryagain(self): def test__sock_recv_tryagain(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.recv.side_effect = BlockingIOError sock.recv.side_effect = BlockingIOError
...@@ -162,7 +160,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -162,7 +160,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.add_reader.call_args[0]) self.loop.add_reader.call_args[0])
def test__sock_recv_exception(self): def test__sock_recv_exception(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
err = sock.recv.side_effect = OSError() err = sock.recv.side_effect = OSError()
...@@ -175,7 +173,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -175,7 +173,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop._sock_sendall = unittest.mock.Mock() self.loop._sock_sendall = unittest.mock.Mock()
f = self.loop.sock_sendall(sock, b'data') f = self.loop.sock_sendall(sock, b'data')
self.assertIsInstance(f, futures.Future) self.assertIsInstance(f, asyncio.Future)
self.assertEqual( self.assertEqual(
(f, False, sock, b'data'), (f, False, sock, b'data'),
self.loop._sock_sendall.call_args[0]) self.loop._sock_sendall.call_args[0])
...@@ -185,7 +183,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -185,7 +183,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop._sock_sendall = unittest.mock.Mock() self.loop._sock_sendall = unittest.mock.Mock()
f = self.loop.sock_sendall(sock, b'') f = self.loop.sock_sendall(sock, b'')
self.assertIsInstance(f, futures.Future) self.assertIsInstance(f, asyncio.Future)
self.assertTrue(f.done()) self.assertTrue(f.done())
self.assertIsNone(f.result()) self.assertIsNone(f.result())
self.assertFalse(self.loop._sock_sendall.called) self.assertFalse(self.loop._sock_sendall.called)
...@@ -193,7 +191,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -193,7 +191,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
def test__sock_sendall_canceled_fut(self): def test__sock_sendall_canceled_fut(self):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() f.cancel()
self.loop._sock_sendall(f, False, sock, b'data') self.loop._sock_sendall(f, False, sock, b'data')
...@@ -203,7 +201,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -203,7 +201,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() f.cancel()
self.loop.remove_writer = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock()
...@@ -211,7 +209,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -211,7 +209,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.assertEqual((10,), self.loop.remove_writer.call_args[0]) self.assertEqual((10,), self.loop.remove_writer.call_args[0])
def test__sock_sendall_tryagain(self): def test__sock_sendall_tryagain(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.send.side_effect = BlockingIOError sock.send.side_effect = BlockingIOError
...@@ -223,7 +221,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -223,7 +221,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.add_writer.call_args[0]) self.loop.add_writer.call_args[0])
def test__sock_sendall_interrupted(self): def test__sock_sendall_interrupted(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.send.side_effect = InterruptedError sock.send.side_effect = InterruptedError
...@@ -235,7 +233,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -235,7 +233,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.add_writer.call_args[0]) self.loop.add_writer.call_args[0])
def test__sock_sendall_exception(self): def test__sock_sendall_exception(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
err = sock.send.side_effect = OSError() err = sock.send.side_effect = OSError()
...@@ -246,7 +244,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -246,7 +244,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
def test__sock_sendall(self): def test__sock_sendall(self):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.send.return_value = 4 sock.send.return_value = 4
...@@ -257,7 +255,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -257,7 +255,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
def test__sock_sendall_partial(self): def test__sock_sendall_partial(self):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.send.return_value = 2 sock.send.return_value = 2
...@@ -271,7 +269,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -271,7 +269,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
def test__sock_sendall_none(self): def test__sock_sendall_none(self):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.send.return_value = 0 sock.send.return_value = 0
...@@ -287,13 +285,13 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -287,13 +285,13 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop._sock_connect = unittest.mock.Mock() self.loop._sock_connect = unittest.mock.Mock()
f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
self.assertIsInstance(f, futures.Future) self.assertIsInstance(f, asyncio.Future)
self.assertEqual( self.assertEqual(
(f, False, sock, ('127.0.0.1', 8080)), (f, False, sock, ('127.0.0.1', 8080)),
self.loop._sock_connect.call_args[0]) self.loop._sock_connect.call_args[0])
def test__sock_connect(self): def test__sock_connect(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
...@@ -306,7 +304,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -306,7 +304,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
def test__sock_connect_canceled_fut(self): def test__sock_connect_canceled_fut(self):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() f.cancel()
self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
...@@ -316,7 +314,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -316,7 +314,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() f.cancel()
self.loop.remove_writer = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock()
...@@ -324,7 +322,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -324,7 +322,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.assertEqual((10,), self.loop.remove_writer.call_args[0]) self.assertEqual((10,), self.loop.remove_writer.call_args[0])
def test__sock_connect_tryagain(self): def test__sock_connect_tryagain(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.getsockopt.return_value = errno.EAGAIN sock.getsockopt.return_value = errno.EAGAIN
...@@ -339,7 +337,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -339,7 +337,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.add_writer.call_args[0]) self.loop.add_writer.call_args[0])
def test__sock_connect_exception(self): def test__sock_connect_exception(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.getsockopt.return_value = errno.ENOTCONN sock.getsockopt.return_value = errno.ENOTCONN
...@@ -353,12 +351,12 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -353,12 +351,12 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop._sock_accept = unittest.mock.Mock() self.loop._sock_accept = unittest.mock.Mock()
f = self.loop.sock_accept(sock) f = self.loop.sock_accept(sock)
self.assertIsInstance(f, futures.Future) self.assertIsInstance(f, asyncio.Future)
self.assertEqual( self.assertEqual(
(f, False, sock), self.loop._sock_accept.call_args[0]) (f, False, sock), self.loop._sock_accept.call_args[0])
def test__sock_accept(self): def test__sock_accept(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
conn = unittest.mock.Mock() conn = unittest.mock.Mock()
...@@ -374,7 +372,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -374,7 +372,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
def test__sock_accept_canceled_fut(self): def test__sock_accept_canceled_fut(self):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() f.cancel()
self.loop._sock_accept(f, False, sock) self.loop._sock_accept(f, False, sock)
...@@ -384,7 +382,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -384,7 +382,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
f.cancel() f.cancel()
self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_reader = unittest.mock.Mock()
...@@ -392,7 +390,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -392,7 +390,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.assertEqual((10,), self.loop.remove_reader.call_args[0]) self.assertEqual((10,), self.loop.remove_reader.call_args[0])
def test__sock_accept_tryagain(self): def test__sock_accept_tryagain(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
sock.accept.side_effect = BlockingIOError sock.accept.side_effect = BlockingIOError
...@@ -404,7 +402,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -404,7 +402,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.add_reader.call_args[0]) self.loop.add_reader.call_args[0])
def test__sock_accept_exception(self): def test__sock_accept_exception(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
sock = unittest.mock.Mock() sock = unittest.mock.Mock()
sock.fileno.return_value = 10 sock.fileno.return_value = 10
err = sock.accept.side_effect = OSError() err = sock.accept.side_effect = OSError()
...@@ -587,7 +585,7 @@ class SelectorTransportTests(unittest.TestCase): ...@@ -587,7 +585,7 @@ class SelectorTransportTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
self.protocol = test_utils.make_test_protocol(Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = unittest.mock.Mock(socket.socket) self.sock = unittest.mock.Mock(socket.socket)
self.sock.fileno.return_value = 7 self.sock.fileno.return_value = 7
...@@ -674,7 +672,7 @@ class SelectorSocketTransportTests(unittest.TestCase): ...@@ -674,7 +672,7 @@ class SelectorSocketTransportTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
self.protocol = test_utils.make_test_protocol(Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = unittest.mock.Mock(socket.socket) self.sock = unittest.mock.Mock(socket.socket)
self.sock_fd = self.sock.fileno.return_value = 7 self.sock_fd = self.sock.fileno.return_value = 7
...@@ -686,7 +684,7 @@ class SelectorSocketTransportTests(unittest.TestCase): ...@@ -686,7 +684,7 @@ class SelectorSocketTransportTests(unittest.TestCase):
self.protocol.connection_made.assert_called_with(tr) self.protocol.connection_made.assert_called_with(tr)
def test_ctor_with_waiter(self): def test_ctor_with_waiter(self):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
_SelectorSocketTransport( _SelectorSocketTransport(
self.loop, self.sock, self.protocol, fut) self.loop, self.sock, self.protocol, fut)
...@@ -1039,7 +1037,7 @@ class SelectorSslTransportTests(unittest.TestCase): ...@@ -1039,7 +1037,7 @@ class SelectorSslTransportTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
self.protocol = test_utils.make_test_protocol(Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.sock = unittest.mock.Mock(socket.socket) self.sock = unittest.mock.Mock(socket.socket)
self.sock.fileno.return_value = 7 self.sock.fileno.return_value = 7
self.sslsock = unittest.mock.Mock() self.sslsock = unittest.mock.Mock()
...@@ -1057,7 +1055,7 @@ class SelectorSslTransportTests(unittest.TestCase): ...@@ -1057,7 +1055,7 @@ class SelectorSslTransportTests(unittest.TestCase):
return transport return transport
def test_on_handshake(self): def test_on_handshake(self):
waiter = futures.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
tr = _SelectorSslTransport( tr = _SelectorSslTransport(
self.loop, self.sock, self.protocol, self.sslcontext, self.loop, self.sock, self.protocol, self.sslcontext,
waiter=waiter) waiter=waiter)
...@@ -1085,7 +1083,7 @@ class SelectorSslTransportTests(unittest.TestCase): ...@@ -1085,7 +1083,7 @@ class SelectorSslTransportTests(unittest.TestCase):
self.sslsock.do_handshake.side_effect = exc self.sslsock.do_handshake.side_effect = exc
transport = _SelectorSslTransport( transport = _SelectorSslTransport(
self.loop, self.sock, self.protocol, self.sslcontext) self.loop, self.sock, self.protocol, self.sslcontext)
transport._waiter = futures.Future(loop=self.loop) transport._waiter = asyncio.Future(loop=self.loop)
transport._on_handshake() transport._on_handshake()
self.assertTrue(self.sslsock.close.called) self.assertTrue(self.sslsock.close.called)
self.assertTrue(transport._waiter.done()) self.assertTrue(transport._waiter.done())
...@@ -1094,7 +1092,7 @@ class SelectorSslTransportTests(unittest.TestCase): ...@@ -1094,7 +1092,7 @@ class SelectorSslTransportTests(unittest.TestCase):
def test_on_handshake_base_exc(self): def test_on_handshake_base_exc(self):
transport = _SelectorSslTransport( transport = _SelectorSslTransport(
self.loop, self.sock, self.protocol, self.sslcontext) self.loop, self.sock, self.protocol, self.sslcontext)
transport._waiter = futures.Future(loop=self.loop) transport._waiter = asyncio.Future(loop=self.loop)
exc = BaseException() exc = BaseException()
self.sslsock.do_handshake.side_effect = exc self.sslsock.do_handshake.side_effect = exc
self.assertRaises(BaseException, transport._on_handshake) self.assertRaises(BaseException, transport._on_handshake)
...@@ -1368,7 +1366,7 @@ class SelectorDatagramTransportTests(unittest.TestCase): ...@@ -1368,7 +1366,7 @@ class SelectorDatagramTransportTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
self.protocol = test_utils.make_test_protocol(DatagramProtocol) self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
self.sock = unittest.mock.Mock(spec_set=socket.socket) self.sock = unittest.mock.Mock(spec_set=socket.socket)
self.sock.fileno.return_value = 7 self.sock.fileno.return_value = 7
......
...@@ -8,9 +8,7 @@ try: ...@@ -8,9 +8,7 @@ try:
except ImportError: except ImportError:
ssl = None ssl = None
from asyncio import events import asyncio
from asyncio import streams
from asyncio import tasks
from asyncio import test_utils from asyncio import test_utils
...@@ -19,8 +17,8 @@ class StreamReaderTests(unittest.TestCase): ...@@ -19,8 +17,8 @@ class StreamReaderTests(unittest.TestCase):
DATA = b'line1\nline2\nline3\n' DATA = b'line1\nline2\nline3\n'
def setUp(self): def setUp(self):
self.loop = events.new_event_loop() self.loop = asyncio.new_event_loop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
# just in case if we have transport close callbacks # just in case if we have transport close callbacks
...@@ -31,12 +29,12 @@ class StreamReaderTests(unittest.TestCase): ...@@ -31,12 +29,12 @@ class StreamReaderTests(unittest.TestCase):
@unittest.mock.patch('asyncio.streams.events') @unittest.mock.patch('asyncio.streams.events')
def test_ctor_global_loop(self, m_events): def test_ctor_global_loop(self, m_events):
stream = streams.StreamReader() stream = asyncio.StreamReader()
self.assertIs(stream._loop, m_events.get_event_loop.return_value) self.assertIs(stream._loop, m_events.get_event_loop.return_value)
def test_open_connection(self): def test_open_connection(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
f = streams.open_connection(*httpd.address, loop=self.loop) f = asyncio.open_connection(*httpd.address, loop=self.loop)
reader, writer = self.loop.run_until_complete(f) reader, writer = self.loop.run_until_complete(f)
writer.write(b'GET / HTTP/1.0\r\n\r\n') writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.readline() f = reader.readline()
...@@ -52,12 +50,12 @@ class StreamReaderTests(unittest.TestCase): ...@@ -52,12 +50,12 @@ class StreamReaderTests(unittest.TestCase):
def test_open_connection_no_loop_ssl(self): def test_open_connection_no_loop_ssl(self):
with test_utils.run_test_server(use_ssl=True) as httpd: with test_utils.run_test_server(use_ssl=True) as httpd:
try: try:
events.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
f = streams.open_connection(*httpd.address, f = asyncio.open_connection(*httpd.address,
ssl=test_utils.dummy_ssl_context()) ssl=test_utils.dummy_ssl_context())
reader, writer = self.loop.run_until_complete(f) reader, writer = self.loop.run_until_complete(f)
finally: finally:
events.set_event_loop(None) asyncio.set_event_loop(None)
writer.write(b'GET / HTTP/1.0\r\n\r\n') writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.read() f = reader.read()
data = self.loop.run_until_complete(f) data = self.loop.run_until_complete(f)
...@@ -67,7 +65,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -67,7 +65,7 @@ class StreamReaderTests(unittest.TestCase):
def test_open_connection_error(self): def test_open_connection_error(self):
with test_utils.run_test_server() as httpd: with test_utils.run_test_server() as httpd:
f = streams.open_connection(*httpd.address, loop=self.loop) f = asyncio.open_connection(*httpd.address, loop=self.loop)
reader, writer = self.loop.run_until_complete(f) reader, writer = self.loop.run_until_complete(f)
writer._protocol.connection_lost(ZeroDivisionError()) writer._protocol.connection_lost(ZeroDivisionError())
f = reader.read() f = reader.read()
...@@ -78,20 +76,20 @@ class StreamReaderTests(unittest.TestCase): ...@@ -78,20 +76,20 @@ class StreamReaderTests(unittest.TestCase):
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
def test_feed_empty_data(self): def test_feed_empty_data(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'') stream.feed_data(b'')
self.assertEqual(0, stream._byte_count) self.assertEqual(0, stream._byte_count)
def test_feed_data_byte_count(self): def test_feed_data_byte_count(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(self.DATA) stream.feed_data(self.DATA)
self.assertEqual(len(self.DATA), stream._byte_count) self.assertEqual(len(self.DATA), stream._byte_count)
def test_read_zero(self): def test_read_zero(self):
# Read zero bytes. # Read zero bytes.
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(self.DATA) stream.feed_data(self.DATA)
data = self.loop.run_until_complete(stream.read(0)) data = self.loop.run_until_complete(stream.read(0))
...@@ -100,8 +98,8 @@ class StreamReaderTests(unittest.TestCase): ...@@ -100,8 +98,8 @@ class StreamReaderTests(unittest.TestCase):
def test_read(self): def test_read(self):
# Read bytes. # Read bytes.
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
read_task = tasks.Task(stream.read(30), loop=self.loop) read_task = asyncio.Task(stream.read(30), loop=self.loop)
def cb(): def cb():
stream.feed_data(self.DATA) stream.feed_data(self.DATA)
...@@ -113,7 +111,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -113,7 +111,7 @@ class StreamReaderTests(unittest.TestCase):
def test_read_line_breaks(self): def test_read_line_breaks(self):
# Read bytes without line breaks. # Read bytes without line breaks.
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'line1') stream.feed_data(b'line1')
stream.feed_data(b'line2') stream.feed_data(b'line2')
...@@ -124,8 +122,8 @@ class StreamReaderTests(unittest.TestCase): ...@@ -124,8 +122,8 @@ class StreamReaderTests(unittest.TestCase):
def test_read_eof(self): def test_read_eof(self):
# Read bytes, stop at eof. # Read bytes, stop at eof.
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
read_task = tasks.Task(stream.read(1024), loop=self.loop) read_task = asyncio.Task(stream.read(1024), loop=self.loop)
def cb(): def cb():
stream.feed_eof() stream.feed_eof()
...@@ -137,8 +135,8 @@ class StreamReaderTests(unittest.TestCase): ...@@ -137,8 +135,8 @@ class StreamReaderTests(unittest.TestCase):
def test_read_until_eof(self): def test_read_until_eof(self):
# Read all bytes until eof. # Read all bytes until eof.
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
read_task = tasks.Task(stream.read(-1), loop=self.loop) read_task = asyncio.Task(stream.read(-1), loop=self.loop)
def cb(): def cb():
stream.feed_data(b'chunk1\n') stream.feed_data(b'chunk1\n')
...@@ -152,7 +150,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -152,7 +150,7 @@ class StreamReaderTests(unittest.TestCase):
self.assertFalse(stream._byte_count) self.assertFalse(stream._byte_count)
def test_read_exception(self): def test_read_exception(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'line\n') stream.feed_data(b'line\n')
data = self.loop.run_until_complete(stream.read(2)) data = self.loop.run_until_complete(stream.read(2))
...@@ -164,9 +162,9 @@ class StreamReaderTests(unittest.TestCase): ...@@ -164,9 +162,9 @@ class StreamReaderTests(unittest.TestCase):
def test_readline(self): def test_readline(self):
# Read one line. # Read one line.
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'chunk1 ') stream.feed_data(b'chunk1 ')
read_task = tasks.Task(stream.readline(), loop=self.loop) read_task = asyncio.Task(stream.readline(), loop=self.loop)
def cb(): def cb():
stream.feed_data(b'chunk2 ') stream.feed_data(b'chunk2 ')
...@@ -179,7 +177,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -179,7 +177,7 @@ class StreamReaderTests(unittest.TestCase):
self.assertEqual(len(b'\n chunk4')-1, stream._byte_count) self.assertEqual(len(b'\n chunk4')-1, stream._byte_count)
def test_readline_limit_with_existing_data(self): def test_readline_limit_with_existing_data(self):
stream = streams.StreamReader(3, loop=self.loop) stream = asyncio.StreamReader(3, loop=self.loop)
stream.feed_data(b'li') stream.feed_data(b'li')
stream.feed_data(b'ne1\nline2\n') stream.feed_data(b'ne1\nline2\n')
...@@ -187,7 +185,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -187,7 +185,7 @@ class StreamReaderTests(unittest.TestCase):
ValueError, self.loop.run_until_complete, stream.readline()) ValueError, self.loop.run_until_complete, stream.readline())
self.assertEqual([b'line2\n'], list(stream._buffer)) self.assertEqual([b'line2\n'], list(stream._buffer))
stream = streams.StreamReader(3, loop=self.loop) stream = asyncio.StreamReader(3, loop=self.loop)
stream.feed_data(b'li') stream.feed_data(b'li')
stream.feed_data(b'ne1') stream.feed_data(b'ne1')
stream.feed_data(b'li') stream.feed_data(b'li')
...@@ -198,7 +196,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -198,7 +196,7 @@ class StreamReaderTests(unittest.TestCase):
self.assertEqual(2, stream._byte_count) self.assertEqual(2, stream._byte_count)
def test_readline_limit(self): def test_readline_limit(self):
stream = streams.StreamReader(7, loop=self.loop) stream = asyncio.StreamReader(7, loop=self.loop)
def cb(): def cb():
stream.feed_data(b'chunk1') stream.feed_data(b'chunk1')
...@@ -213,7 +211,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -213,7 +211,7 @@ class StreamReaderTests(unittest.TestCase):
self.assertEqual(7, stream._byte_count) self.assertEqual(7, stream._byte_count)
def test_readline_line_byte_count(self): def test_readline_line_byte_count(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[:6])
stream.feed_data(self.DATA[6:]) stream.feed_data(self.DATA[6:])
...@@ -223,7 +221,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -223,7 +221,7 @@ class StreamReaderTests(unittest.TestCase):
self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count) self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count)
def test_readline_eof(self): def test_readline_eof(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'some data') stream.feed_data(b'some data')
stream.feed_eof() stream.feed_eof()
...@@ -231,14 +229,14 @@ class StreamReaderTests(unittest.TestCase): ...@@ -231,14 +229,14 @@ class StreamReaderTests(unittest.TestCase):
self.assertEqual(b'some data', line) self.assertEqual(b'some data', line)
def test_readline_empty_eof(self): def test_readline_empty_eof(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_eof() stream.feed_eof()
line = self.loop.run_until_complete(stream.readline()) line = self.loop.run_until_complete(stream.readline())
self.assertEqual(b'', line) self.assertEqual(b'', line)
def test_readline_read_byte_count(self): def test_readline_read_byte_count(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(self.DATA) stream.feed_data(self.DATA)
self.loop.run_until_complete(stream.readline()) self.loop.run_until_complete(stream.readline())
...@@ -251,7 +249,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -251,7 +249,7 @@ class StreamReaderTests(unittest.TestCase):
stream._byte_count) stream._byte_count)
def test_readline_exception(self): def test_readline_exception(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'line\n') stream.feed_data(b'line\n')
data = self.loop.run_until_complete(stream.readline()) data = self.loop.run_until_complete(stream.readline())
...@@ -263,7 +261,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -263,7 +261,7 @@ class StreamReaderTests(unittest.TestCase):
def test_readexactly_zero_or_less(self): def test_readexactly_zero_or_less(self):
# Read exact number of bytes (zero or less). # Read exact number of bytes (zero or less).
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(self.DATA) stream.feed_data(self.DATA)
data = self.loop.run_until_complete(stream.readexactly(0)) data = self.loop.run_until_complete(stream.readexactly(0))
...@@ -276,10 +274,10 @@ class StreamReaderTests(unittest.TestCase): ...@@ -276,10 +274,10 @@ class StreamReaderTests(unittest.TestCase):
def test_readexactly(self): def test_readexactly(self):
# Read exact number of bytes. # Read exact number of bytes.
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
n = 2 * len(self.DATA) n = 2 * len(self.DATA)
read_task = tasks.Task(stream.readexactly(n), loop=self.loop) read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
def cb(): def cb():
stream.feed_data(self.DATA) stream.feed_data(self.DATA)
...@@ -293,21 +291,25 @@ class StreamReaderTests(unittest.TestCase): ...@@ -293,21 +291,25 @@ class StreamReaderTests(unittest.TestCase):
def test_readexactly_eof(self): def test_readexactly_eof(self):
# Read exact number of bytes (eof). # Read exact number of bytes (eof).
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
n = 2 * len(self.DATA) n = 2 * len(self.DATA)
read_task = tasks.Task(stream.readexactly(n), loop=self.loop) read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
def cb(): def cb():
stream.feed_data(self.DATA) stream.feed_data(self.DATA)
stream.feed_eof() stream.feed_eof()
self.loop.call_soon(cb) self.loop.call_soon(cb)
data = self.loop.run_until_complete(read_task) with self.assertRaises(asyncio.IncompleteReadError) as cm:
self.assertEqual(self.DATA, data) self.loop.run_until_complete(read_task)
self.assertEqual(cm.exception.partial, self.DATA)
self.assertEqual(cm.exception.expected, n)
self.assertEqual(str(cm.exception),
'18 bytes read on a total of 36 expected bytes')
self.assertFalse(stream._byte_count) self.assertFalse(stream._byte_count)
def test_readexactly_exception(self): def test_readexactly_exception(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'line\n') stream.feed_data(b'line\n')
data = self.loop.run_until_complete(stream.readexactly(2)) data = self.loop.run_until_complete(stream.readexactly(2))
...@@ -318,7 +320,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -318,7 +320,7 @@ class StreamReaderTests(unittest.TestCase):
ValueError, self.loop.run_until_complete, stream.readexactly(2)) ValueError, self.loop.run_until_complete, stream.readexactly(2))
def test_exception(self): def test_exception(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
self.assertIsNone(stream.exception()) self.assertIsNone(stream.exception())
exc = ValueError() exc = ValueError()
...@@ -326,31 +328,31 @@ class StreamReaderTests(unittest.TestCase): ...@@ -326,31 +328,31 @@ class StreamReaderTests(unittest.TestCase):
self.assertIs(stream.exception(), exc) self.assertIs(stream.exception(), exc)
def test_exception_waiter(self): def test_exception_waiter(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def set_err(): def set_err():
stream.set_exception(ValueError()) stream.set_exception(ValueError())
@tasks.coroutine @asyncio.coroutine
def readline(): def readline():
yield from stream.readline() yield from stream.readline()
t1 = tasks.Task(stream.readline(), loop=self.loop) t1 = asyncio.Task(stream.readline(), loop=self.loop)
t2 = tasks.Task(set_err(), loop=self.loop) t2 = asyncio.Task(set_err(), loop=self.loop)
self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) self.loop.run_until_complete(asyncio.wait([t1, t2], loop=self.loop))
self.assertRaises(ValueError, t1.result) self.assertRaises(ValueError, t1.result)
def test_exception_cancel(self): def test_exception_cancel(self):
stream = streams.StreamReader(loop=self.loop) stream = asyncio.StreamReader(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def read_a_line(): def read_a_line():
yield from stream.readline() yield from stream.readline()
t = tasks.Task(read_a_line(), loop=self.loop) t = asyncio.Task(read_a_line(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
t.cancel() t.cancel()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
...@@ -367,19 +369,19 @@ class StreamReaderTests(unittest.TestCase): ...@@ -367,19 +369,19 @@ class StreamReaderTests(unittest.TestCase):
self.server = None self.server = None
self.loop = loop self.loop = loop
@tasks.coroutine @asyncio.coroutine
def handle_client(self, client_reader, client_writer): def handle_client(self, client_reader, client_writer):
data = yield from client_reader.readline() data = yield from client_reader.readline()
client_writer.write(data) client_writer.write(data)
def start(self): def start(self):
self.server = self.loop.run_until_complete( self.server = self.loop.run_until_complete(
streams.start_server(self.handle_client, asyncio.start_server(self.handle_client,
'127.0.0.1', 12345, '127.0.0.1', 12345,
loop=self.loop)) loop=self.loop))
def handle_client_callback(self, client_reader, client_writer): def handle_client_callback(self, client_reader, client_writer):
task = tasks.Task(client_reader.readline(), loop=self.loop) task = asyncio.Task(client_reader.readline(), loop=self.loop)
def done(task): def done(task):
client_writer.write(task.result()) client_writer.write(task.result())
...@@ -388,7 +390,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -388,7 +390,7 @@ class StreamReaderTests(unittest.TestCase):
def start_callback(self): def start_callback(self):
self.server = self.loop.run_until_complete( self.server = self.loop.run_until_complete(
streams.start_server(self.handle_client_callback, asyncio.start_server(self.handle_client_callback,
'127.0.0.1', 12345, '127.0.0.1', 12345,
loop=self.loop)) loop=self.loop))
...@@ -398,9 +400,9 @@ class StreamReaderTests(unittest.TestCase): ...@@ -398,9 +400,9 @@ class StreamReaderTests(unittest.TestCase):
self.loop.run_until_complete(self.server.wait_closed()) self.loop.run_until_complete(self.server.wait_closed())
self.server = None self.server = None
@tasks.coroutine @asyncio.coroutine
def client(): def client():
reader, writer = yield from streams.open_connection( reader, writer = yield from asyncio.open_connection(
'127.0.0.1', 12345, loop=self.loop) '127.0.0.1', 12345, loop=self.loop)
# send a line # send a line
writer.write(b"hello world!\n") writer.write(b"hello world!\n")
...@@ -412,7 +414,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -412,7 +414,7 @@ class StreamReaderTests(unittest.TestCase):
# test the server variant with a coroutine as client handler # test the server variant with a coroutine as client handler
server = MyServer(self.loop) server = MyServer(self.loop)
server.start() server.start()
msg = self.loop.run_until_complete(tasks.Task(client(), msg = self.loop.run_until_complete(asyncio.Task(client(),
loop=self.loop)) loop=self.loop))
server.stop() server.stop()
self.assertEqual(msg, b"hello world!\n") self.assertEqual(msg, b"hello world!\n")
...@@ -420,7 +422,7 @@ class StreamReaderTests(unittest.TestCase): ...@@ -420,7 +422,7 @@ class StreamReaderTests(unittest.TestCase):
# test the server variant with a callback as client handler # test the server variant with a callback as client handler
server = MyServer(self.loop) server = MyServer(self.loop)
server.start_callback() server.start_callback()
msg = self.loop.run_until_complete(tasks.Task(client(), msg = self.loop.run_until_complete(asyncio.Task(client(),
loop=self.loop)) loop=self.loop))
server.stop() server.stop()
self.assertEqual(msg, b"hello world!\n") self.assertEqual(msg, b"hello world!\n")
......
...@@ -5,9 +5,7 @@ import unittest ...@@ -5,9 +5,7 @@ import unittest
import unittest.mock import unittest.mock
from unittest.mock import Mock from unittest.mock import Mock
from asyncio import events import asyncio
from asyncio import futures
from asyncio import tasks
from asyncio import test_utils from asyncio import test_utils
...@@ -24,115 +22,115 @@ class TaskTests(unittest.TestCase): ...@@ -24,115 +22,115 @@ class TaskTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
gc.collect() gc.collect()
def test_task_class(self): def test_task_class(self):
@tasks.coroutine @asyncio.coroutine
def notmuch(): def notmuch():
return 'ok' return 'ok'
t = tasks.Task(notmuch(), loop=self.loop) t = asyncio.Task(notmuch(), loop=self.loop)
self.loop.run_until_complete(t) self.loop.run_until_complete(t)
self.assertTrue(t.done()) self.assertTrue(t.done())
self.assertEqual(t.result(), 'ok') self.assertEqual(t.result(), 'ok')
self.assertIs(t._loop, self.loop) self.assertIs(t._loop, self.loop)
loop = events.new_event_loop() loop = asyncio.new_event_loop()
t = tasks.Task(notmuch(), loop=loop) t = asyncio.Task(notmuch(), loop=loop)
self.assertIs(t._loop, loop) self.assertIs(t._loop, loop)
loop.close() loop.close()
def test_async_coroutine(self): def test_async_coroutine(self):
@tasks.coroutine @asyncio.coroutine
def notmuch(): def notmuch():
return 'ok' return 'ok'
t = tasks.async(notmuch(), loop=self.loop) t = asyncio.async(notmuch(), loop=self.loop)
self.loop.run_until_complete(t) self.loop.run_until_complete(t)
self.assertTrue(t.done()) self.assertTrue(t.done())
self.assertEqual(t.result(), 'ok') self.assertEqual(t.result(), 'ok')
self.assertIs(t._loop, self.loop) self.assertIs(t._loop, self.loop)
loop = events.new_event_loop() loop = asyncio.new_event_loop()
t = tasks.async(notmuch(), loop=loop) t = asyncio.async(notmuch(), loop=loop)
self.assertIs(t._loop, loop) self.assertIs(t._loop, loop)
loop.close() loop.close()
def test_async_future(self): def test_async_future(self):
f_orig = futures.Future(loop=self.loop) f_orig = asyncio.Future(loop=self.loop)
f_orig.set_result('ko') f_orig.set_result('ko')
f = tasks.async(f_orig) f = asyncio.async(f_orig)
self.loop.run_until_complete(f) self.loop.run_until_complete(f)
self.assertTrue(f.done()) self.assertTrue(f.done())
self.assertEqual(f.result(), 'ko') self.assertEqual(f.result(), 'ko')
self.assertIs(f, f_orig) self.assertIs(f, f_orig)
loop = events.new_event_loop() loop = asyncio.new_event_loop()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
f = tasks.async(f_orig, loop=loop) f = asyncio.async(f_orig, loop=loop)
loop.close() loop.close()
f = tasks.async(f_orig, loop=self.loop) f = asyncio.async(f_orig, loop=self.loop)
self.assertIs(f, f_orig) self.assertIs(f, f_orig)
def test_async_task(self): def test_async_task(self):
@tasks.coroutine @asyncio.coroutine
def notmuch(): def notmuch():
return 'ok' return 'ok'
t_orig = tasks.Task(notmuch(), loop=self.loop) t_orig = asyncio.Task(notmuch(), loop=self.loop)
t = tasks.async(t_orig) t = asyncio.async(t_orig)
self.loop.run_until_complete(t) self.loop.run_until_complete(t)
self.assertTrue(t.done()) self.assertTrue(t.done())
self.assertEqual(t.result(), 'ok') self.assertEqual(t.result(), 'ok')
self.assertIs(t, t_orig) self.assertIs(t, t_orig)
loop = events.new_event_loop() loop = asyncio.new_event_loop()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
t = tasks.async(t_orig, loop=loop) t = asyncio.async(t_orig, loop=loop)
loop.close() loop.close()
t = tasks.async(t_orig, loop=self.loop) t = asyncio.async(t_orig, loop=self.loop)
self.assertIs(t, t_orig) self.assertIs(t, t_orig)
def test_async_neither(self): def test_async_neither(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
tasks.async('ok') asyncio.async('ok')
def test_task_repr(self): def test_task_repr(self):
@tasks.coroutine @asyncio.coroutine
def notmuch(): def notmuch():
yield from [] yield from []
return 'abc' return 'abc'
t = tasks.Task(notmuch(), loop=self.loop) t = asyncio.Task(notmuch(), loop=self.loop)
t.add_done_callback(Dummy()) t.add_done_callback(Dummy())
self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>') self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>')
t.cancel() # Does not take immediate effect! t.cancel() # Does not take immediate effect!
self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>') self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>')
self.assertRaises(futures.CancelledError, self.assertRaises(asyncio.CancelledError,
self.loop.run_until_complete, t) self.loop.run_until_complete, t)
self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>') self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>')
t = tasks.Task(notmuch(), loop=self.loop) t = asyncio.Task(notmuch(), loop=self.loop)
self.loop.run_until_complete(t) self.loop.run_until_complete(t)
self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>") self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>")
def test_task_repr_custom(self): def test_task_repr_custom(self):
@tasks.coroutine @asyncio.coroutine
def coro(): def coro():
pass pass
class T(futures.Future): class T(asyncio.Future):
def __repr__(self): def __repr__(self):
return 'T[]' return 'T[]'
class MyTask(tasks.Task, T): class MyTask(asyncio.Task, T):
def __repr__(self): def __repr__(self):
return super().__repr__() return super().__repr__()
...@@ -142,17 +140,17 @@ class TaskTests(unittest.TestCase): ...@@ -142,17 +140,17 @@ class TaskTests(unittest.TestCase):
gen.close() gen.close()
def test_task_basics(self): def test_task_basics(self):
@tasks.coroutine @asyncio.coroutine
def outer(): def outer():
a = yield from inner1() a = yield from inner1()
b = yield from inner2() b = yield from inner2()
return a+b return a+b
@tasks.coroutine @asyncio.coroutine
def inner1(): def inner1():
return 42 return 42
@tasks.coroutine @asyncio.coroutine
def inner2(): def inner2():
return 1000 return 1000
...@@ -169,66 +167,66 @@ class TaskTests(unittest.TestCase): ...@@ -169,66 +167,66 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
@tasks.coroutine @asyncio.coroutine
def task(): def task():
yield from tasks.sleep(10.0, loop=loop) yield from asyncio.sleep(10.0, loop=loop)
return 12 return 12
t = tasks.Task(task(), loop=loop) t = asyncio.Task(task(), loop=loop)
loop.call_soon(t.cancel) loop.call_soon(t.cancel)
with self.assertRaises(futures.CancelledError): with self.assertRaises(asyncio.CancelledError):
loop.run_until_complete(t) loop.run_until_complete(t)
self.assertTrue(t.done()) self.assertTrue(t.done())
self.assertTrue(t.cancelled()) self.assertTrue(t.cancelled())
self.assertFalse(t.cancel()) self.assertFalse(t.cancel())
def test_cancel_yield(self): def test_cancel_yield(self):
@tasks.coroutine @asyncio.coroutine
def task(): def task():
yield yield
yield yield
return 12 return 12
t = tasks.Task(task(), loop=self.loop) t = asyncio.Task(task(), loop=self.loop)
test_utils.run_briefly(self.loop) # start coro test_utils.run_briefly(self.loop) # start coro
t.cancel() t.cancel()
self.assertRaises( self.assertRaises(
futures.CancelledError, self.loop.run_until_complete, t) asyncio.CancelledError, self.loop.run_until_complete, t)
self.assertTrue(t.done()) self.assertTrue(t.done())
self.assertTrue(t.cancelled()) self.assertTrue(t.cancelled())
self.assertFalse(t.cancel()) self.assertFalse(t.cancel())
def test_cancel_inner_future(self): def test_cancel_inner_future(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def task(): def task():
yield from f yield from f
return 12 return 12
t = tasks.Task(task(), loop=self.loop) t = asyncio.Task(task(), loop=self.loop)
test_utils.run_briefly(self.loop) # start task test_utils.run_briefly(self.loop) # start task
f.cancel() f.cancel()
with self.assertRaises(futures.CancelledError): with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(t) self.loop.run_until_complete(t)
self.assertTrue(f.cancelled()) self.assertTrue(f.cancelled())
self.assertTrue(t.cancelled()) self.assertTrue(t.cancelled())
def test_cancel_both_task_and_inner_future(self): def test_cancel_both_task_and_inner_future(self):
f = futures.Future(loop=self.loop) f = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def task(): def task():
yield from f yield from f
return 12 return 12
t = tasks.Task(task(), loop=self.loop) t = asyncio.Task(task(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
f.cancel() f.cancel()
t.cancel() t.cancel()
with self.assertRaises(futures.CancelledError): with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(t) self.loop.run_until_complete(t)
self.assertTrue(t.done()) self.assertTrue(t.done())
...@@ -236,18 +234,18 @@ class TaskTests(unittest.TestCase): ...@@ -236,18 +234,18 @@ class TaskTests(unittest.TestCase):
self.assertTrue(t.cancelled()) self.assertTrue(t.cancelled())
def test_cancel_task_catching(self): def test_cancel_task_catching(self):
fut1 = futures.Future(loop=self.loop) fut1 = asyncio.Future(loop=self.loop)
fut2 = futures.Future(loop=self.loop) fut2 = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def task(): def task():
yield from fut1 yield from fut1
try: try:
yield from fut2 yield from fut2
except futures.CancelledError: except asyncio.CancelledError:
return 42 return 42
t = tasks.Task(task(), loop=self.loop) t = asyncio.Task(task(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertIs(t._fut_waiter, fut1) # White-box test. self.assertIs(t._fut_waiter, fut1) # White-box test.
fut1.set_result(None) fut1.set_result(None)
...@@ -260,21 +258,21 @@ class TaskTests(unittest.TestCase): ...@@ -260,21 +258,21 @@ class TaskTests(unittest.TestCase):
self.assertFalse(t.cancelled()) self.assertFalse(t.cancelled())
def test_cancel_task_ignoring(self): def test_cancel_task_ignoring(self):
fut1 = futures.Future(loop=self.loop) fut1 = asyncio.Future(loop=self.loop)
fut2 = futures.Future(loop=self.loop) fut2 = asyncio.Future(loop=self.loop)
fut3 = futures.Future(loop=self.loop) fut3 = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def task(): def task():
yield from fut1 yield from fut1
try: try:
yield from fut2 yield from fut2
except futures.CancelledError: except asyncio.CancelledError:
pass pass
res = yield from fut3 res = yield from fut3
return res return res
t = tasks.Task(task(), loop=self.loop) t = asyncio.Task(task(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertIs(t._fut_waiter, fut1) # White-box test. self.assertIs(t._fut_waiter, fut1) # White-box test.
fut1.set_result(None) fut1.set_result(None)
...@@ -291,20 +289,20 @@ class TaskTests(unittest.TestCase): ...@@ -291,20 +289,20 @@ class TaskTests(unittest.TestCase):
self.assertFalse(t.cancelled()) self.assertFalse(t.cancelled())
def test_cancel_current_task(self): def test_cancel_current_task(self):
loop = events.new_event_loop() loop = asyncio.new_event_loop()
self.addCleanup(loop.close) self.addCleanup(loop.close)
@tasks.coroutine @asyncio.coroutine
def task(): def task():
t.cancel() t.cancel()
self.assertTrue(t._must_cancel) # White-box test. self.assertTrue(t._must_cancel) # White-box test.
# The sleep should be cancelled immediately. # The sleep should be cancelled immediately.
yield from tasks.sleep(100, loop=loop) yield from asyncio.sleep(100, loop=loop)
return 12 return 12
t = tasks.Task(task(), loop=loop) t = asyncio.Task(task(), loop=loop)
self.assertRaises( self.assertRaises(
futures.CancelledError, loop.run_until_complete, t) asyncio.CancelledError, loop.run_until_complete, t)
self.assertTrue(t.done()) self.assertTrue(t.done())
self.assertFalse(t._must_cancel) # White-box test. self.assertFalse(t._must_cancel) # White-box test.
self.assertFalse(t.cancel()) self.assertFalse(t.cancel())
...@@ -326,17 +324,17 @@ class TaskTests(unittest.TestCase): ...@@ -326,17 +324,17 @@ class TaskTests(unittest.TestCase):
x = 0 x = 0
waiters = [] waiters = []
@tasks.coroutine @asyncio.coroutine
def task(): def task():
nonlocal x nonlocal x
while x < 10: while x < 10:
waiters.append(tasks.sleep(0.1, loop=loop)) waiters.append(asyncio.sleep(0.1, loop=loop))
yield from waiters[-1] yield from waiters[-1]
x += 1 x += 1
if x == 2: if x == 2:
loop.stop() loop.stop()
t = tasks.Task(task(), loop=loop) t = asyncio.Task(task(), loop=loop)
self.assertRaises( self.assertRaises(
RuntimeError, loop.run_until_complete, t) RuntimeError, loop.run_until_complete, t)
self.assertFalse(t.done()) self.assertFalse(t.done())
...@@ -361,20 +359,20 @@ class TaskTests(unittest.TestCase): ...@@ -361,20 +359,20 @@ class TaskTests(unittest.TestCase):
foo_running = None foo_running = None
@tasks.coroutine @asyncio.coroutine
def foo(): def foo():
nonlocal foo_running nonlocal foo_running
foo_running = True foo_running = True
try: try:
yield from tasks.sleep(0.2, loop=loop) yield from asyncio.sleep(0.2, loop=loop)
finally: finally:
foo_running = False foo_running = False
return 'done' return 'done'
fut = tasks.Task(foo(), loop=loop) fut = asyncio.Task(foo(), loop=loop)
with self.assertRaises(futures.TimeoutError): with self.assertRaises(asyncio.TimeoutError):
loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) loop.run_until_complete(asyncio.wait_for(fut, 0.1, loop=loop))
self.assertTrue(fut.done()) self.assertTrue(fut.done())
# it should have been cancelled due to the timeout # it should have been cancelled due to the timeout
self.assertTrue(fut.cancelled()) self.assertTrue(fut.cancelled())
...@@ -394,18 +392,18 @@ class TaskTests(unittest.TestCase): ...@@ -394,18 +392,18 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
@tasks.coroutine @asyncio.coroutine
def foo(): def foo():
yield from tasks.sleep(0.2, loop=loop) yield from asyncio.sleep(0.2, loop=loop)
return 'done' return 'done'
events.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
fut = tasks.Task(foo(), loop=loop) fut = asyncio.Task(foo(), loop=loop)
with self.assertRaises(futures.TimeoutError): with self.assertRaises(asyncio.TimeoutError):
loop.run_until_complete(tasks.wait_for(fut, 0.01)) loop.run_until_complete(asyncio.wait_for(fut, 0.01))
finally: finally:
events.set_event_loop(None) asyncio.set_event_loop(None)
self.assertAlmostEqual(0.01, loop.time()) self.assertAlmostEqual(0.01, loop.time())
self.assertTrue(fut.done()) self.assertTrue(fut.done())
...@@ -423,22 +421,22 @@ class TaskTests(unittest.TestCase): ...@@ -423,22 +421,22 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
@tasks.coroutine @asyncio.coroutine
def foo(): def foo():
done, pending = yield from tasks.wait([b, a], loop=loop) done, pending = yield from asyncio.wait([b, a], loop=loop)
self.assertEqual(done, set([a, b])) self.assertEqual(done, set([a, b]))
self.assertEqual(pending, set()) self.assertEqual(pending, set())
return 42 return 42
res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
self.assertEqual(res, 42) self.assertEqual(res, 42)
self.assertAlmostEqual(0.15, loop.time()) self.assertAlmostEqual(0.15, loop.time())
# Doing it again should take no time and exercise a different path. # Doing it again should take no time and exercise a different path.
res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
self.assertAlmostEqual(0.15, loop.time()) self.assertAlmostEqual(0.15, loop.time())
self.assertEqual(res, 42) self.assertEqual(res, 42)
...@@ -454,33 +452,33 @@ class TaskTests(unittest.TestCase): ...@@ -454,33 +452,33 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.01, loop=loop), loop=loop)
b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.015, loop=loop), loop=loop)
@tasks.coroutine @asyncio.coroutine
def foo(): def foo():
done, pending = yield from tasks.wait([b, a]) done, pending = yield from asyncio.wait([b, a])
self.assertEqual(done, set([a, b])) self.assertEqual(done, set([a, b]))
self.assertEqual(pending, set()) self.assertEqual(pending, set())
return 42 return 42
events.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
res = loop.run_until_complete( res = loop.run_until_complete(
tasks.Task(foo(), loop=loop)) asyncio.Task(foo(), loop=loop))
finally: finally:
events.set_event_loop(None) asyncio.set_event_loop(None)
self.assertEqual(res, 42) self.assertEqual(res, 42)
def test_wait_errors(self): def test_wait_errors(self):
self.assertRaises( self.assertRaises(
ValueError, self.loop.run_until_complete, ValueError, self.loop.run_until_complete,
tasks.wait(set(), loop=self.loop)) asyncio.wait(set(), loop=self.loop))
self.assertRaises( self.assertRaises(
ValueError, self.loop.run_until_complete, ValueError, self.loop.run_until_complete,
tasks.wait([tasks.sleep(10.0, loop=self.loop)], asyncio.wait([asyncio.sleep(10.0, loop=self.loop)],
return_when=-1, loop=self.loop)) return_when=-1, loop=self.loop))
def test_wait_first_completed(self): def test_wait_first_completed(self):
...@@ -495,10 +493,10 @@ class TaskTests(unittest.TestCase): ...@@ -495,10 +493,10 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
task = tasks.Task( task = asyncio.Task(
tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED,
loop=loop), loop=loop),
loop=loop) loop=loop)
...@@ -512,25 +510,25 @@ class TaskTests(unittest.TestCase): ...@@ -512,25 +510,25 @@ class TaskTests(unittest.TestCase):
# move forward to close generator # move forward to close generator
loop.advance_time(10) loop.advance_time(10)
loop.run_until_complete(tasks.wait([a, b], loop=loop)) loop.run_until_complete(asyncio.wait([a, b], loop=loop))
def test_wait_really_done(self): def test_wait_really_done(self):
# there is possibility that some tasks in the pending list # there is possibility that some tasks in the pending list
# became done but their callbacks haven't all been called yet # became done but their callbacks haven't all been called yet
@tasks.coroutine @asyncio.coroutine
def coro1(): def coro1():
yield yield
@tasks.coroutine @asyncio.coroutine
def coro2(): def coro2():
yield yield
yield yield
a = tasks.Task(coro1(), loop=self.loop) a = asyncio.Task(coro1(), loop=self.loop)
b = tasks.Task(coro2(), loop=self.loop) b = asyncio.Task(coro2(), loop=self.loop)
task = tasks.Task( task = asyncio.Task(
tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED,
loop=self.loop), loop=self.loop),
loop=self.loop) loop=self.loop)
...@@ -552,15 +550,15 @@ class TaskTests(unittest.TestCase): ...@@ -552,15 +550,15 @@ class TaskTests(unittest.TestCase):
self.addCleanup(loop.close) self.addCleanup(loop.close)
# first_exception, task already has exception # first_exception, task already has exception
a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
@tasks.coroutine @asyncio.coroutine
def exc(): def exc():
raise ZeroDivisionError('err') raise ZeroDivisionError('err')
b = tasks.Task(exc(), loop=loop) b = asyncio.Task(exc(), loop=loop)
task = tasks.Task( task = asyncio.Task(
tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION,
loop=loop), loop=loop),
loop=loop) loop=loop)
...@@ -571,7 +569,7 @@ class TaskTests(unittest.TestCase): ...@@ -571,7 +569,7 @@ class TaskTests(unittest.TestCase):
# move forward to close generator # move forward to close generator
loop.advance_time(10) loop.advance_time(10)
loop.run_until_complete(tasks.wait([a, b], loop=loop)) loop.run_until_complete(asyncio.wait([a, b], loop=loop))
def test_wait_first_exception_in_wait(self): def test_wait_first_exception_in_wait(self):
...@@ -586,15 +584,15 @@ class TaskTests(unittest.TestCase): ...@@ -586,15 +584,15 @@ class TaskTests(unittest.TestCase):
self.addCleanup(loop.close) self.addCleanup(loop.close)
# first_exception, exception during waiting # first_exception, exception during waiting
a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
@tasks.coroutine @asyncio.coroutine
def exc(): def exc():
yield from tasks.sleep(0.01, loop=loop) yield from asyncio.sleep(0.01, loop=loop)
raise ZeroDivisionError('err') raise ZeroDivisionError('err')
b = tasks.Task(exc(), loop=loop) b = asyncio.Task(exc(), loop=loop)
task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, task = asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION,
loop=loop) loop=loop)
done, pending = loop.run_until_complete(task) done, pending = loop.run_until_complete(task)
...@@ -604,7 +602,7 @@ class TaskTests(unittest.TestCase): ...@@ -604,7 +602,7 @@ class TaskTests(unittest.TestCase):
# move forward to close generator # move forward to close generator
loop.advance_time(10) loop.advance_time(10)
loop.run_until_complete(tasks.wait([a, b], loop=loop)) loop.run_until_complete(asyncio.wait([a, b], loop=loop))
def test_wait_with_exception(self): def test_wait_with_exception(self):
...@@ -618,27 +616,27 @@ class TaskTests(unittest.TestCase): ...@@ -618,27 +616,27 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
@tasks.coroutine @asyncio.coroutine
def sleeper(): def sleeper():
yield from tasks.sleep(0.15, loop=loop) yield from asyncio.sleep(0.15, loop=loop)
raise ZeroDivisionError('really') raise ZeroDivisionError('really')
b = tasks.Task(sleeper(), loop=loop) b = asyncio.Task(sleeper(), loop=loop)
@tasks.coroutine @asyncio.coroutine
def foo(): def foo():
done, pending = yield from tasks.wait([b, a], loop=loop) done, pending = yield from asyncio.wait([b, a], loop=loop)
self.assertEqual(len(done), 2) self.assertEqual(len(done), 2)
self.assertEqual(pending, set()) self.assertEqual(pending, set())
errors = set(f for f in done if f.exception() is not None) errors = set(f for f in done if f.exception() is not None)
self.assertEqual(len(errors), 1) self.assertEqual(len(errors), 1)
loop.run_until_complete(tasks.Task(foo(), loop=loop)) loop.run_until_complete(asyncio.Task(foo(), loop=loop))
self.assertAlmostEqual(0.15, loop.time()) self.assertAlmostEqual(0.15, loop.time())
loop.run_until_complete(tasks.Task(foo(), loop=loop)) loop.run_until_complete(asyncio.Task(foo(), loop=loop))
self.assertAlmostEqual(0.15, loop.time()) self.assertAlmostEqual(0.15, loop.time())
def test_wait_with_timeout(self): def test_wait_with_timeout(self):
...@@ -655,22 +653,22 @@ class TaskTests(unittest.TestCase): ...@@ -655,22 +653,22 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
@tasks.coroutine @asyncio.coroutine
def foo(): def foo():
done, pending = yield from tasks.wait([b, a], timeout=0.11, done, pending = yield from asyncio.wait([b, a], timeout=0.11,
loop=loop) loop=loop)
self.assertEqual(done, set([a])) self.assertEqual(done, set([a]))
self.assertEqual(pending, set([b])) self.assertEqual(pending, set([b]))
loop.run_until_complete(tasks.Task(foo(), loop=loop)) loop.run_until_complete(asyncio.Task(foo(), loop=loop))
self.assertAlmostEqual(0.11, loop.time()) self.assertAlmostEqual(0.11, loop.time())
# move forward to close generator # move forward to close generator
loop.advance_time(10) loop.advance_time(10)
loop.run_until_complete(tasks.wait([a, b], loop=loop)) loop.run_until_complete(asyncio.wait([a, b], loop=loop))
def test_wait_concurrent_complete(self): def test_wait_concurrent_complete(self):
...@@ -686,11 +684,11 @@ class TaskTests(unittest.TestCase): ...@@ -686,11 +684,11 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
done, pending = loop.run_until_complete( done, pending = loop.run_until_complete(
tasks.wait([b, a], timeout=0.1, loop=loop)) asyncio.wait([b, a], timeout=0.1, loop=loop))
self.assertEqual(done, set([a])) self.assertEqual(done, set([a]))
self.assertEqual(pending, set([b])) self.assertEqual(pending, set([b]))
...@@ -698,7 +696,7 @@ class TaskTests(unittest.TestCase): ...@@ -698,7 +696,7 @@ class TaskTests(unittest.TestCase):
# move forward to close generator # move forward to close generator
loop.advance_time(10) loop.advance_time(10)
loop.run_until_complete(tasks.wait([a, b], loop=loop)) loop.run_until_complete(asyncio.wait([a, b], loop=loop))
def test_as_completed(self): def test_as_completed(self):
...@@ -713,10 +711,10 @@ class TaskTests(unittest.TestCase): ...@@ -713,10 +711,10 @@ class TaskTests(unittest.TestCase):
completed = set() completed = set()
time_shifted = False time_shifted = False
@tasks.coroutine @asyncio.coroutine
def sleeper(dt, x): def sleeper(dt, x):
nonlocal time_shifted nonlocal time_shifted
yield from tasks.sleep(dt, loop=loop) yield from asyncio.sleep(dt, loop=loop)
completed.add(x) completed.add(x)
if not time_shifted and 'a' in completed and 'b' in completed: if not time_shifted and 'a' in completed and 'b' in completed:
time_shifted = True time_shifted = True
...@@ -727,21 +725,21 @@ class TaskTests(unittest.TestCase): ...@@ -727,21 +725,21 @@ class TaskTests(unittest.TestCase):
b = sleeper(0.01, 'b') b = sleeper(0.01, 'b')
c = sleeper(0.15, 'c') c = sleeper(0.15, 'c')
@tasks.coroutine @asyncio.coroutine
def foo(): def foo():
values = [] values = []
for f in tasks.as_completed([b, c, a], loop=loop): for f in asyncio.as_completed([b, c, a], loop=loop):
values.append((yield from f)) values.append((yield from f))
return values return values
res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
self.assertAlmostEqual(0.15, loop.time()) self.assertAlmostEqual(0.15, loop.time())
self.assertTrue('a' in res[:2]) self.assertTrue('a' in res[:2])
self.assertTrue('b' in res[:2]) self.assertTrue('b' in res[:2])
self.assertEqual(res[2], 'c') self.assertEqual(res[2], 'c')
# Doing it again should take no time and exercise a different path. # Doing it again should take no time and exercise a different path.
res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
self.assertAlmostEqual(0.15, loop.time()) self.assertAlmostEqual(0.15, loop.time())
def test_as_completed_with_timeout(self): def test_as_completed_with_timeout(self):
...@@ -760,30 +758,30 @@ class TaskTests(unittest.TestCase): ...@@ -760,30 +758,30 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
a = tasks.sleep(0.1, 'a', loop=loop) a = asyncio.sleep(0.1, 'a', loop=loop)
b = tasks.sleep(0.15, 'b', loop=loop) b = asyncio.sleep(0.15, 'b', loop=loop)
@tasks.coroutine @asyncio.coroutine
def foo(): def foo():
values = [] values = []
for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): for f in asyncio.as_completed([a, b], timeout=0.12, loop=loop):
try: try:
v = yield from f v = yield from f
values.append((1, v)) values.append((1, v))
except futures.TimeoutError as exc: except asyncio.TimeoutError as exc:
values.append((2, exc)) values.append((2, exc))
return values return values
res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) res = loop.run_until_complete(asyncio.Task(foo(), loop=loop))
self.assertEqual(len(res), 2, res) self.assertEqual(len(res), 2, res)
self.assertEqual(res[0], (1, 'a')) self.assertEqual(res[0], (1, 'a'))
self.assertEqual(res[1][0], 2) self.assertEqual(res[1][0], 2)
self.assertIsInstance(res[1][1], futures.TimeoutError) self.assertIsInstance(res[1][1], asyncio.TimeoutError)
self.assertAlmostEqual(0.12, loop.time()) self.assertAlmostEqual(0.12, loop.time())
# move forward to close generator # move forward to close generator
loop.advance_time(10) loop.advance_time(10)
loop.run_until_complete(tasks.wait([a, b], loop=loop)) loop.run_until_complete(asyncio.wait([a, b], loop=loop))
def test_as_completed_reverse_wait(self): def test_as_completed_reverse_wait(self):
...@@ -795,10 +793,10 @@ class TaskTests(unittest.TestCase): ...@@ -795,10 +793,10 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
a = tasks.sleep(0.05, 'a', loop=loop) a = asyncio.sleep(0.05, 'a', loop=loop)
b = tasks.sleep(0.10, 'b', loop=loop) b = asyncio.sleep(0.10, 'b', loop=loop)
fs = {a, b} fs = {a, b}
futs = list(tasks.as_completed(fs, loop=loop)) futs = list(asyncio.as_completed(fs, loop=loop))
self.assertEqual(len(futs), 2) self.assertEqual(len(futs), 2)
x = loop.run_until_complete(futs[1]) x = loop.run_until_complete(futs[1])
...@@ -821,12 +819,12 @@ class TaskTests(unittest.TestCase): ...@@ -821,12 +819,12 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
a = tasks.sleep(0.05, 'a', loop=loop) a = asyncio.sleep(0.05, 'a', loop=loop)
b = tasks.sleep(0.05, 'b', loop=loop) b = asyncio.sleep(0.05, 'b', loop=loop)
fs = {a, b} fs = {a, b}
futs = list(tasks.as_completed(fs, loop=loop)) futs = list(asyncio.as_completed(fs, loop=loop))
self.assertEqual(len(futs), 2) self.assertEqual(len(futs), 2)
waiter = tasks.wait(futs, loop=loop) waiter = asyncio.wait(futs, loop=loop)
done, pending = loop.run_until_complete(waiter) done, pending = loop.run_until_complete(waiter)
self.assertEqual(set(f.result() for f in done), {'a', 'b'}) self.assertEqual(set(f.result() for f in done), {'a', 'b'})
...@@ -842,13 +840,13 @@ class TaskTests(unittest.TestCase): ...@@ -842,13 +840,13 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
@tasks.coroutine @asyncio.coroutine
def sleeper(dt, arg): def sleeper(dt, arg):
yield from tasks.sleep(dt/2, loop=loop) yield from asyncio.sleep(dt/2, loop=loop)
res = yield from tasks.sleep(dt/2, arg, loop=loop) res = yield from asyncio.sleep(dt/2, arg, loop=loop)
return res return res
t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) t = asyncio.Task(sleeper(0.1, 'yeah'), loop=loop)
loop.run_until_complete(t) loop.run_until_complete(t)
self.assertTrue(t.done()) self.assertTrue(t.done())
self.assertEqual(t.result(), 'yeah') self.assertEqual(t.result(), 'yeah')
...@@ -864,7 +862,7 @@ class TaskTests(unittest.TestCase): ...@@ -864,7 +862,7 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop),
loop=loop) loop=loop)
handle = None handle = None
...@@ -898,19 +896,19 @@ class TaskTests(unittest.TestCase): ...@@ -898,19 +896,19 @@ class TaskTests(unittest.TestCase):
sleepfut = None sleepfut = None
@tasks.coroutine @asyncio.coroutine
def sleep(dt): def sleep(dt):
nonlocal sleepfut nonlocal sleepfut
sleepfut = tasks.sleep(dt, loop=loop) sleepfut = asyncio.sleep(dt, loop=loop)
yield from sleepfut yield from sleepfut
@tasks.coroutine @asyncio.coroutine
def doit(): def doit():
sleeper = tasks.Task(sleep(5000), loop=loop) sleeper = asyncio.Task(sleep(5000), loop=loop)
loop.call_later(0.1, sleeper.cancel) loop.call_later(0.1, sleeper.cancel)
try: try:
yield from sleeper yield from sleeper
except futures.CancelledError: except asyncio.CancelledError:
return 'cancelled' return 'cancelled'
else: else:
return 'slept in' return 'slept in'
...@@ -920,37 +918,37 @@ class TaskTests(unittest.TestCase): ...@@ -920,37 +918,37 @@ class TaskTests(unittest.TestCase):
self.assertAlmostEqual(0.1, loop.time()) self.assertAlmostEqual(0.1, loop.time())
def test_task_cancel_waiter_future(self): def test_task_cancel_waiter_future(self):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def coro(): def coro():
yield from fut yield from fut
task = tasks.Task(coro(), loop=self.loop) task = asyncio.Task(coro(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertIs(task._fut_waiter, fut) self.assertIs(task._fut_waiter, fut)
task.cancel() task.cancel()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertRaises( self.assertRaises(
futures.CancelledError, self.loop.run_until_complete, task) asyncio.CancelledError, self.loop.run_until_complete, task)
self.assertIsNone(task._fut_waiter) self.assertIsNone(task._fut_waiter)
self.assertTrue(fut.cancelled()) self.assertTrue(fut.cancelled())
def test_step_in_completed_task(self): def test_step_in_completed_task(self):
@tasks.coroutine @asyncio.coroutine
def notmuch(): def notmuch():
return 'ko' return 'ko'
gen = notmuch() gen = notmuch()
task = tasks.Task(gen, loop=self.loop) task = asyncio.Task(gen, loop=self.loop)
task.set_result('ok') task.set_result('ok')
self.assertRaises(AssertionError, task._step) self.assertRaises(AssertionError, task._step)
gen.close() gen.close()
def test_step_result(self): def test_step_result(self):
@tasks.coroutine @asyncio.coroutine
def notmuch(): def notmuch():
yield None yield None
yield 1 yield 1
...@@ -962,7 +960,7 @@ class TaskTests(unittest.TestCase): ...@@ -962,7 +960,7 @@ class TaskTests(unittest.TestCase):
def test_step_result_future(self): def test_step_result_future(self):
# If coroutine returns future, task waits on this future. # If coroutine returns future, task waits on this future.
class Fut(futures.Future): class Fut(asyncio.Future):
def __init__(self, *args, **kwds): def __init__(self, *args, **kwds):
self.cb_added = False self.cb_added = False
super().__init__(*args, **kwds) super().__init__(*args, **kwds)
...@@ -974,12 +972,12 @@ class TaskTests(unittest.TestCase): ...@@ -974,12 +972,12 @@ class TaskTests(unittest.TestCase):
fut = Fut(loop=self.loop) fut = Fut(loop=self.loop)
result = None result = None
@tasks.coroutine @asyncio.coroutine
def wait_for_future(): def wait_for_future():
nonlocal result nonlocal result
result = yield from fut result = yield from fut
t = tasks.Task(wait_for_future(), loop=self.loop) t = asyncio.Task(wait_for_future(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertTrue(fut.cb_added) self.assertTrue(fut.cb_added)
...@@ -991,11 +989,11 @@ class TaskTests(unittest.TestCase): ...@@ -991,11 +989,11 @@ class TaskTests(unittest.TestCase):
self.assertIsNone(t.result()) self.assertIsNone(t.result())
def test_step_with_baseexception(self): def test_step_with_baseexception(self):
@tasks.coroutine @asyncio.coroutine
def notmutch(): def notmutch():
raise BaseException() raise BaseException()
task = tasks.Task(notmutch(), loop=self.loop) task = asyncio.Task(notmutch(), loop=self.loop)
self.assertRaises(BaseException, task._step) self.assertRaises(BaseException, task._step)
self.assertTrue(task.done()) self.assertTrue(task.done())
...@@ -1011,20 +1009,20 @@ class TaskTests(unittest.TestCase): ...@@ -1011,20 +1009,20 @@ class TaskTests(unittest.TestCase):
loop = test_utils.TestLoop(gen) loop = test_utils.TestLoop(gen)
self.addCleanup(loop.close) self.addCleanup(loop.close)
@tasks.coroutine @asyncio.coroutine
def sleeper(): def sleeper():
yield from tasks.sleep(10, loop=loop) yield from asyncio.sleep(10, loop=loop)
base_exc = BaseException() base_exc = BaseException()
@tasks.coroutine @asyncio.coroutine
def notmutch(): def notmutch():
try: try:
yield from sleeper() yield from sleeper()
except futures.CancelledError: except asyncio.CancelledError:
raise base_exc raise base_exc
task = tasks.Task(notmutch(), loop=loop) task = asyncio.Task(notmutch(), loop=loop)
test_utils.run_briefly(loop) test_utils.run_briefly(loop)
task.cancel() task.cancel()
...@@ -1040,21 +1038,21 @@ class TaskTests(unittest.TestCase): ...@@ -1040,21 +1038,21 @@ class TaskTests(unittest.TestCase):
def fn(): def fn():
pass pass
self.assertFalse(tasks.iscoroutinefunction(fn)) self.assertFalse(asyncio.iscoroutinefunction(fn))
def fn1(): def fn1():
yield yield
self.assertFalse(tasks.iscoroutinefunction(fn1)) self.assertFalse(asyncio.iscoroutinefunction(fn1))
@tasks.coroutine @asyncio.coroutine
def fn2(): def fn2():
yield yield
self.assertTrue(tasks.iscoroutinefunction(fn2)) self.assertTrue(asyncio.iscoroutinefunction(fn2))
def test_yield_vs_yield_from(self): def test_yield_vs_yield_from(self):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def wait_for_future(): def wait_for_future():
yield fut yield fut
...@@ -1065,11 +1063,11 @@ class TaskTests(unittest.TestCase): ...@@ -1065,11 +1063,11 @@ class TaskTests(unittest.TestCase):
self.assertFalse(fut.done()) self.assertFalse(fut.done())
def test_yield_vs_yield_from_generator(self): def test_yield_vs_yield_from_generator(self):
@tasks.coroutine @asyncio.coroutine
def coro(): def coro():
yield yield
@tasks.coroutine @asyncio.coroutine
def wait_for_future(): def wait_for_future():
gen = coro() gen = coro()
try: try:
...@@ -1083,72 +1081,72 @@ class TaskTests(unittest.TestCase): ...@@ -1083,72 +1081,72 @@ class TaskTests(unittest.TestCase):
self.loop.run_until_complete, task) self.loop.run_until_complete, task)
def test_coroutine_non_gen_function(self): def test_coroutine_non_gen_function(self):
@tasks.coroutine @asyncio.coroutine
def func(): def func():
return 'test' return 'test'
self.assertTrue(tasks.iscoroutinefunction(func)) self.assertTrue(asyncio.iscoroutinefunction(func))
coro = func() coro = func()
self.assertTrue(tasks.iscoroutine(coro)) self.assertTrue(asyncio.iscoroutine(coro))
res = self.loop.run_until_complete(coro) res = self.loop.run_until_complete(coro)
self.assertEqual(res, 'test') self.assertEqual(res, 'test')
def test_coroutine_non_gen_function_return_future(self): def test_coroutine_non_gen_function_return_future(self):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def func(): def func():
return fut return fut
@tasks.coroutine @asyncio.coroutine
def coro(): def coro():
fut.set_result('test') fut.set_result('test')
t1 = tasks.Task(func(), loop=self.loop) t1 = asyncio.Task(func(), loop=self.loop)
t2 = tasks.Task(coro(), loop=self.loop) t2 = asyncio.Task(coro(), loop=self.loop)
res = self.loop.run_until_complete(t1) res = self.loop.run_until_complete(t1)
self.assertEqual(res, 'test') self.assertEqual(res, 'test')
self.assertIsNone(t2.result()) self.assertIsNone(t2.result())
def test_current_task(self): def test_current_task(self):
self.assertIsNone(tasks.Task.current_task(loop=self.loop)) self.assertIsNone(asyncio.Task.current_task(loop=self.loop))
@tasks.coroutine @asyncio.coroutine
def coro(loop): def coro(loop):
self.assertTrue(tasks.Task.current_task(loop=loop) is task) self.assertTrue(asyncio.Task.current_task(loop=loop) is task)
task = tasks.Task(coro(self.loop), loop=self.loop) task = asyncio.Task(coro(self.loop), loop=self.loop)
self.loop.run_until_complete(task) self.loop.run_until_complete(task)
self.assertIsNone(tasks.Task.current_task(loop=self.loop)) self.assertIsNone(asyncio.Task.current_task(loop=self.loop))
def test_current_task_with_interleaving_tasks(self): def test_current_task_with_interleaving_tasks(self):
self.assertIsNone(tasks.Task.current_task(loop=self.loop)) self.assertIsNone(asyncio.Task.current_task(loop=self.loop))
fut1 = futures.Future(loop=self.loop) fut1 = asyncio.Future(loop=self.loop)
fut2 = futures.Future(loop=self.loop) fut2 = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def coro1(loop): def coro1(loop):
self.assertTrue(tasks.Task.current_task(loop=loop) is task1) self.assertTrue(asyncio.Task.current_task(loop=loop) is task1)
yield from fut1 yield from fut1
self.assertTrue(tasks.Task.current_task(loop=loop) is task1) self.assertTrue(asyncio.Task.current_task(loop=loop) is task1)
fut2.set_result(True) fut2.set_result(True)
@tasks.coroutine @asyncio.coroutine
def coro2(loop): def coro2(loop):
self.assertTrue(tasks.Task.current_task(loop=loop) is task2) self.assertTrue(asyncio.Task.current_task(loop=loop) is task2)
fut1.set_result(True) fut1.set_result(True)
yield from fut2 yield from fut2
self.assertTrue(tasks.Task.current_task(loop=loop) is task2) self.assertTrue(asyncio.Task.current_task(loop=loop) is task2)
task1 = tasks.Task(coro1(self.loop), loop=self.loop) task1 = asyncio.Task(coro1(self.loop), loop=self.loop)
task2 = tasks.Task(coro2(self.loop), loop=self.loop) task2 = asyncio.Task(coro2(self.loop), loop=self.loop)
self.loop.run_until_complete(tasks.wait((task1, task2), self.loop.run_until_complete(asyncio.wait((task1, task2),
loop=self.loop)) loop=self.loop))
self.assertIsNone(tasks.Task.current_task(loop=self.loop)) self.assertIsNone(asyncio.Task.current_task(loop=self.loop))
# Some thorough tests for cancellation propagation through # Some thorough tests for cancellation propagation through
# coroutines, tasks and wait(). # coroutines, tasks and wait().
...@@ -1156,30 +1154,30 @@ class TaskTests(unittest.TestCase): ...@@ -1156,30 +1154,30 @@ class TaskTests(unittest.TestCase):
def test_yield_future_passes_cancel(self): def test_yield_future_passes_cancel(self):
# Cancelling outer() cancels inner() cancels waiter. # Cancelling outer() cancels inner() cancels waiter.
proof = 0 proof = 0
waiter = futures.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def inner(): def inner():
nonlocal proof nonlocal proof
try: try:
yield from waiter yield from waiter
except futures.CancelledError: except asyncio.CancelledError:
proof += 1 proof += 1
raise raise
else: else:
self.fail('got past sleep() in inner()') self.fail('got past sleep() in inner()')
@tasks.coroutine @asyncio.coroutine
def outer(): def outer():
nonlocal proof nonlocal proof
try: try:
yield from inner() yield from inner()
except futures.CancelledError: except asyncio.CancelledError:
proof += 100 # Expect this path. proof += 100 # Expect this path.
else: else:
proof += 10 proof += 10
f = tasks.async(outer(), loop=self.loop) f = asyncio.async(outer(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
f.cancel() f.cancel()
self.loop.run_until_complete(f) self.loop.run_until_complete(f)
...@@ -1190,39 +1188,39 @@ class TaskTests(unittest.TestCase): ...@@ -1190,39 +1188,39 @@ class TaskTests(unittest.TestCase):
# Cancelling outer() makes wait() return early, leaves inner() # Cancelling outer() makes wait() return early, leaves inner()
# running. # running.
proof = 0 proof = 0
waiter = futures.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def inner(): def inner():
nonlocal proof nonlocal proof
yield from waiter yield from waiter
proof += 1 proof += 1
@tasks.coroutine @asyncio.coroutine
def outer(): def outer():
nonlocal proof nonlocal proof
d, p = yield from tasks.wait([inner()], loop=self.loop) d, p = yield from asyncio.wait([inner()], loop=self.loop)
proof += 100 proof += 100
f = tasks.async(outer(), loop=self.loop) f = asyncio.async(outer(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
f.cancel() f.cancel()
self.assertRaises( self.assertRaises(
futures.CancelledError, self.loop.run_until_complete, f) asyncio.CancelledError, self.loop.run_until_complete, f)
waiter.set_result(None) waiter.set_result(None)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual(proof, 1) self.assertEqual(proof, 1)
def test_shield_result(self): def test_shield_result(self):
inner = futures.Future(loop=self.loop) inner = asyncio.Future(loop=self.loop)
outer = tasks.shield(inner) outer = asyncio.shield(inner)
inner.set_result(42) inner.set_result(42)
res = self.loop.run_until_complete(outer) res = self.loop.run_until_complete(outer)
self.assertEqual(res, 42) self.assertEqual(res, 42)
def test_shield_exception(self): def test_shield_exception(self):
inner = futures.Future(loop=self.loop) inner = asyncio.Future(loop=self.loop)
outer = tasks.shield(inner) outer = asyncio.shield(inner)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
exc = RuntimeError('expected') exc = RuntimeError('expected')
inner.set_exception(exc) inner.set_exception(exc)
...@@ -1230,50 +1228,50 @@ class TaskTests(unittest.TestCase): ...@@ -1230,50 +1228,50 @@ class TaskTests(unittest.TestCase):
self.assertIs(outer.exception(), exc) self.assertIs(outer.exception(), exc)
def test_shield_cancel(self): def test_shield_cancel(self):
inner = futures.Future(loop=self.loop) inner = asyncio.Future(loop=self.loop)
outer = tasks.shield(inner) outer = asyncio.shield(inner)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
inner.cancel() inner.cancel()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertTrue(outer.cancelled()) self.assertTrue(outer.cancelled())
def test_shield_shortcut(self): def test_shield_shortcut(self):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
fut.set_result(42) fut.set_result(42)
res = self.loop.run_until_complete(tasks.shield(fut)) res = self.loop.run_until_complete(asyncio.shield(fut))
self.assertEqual(res, 42) self.assertEqual(res, 42)
def test_shield_effect(self): def test_shield_effect(self):
# Cancelling outer() does not affect inner(). # Cancelling outer() does not affect inner().
proof = 0 proof = 0
waiter = futures.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
@tasks.coroutine @asyncio.coroutine
def inner(): def inner():
nonlocal proof nonlocal proof
yield from waiter yield from waiter
proof += 1 proof += 1
@tasks.coroutine @asyncio.coroutine
def outer(): def outer():
nonlocal proof nonlocal proof
yield from tasks.shield(inner(), loop=self.loop) yield from asyncio.shield(inner(), loop=self.loop)
proof += 100 proof += 100
f = tasks.async(outer(), loop=self.loop) f = asyncio.async(outer(), loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
f.cancel() f.cancel()
with self.assertRaises(futures.CancelledError): with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(f) self.loop.run_until_complete(f)
waiter.set_result(None) waiter.set_result(None)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertEqual(proof, 1) self.assertEqual(proof, 1)
def test_shield_gather(self): def test_shield_gather(self):
child1 = futures.Future(loop=self.loop) child1 = asyncio.Future(loop=self.loop)
child2 = futures.Future(loop=self.loop) child2 = asyncio.Future(loop=self.loop)
parent = tasks.gather(child1, child2, loop=self.loop) parent = asyncio.gather(child1, child2, loop=self.loop)
outer = tasks.shield(parent, loop=self.loop) outer = asyncio.shield(parent, loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
outer.cancel() outer.cancel()
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
...@@ -1284,16 +1282,16 @@ class TaskTests(unittest.TestCase): ...@@ -1284,16 +1282,16 @@ class TaskTests(unittest.TestCase):
self.assertEqual(parent.result(), [1, 2]) self.assertEqual(parent.result(), [1, 2])
def test_gather_shield(self): def test_gather_shield(self):
child1 = futures.Future(loop=self.loop) child1 = asyncio.Future(loop=self.loop)
child2 = futures.Future(loop=self.loop) child2 = asyncio.Future(loop=self.loop)
inner1 = tasks.shield(child1, loop=self.loop) inner1 = asyncio.shield(child1, loop=self.loop)
inner2 = tasks.shield(child2, loop=self.loop) inner2 = asyncio.shield(child2, loop=self.loop)
parent = tasks.gather(inner1, inner2, loop=self.loop) parent = asyncio.gather(inner1, inner2, loop=self.loop)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
parent.cancel() parent.cancel()
# This should cancel inner1 and inner2 but bot child1 and child2. # This should cancel inner1 and inner2 but bot child1 and child2.
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.assertIsInstance(parent.exception(), futures.CancelledError) self.assertIsInstance(parent.exception(), asyncio.CancelledError)
self.assertTrue(inner1.cancelled()) self.assertTrue(inner1.cancelled())
self.assertTrue(inner2.cancelled()) self.assertTrue(inner2.cancelled())
child1.set_result(1) child1.set_result(1)
...@@ -1316,8 +1314,8 @@ class GatherTestsBase: ...@@ -1316,8 +1314,8 @@ class GatherTestsBase:
test_utils.run_briefly(loop) test_utils.run_briefly(loop)
def _check_success(self, **kwargs): def _check_success(self, **kwargs):
a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)] a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs) fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
cb = Mock() cb = Mock()
fut.add_done_callback(cb) fut.add_done_callback(cb)
b.set_result(1) b.set_result(1)
...@@ -1338,8 +1336,8 @@ class GatherTestsBase: ...@@ -1338,8 +1336,8 @@ class GatherTestsBase:
self._check_success(return_exceptions=True) self._check_success(return_exceptions=True)
def test_one_exception(self): def test_one_exception(self):
a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = tasks.gather(*self.wrap_futures(a, b, c, d, e)) fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
cb = Mock() cb = Mock()
fut.add_done_callback(cb) fut.add_done_callback(cb)
exc = ZeroDivisionError() exc = ZeroDivisionError()
...@@ -1356,8 +1354,8 @@ class GatherTestsBase: ...@@ -1356,8 +1354,8 @@ class GatherTestsBase:
e.exception() e.exception()
def test_return_exceptions(self): def test_return_exceptions(self):
a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)] a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
fut = tasks.gather(*self.wrap_futures(a, b, c, d), fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
return_exceptions=True) return_exceptions=True)
cb = Mock() cb = Mock()
fut.add_done_callback(cb) fut.add_done_callback(cb)
...@@ -1381,15 +1379,15 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1381,15 +1379,15 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
return futures return futures
def _check_empty_sequence(self, seq_or_iter): def _check_empty_sequence(self, seq_or_iter):
events.set_event_loop(self.one_loop) asyncio.set_event_loop(self.one_loop)
self.addCleanup(events.set_event_loop, None) self.addCleanup(asyncio.set_event_loop, None)
fut = tasks.gather(*seq_or_iter) fut = asyncio.gather(*seq_or_iter)
self.assertIsInstance(fut, futures.Future) self.assertIsInstance(fut, asyncio.Future)
self.assertIs(fut._loop, self.one_loop) self.assertIs(fut._loop, self.one_loop)
self._run_loop(self.one_loop) self._run_loop(self.one_loop)
self.assertTrue(fut.done()) self.assertTrue(fut.done())
self.assertEqual(fut.result(), []) self.assertEqual(fut.result(), [])
fut = tasks.gather(*seq_or_iter, loop=self.other_loop) fut = asyncio.gather(*seq_or_iter, loop=self.other_loop)
self.assertIs(fut._loop, self.other_loop) self.assertIs(fut._loop, self.other_loop)
def test_constructor_empty_sequence(self): def test_constructor_empty_sequence(self):
...@@ -1399,27 +1397,27 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1399,27 +1397,27 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
self._check_empty_sequence(iter("")) self._check_empty_sequence(iter(""))
def test_constructor_heterogenous_futures(self): def test_constructor_heterogenous_futures(self):
fut1 = futures.Future(loop=self.one_loop) fut1 = asyncio.Future(loop=self.one_loop)
fut2 = futures.Future(loop=self.other_loop) fut2 = asyncio.Future(loop=self.other_loop)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tasks.gather(fut1, fut2) asyncio.gather(fut1, fut2)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tasks.gather(fut1, loop=self.other_loop) asyncio.gather(fut1, loop=self.other_loop)
def test_constructor_homogenous_futures(self): def test_constructor_homogenous_futures(self):
children = [futures.Future(loop=self.other_loop) for i in range(3)] children = [asyncio.Future(loop=self.other_loop) for i in range(3)]
fut = tasks.gather(*children) fut = asyncio.gather(*children)
self.assertIs(fut._loop, self.other_loop) self.assertIs(fut._loop, self.other_loop)
self._run_loop(self.other_loop) self._run_loop(self.other_loop)
self.assertFalse(fut.done()) self.assertFalse(fut.done())
fut = tasks.gather(*children, loop=self.other_loop) fut = asyncio.gather(*children, loop=self.other_loop)
self.assertIs(fut._loop, self.other_loop) self.assertIs(fut._loop, self.other_loop)
self._run_loop(self.other_loop) self._run_loop(self.other_loop)
self.assertFalse(fut.done()) self.assertFalse(fut.done())
def test_one_cancellation(self): def test_one_cancellation(self):
a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = tasks.gather(a, b, c, d, e) fut = asyncio.gather(a, b, c, d, e)
cb = Mock() cb = Mock()
fut.add_done_callback(cb) fut.add_done_callback(cb)
a.set_result(1) a.set_result(1)
...@@ -1428,7 +1426,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1428,7 +1426,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
self.assertTrue(fut.done()) self.assertTrue(fut.done())
cb.assert_called_once_with(fut) cb.assert_called_once_with(fut)
self.assertFalse(fut.cancelled()) self.assertFalse(fut.cancelled())
self.assertIsInstance(fut.exception(), futures.CancelledError) self.assertIsInstance(fut.exception(), asyncio.CancelledError)
# Does nothing # Does nothing
c.set_result(3) c.set_result(3)
d.cancel() d.cancel()
...@@ -1436,9 +1434,9 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1436,9 +1434,9 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
e.exception() e.exception()
def test_result_exception_one_cancellation(self): def test_result_exception_one_cancellation(self):
a, b, c, d, e, f = [futures.Future(loop=self.one_loop) a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
for i in range(6)] for i in range(6)]
fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True) fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
cb = Mock() cb = Mock()
fut.add_done_callback(cb) fut.add_done_callback(cb)
a.set_result(1) a.set_result(1)
...@@ -1452,8 +1450,8 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1452,8 +1450,8 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
rte = RuntimeError() rte = RuntimeError()
f.set_exception(rte) f.set_exception(rte)
res = self.one_loop.run_until_complete(fut) res = self.one_loop.run_until_complete(fut)
self.assertIsInstance(res[2], futures.CancelledError) self.assertIsInstance(res[2], asyncio.CancelledError)
self.assertIsInstance(res[4], futures.CancelledError) self.assertIsInstance(res[4], asyncio.CancelledError)
res[2] = res[4] = None res[2] = res[4] = None
self.assertEqual(res, [1, zde, None, 3, None, rte]) self.assertEqual(res, [1, zde, None, 3, None, rte])
cb.assert_called_once_with(fut) cb.assert_called_once_with(fut)
...@@ -1463,34 +1461,34 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1463,34 +1461,34 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
events.set_event_loop(self.one_loop) asyncio.set_event_loop(self.one_loop)
def tearDown(self): def tearDown(self):
events.set_event_loop(None) asyncio.set_event_loop(None)
super().tearDown() super().tearDown()
def wrap_futures(self, *futures): def wrap_futures(self, *futures):
coros = [] coros = []
for fut in futures: for fut in futures:
@tasks.coroutine @asyncio.coroutine
def coro(fut=fut): def coro(fut=fut):
return (yield from fut) return (yield from fut)
coros.append(coro()) coros.append(coro())
return coros return coros
def test_constructor_loop_selection(self): def test_constructor_loop_selection(self):
@tasks.coroutine @asyncio.coroutine
def coro(): def coro():
return 'abc' return 'abc'
gen1 = coro() gen1 = coro()
gen2 = coro() gen2 = coro()
fut = tasks.gather(gen1, gen2) fut = asyncio.gather(gen1, gen2)
self.assertIs(fut._loop, self.one_loop) self.assertIs(fut._loop, self.one_loop)
gen1.close() gen1.close()
gen2.close() gen2.close()
gen3 = coro() gen3 = coro()
gen4 = coro() gen4 = coro()
fut = tasks.gather(gen3, gen4, loop=self.other_loop) fut = asyncio.gather(gen3, gen4, loop=self.other_loop)
self.assertIs(fut._loop, self.other_loop) self.assertIs(fut._loop, self.other_loop)
gen3.close() gen3.close()
gen4.close() gen4.close()
...@@ -1498,29 +1496,29 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1498,29 +1496,29 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase):
def test_cancellation_broadcast(self): def test_cancellation_broadcast(self):
# Cancelling outer() cancels all children. # Cancelling outer() cancels all children.
proof = 0 proof = 0
waiter = futures.Future(loop=self.one_loop) waiter = asyncio.Future(loop=self.one_loop)
@tasks.coroutine @asyncio.coroutine
def inner(): def inner():
nonlocal proof nonlocal proof
yield from waiter yield from waiter
proof += 1 proof += 1
child1 = tasks.async(inner(), loop=self.one_loop) child1 = asyncio.async(inner(), loop=self.one_loop)
child2 = tasks.async(inner(), loop=self.one_loop) child2 = asyncio.async(inner(), loop=self.one_loop)
gatherer = None gatherer = None
@tasks.coroutine @asyncio.coroutine
def outer(): def outer():
nonlocal proof, gatherer nonlocal proof, gatherer
gatherer = tasks.gather(child1, child2, loop=self.one_loop) gatherer = asyncio.gather(child1, child2, loop=self.one_loop)
yield from gatherer yield from gatherer
proof += 100 proof += 100
f = tasks.async(outer(), loop=self.one_loop) f = asyncio.async(outer(), loop=self.one_loop)
test_utils.run_briefly(self.one_loop) test_utils.run_briefly(self.one_loop)
self.assertTrue(f.cancel()) self.assertTrue(f.cancel())
with self.assertRaises(futures.CancelledError): with self.assertRaises(asyncio.CancelledError):
self.one_loop.run_until_complete(f) self.one_loop.run_until_complete(f)
self.assertFalse(gatherer.cancel()) self.assertFalse(gatherer.cancel())
self.assertTrue(waiter.cancelled()) self.assertTrue(waiter.cancelled())
...@@ -1532,19 +1530,19 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1532,19 +1530,19 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase):
def test_exception_marking(self): def test_exception_marking(self):
# Test for the first line marked "Mark exception retrieved." # Test for the first line marked "Mark exception retrieved."
@tasks.coroutine @asyncio.coroutine
def inner(f): def inner(f):
yield from f yield from f
raise RuntimeError('should not be ignored') raise RuntimeError('should not be ignored')
a = futures.Future(loop=self.one_loop) a = asyncio.Future(loop=self.one_loop)
b = futures.Future(loop=self.one_loop) b = asyncio.Future(loop=self.one_loop)
@tasks.coroutine @asyncio.coroutine
def outer(): def outer():
yield from tasks.gather(inner(a), inner(b), loop=self.one_loop) yield from asyncio.gather(inner(a), inner(b), loop=self.one_loop)
f = tasks.async(outer(), loop=self.one_loop) f = asyncio.async(outer(), loop=self.one_loop)
test_utils.run_briefly(self.one_loop) test_utils.run_briefly(self.one_loop)
a.set_result(None) a.set_result(None)
test_utils.run_briefly(self.one_loop) test_utils.run_briefly(self.one_loop)
......
...@@ -3,17 +3,17 @@ ...@@ -3,17 +3,17 @@
import unittest import unittest
import unittest.mock import unittest.mock
from asyncio import transports import asyncio
class TransportTests(unittest.TestCase): class TransportTests(unittest.TestCase):
def test_ctor_extra_is_none(self): def test_ctor_extra_is_none(self):
transport = transports.Transport() transport = asyncio.Transport()
self.assertEqual(transport._extra, {}) self.assertEqual(transport._extra, {})
def test_get_extra_info(self): def test_get_extra_info(self):
transport = transports.Transport({'extra': 'info'}) transport = asyncio.Transport({'extra': 'info'})
self.assertEqual('info', transport.get_extra_info('extra')) self.assertEqual('info', transport.get_extra_info('extra'))
self.assertIsNone(transport.get_extra_info('unknown')) self.assertIsNone(transport.get_extra_info('unknown'))
...@@ -21,7 +21,7 @@ class TransportTests(unittest.TestCase): ...@@ -21,7 +21,7 @@ class TransportTests(unittest.TestCase):
self.assertIs(default, transport.get_extra_info('unknown', default)) self.assertIs(default, transport.get_extra_info('unknown', default))
def test_writelines(self): def test_writelines(self):
transport = transports.Transport() transport = asyncio.Transport()
transport.write = unittest.mock.Mock() transport.write = unittest.mock.Mock()
transport.writelines([b'line1', transport.writelines([b'line1',
...@@ -31,7 +31,7 @@ class TransportTests(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TransportTests(unittest.TestCase):
transport.write.assert_called_with(b'line1line2line3') transport.write.assert_called_with(b'line1line2line3')
def test_not_implemented(self): def test_not_implemented(self):
transport = transports.Transport() transport = asyncio.Transport()
self.assertRaises(NotImplementedError, self.assertRaises(NotImplementedError,
transport.set_write_buffer_limits) transport.set_write_buffer_limits)
...@@ -45,13 +45,13 @@ class TransportTests(unittest.TestCase): ...@@ -45,13 +45,13 @@ class TransportTests(unittest.TestCase):
self.assertRaises(NotImplementedError, transport.abort) self.assertRaises(NotImplementedError, transport.abort)
def test_dgram_not_implemented(self): def test_dgram_not_implemented(self):
transport = transports.DatagramTransport() transport = asyncio.DatagramTransport()
self.assertRaises(NotImplementedError, transport.sendto, 'data') self.assertRaises(NotImplementedError, transport.sendto, 'data')
self.assertRaises(NotImplementedError, transport.abort) self.assertRaises(NotImplementedError, transport.abort)
def test_subprocess_transport_not_implemented(self): def test_subprocess_transport_not_implemented(self):
transport = transports.SubprocessTransport() transport = asyncio.SubprocessTransport()
self.assertRaises(NotImplementedError, transport.get_pid) self.assertRaises(NotImplementedError, transport.get_pid)
self.assertRaises(NotImplementedError, transport.get_returncode) self.assertRaises(NotImplementedError, transport.get_returncode)
......
...@@ -17,9 +17,8 @@ if sys.platform == 'win32': ...@@ -17,9 +17,8 @@ if sys.platform == 'win32':
raise unittest.SkipTest('UNIX only') raise unittest.SkipTest('UNIX only')
from asyncio import events import asyncio
from asyncio import futures from asyncio import log
from asyncio import protocols
from asyncio import test_utils from asyncio import test_utils
from asyncio import unix_events from asyncio import unix_events
...@@ -28,8 +27,8 @@ from asyncio import unix_events ...@@ -28,8 +27,8 @@ from asyncio import unix_events
class SelectorEventLoopTests(unittest.TestCase): class SelectorEventLoopTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = unix_events.SelectorEventLoop() self.loop = asyncio.SelectorEventLoop()
events.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
...@@ -44,7 +43,7 @@ class SelectorEventLoopTests(unittest.TestCase): ...@@ -44,7 +43,7 @@ class SelectorEventLoopTests(unittest.TestCase):
self.loop._handle_signal(signal.NSIG + 1, ()) self.loop._handle_signal(signal.NSIG + 1, ())
def test_handle_signal_cancelled_handler(self): def test_handle_signal_cancelled_handler(self):
h = events.Handle(unittest.mock.Mock(), ()) h = asyncio.Handle(unittest.mock.Mock(), ())
h.cancel() h.cancel()
self.loop._signal_handlers[signal.NSIG + 1] = h self.loop._signal_handlers[signal.NSIG + 1] = h
self.loop.remove_signal_handler = unittest.mock.Mock() self.loop.remove_signal_handler = unittest.mock.Mock()
...@@ -68,7 +67,7 @@ class SelectorEventLoopTests(unittest.TestCase): ...@@ -68,7 +67,7 @@ class SelectorEventLoopTests(unittest.TestCase):
cb = lambda: True cb = lambda: True
self.loop.add_signal_handler(signal.SIGHUP, cb) self.loop.add_signal_handler(signal.SIGHUP, cb)
h = self.loop._signal_handlers.get(signal.SIGHUP) h = self.loop._signal_handlers.get(signal.SIGHUP)
self.assertIsInstance(h, events.Handle) self.assertIsInstance(h, asyncio.Handle)
self.assertEqual(h._callback, cb) self.assertEqual(h._callback, cb)
@unittest.mock.patch('asyncio.unix_events.signal') @unittest.mock.patch('asyncio.unix_events.signal')
...@@ -205,7 +204,7 @@ class UnixReadPipeTransportTests(unittest.TestCase): ...@@ -205,7 +204,7 @@ class UnixReadPipeTransportTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
self.protocol = test_utils.make_test_protocol(protocols.Protocol) self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5 self.pipe.fileno.return_value = 5
...@@ -228,7 +227,7 @@ class UnixReadPipeTransportTests(unittest.TestCase): ...@@ -228,7 +227,7 @@ class UnixReadPipeTransportTests(unittest.TestCase):
self.protocol.connection_made.assert_called_with(tr) self.protocol.connection_made.assert_called_with(tr)
def test_ctor_with_waiter(self): def test_ctor_with_waiter(self):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
unix_events._UnixReadPipeTransport( unix_events._UnixReadPipeTransport(
self.loop, self.pipe, self.protocol, fut) self.loop, self.pipe, self.protocol, fut)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
...@@ -368,7 +367,7 @@ class UnixWritePipeTransportTests(unittest.TestCase): ...@@ -368,7 +367,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
self.pipe.fileno.return_value = 5 self.pipe.fileno.return_value = 5
...@@ -391,7 +390,7 @@ class UnixWritePipeTransportTests(unittest.TestCase): ...@@ -391,7 +390,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
self.protocol.connection_made.assert_called_with(tr) self.protocol.connection_made.assert_called_with(tr)
def test_ctor_with_waiter(self): def test_ctor_with_waiter(self):
fut = futures.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
tr = unix_events._UnixWritePipeTransport( tr = unix_events._UnixWritePipeTransport(
self.loop, self.pipe, self.protocol, fut) self.loop, self.pipe, self.protocol, fut)
self.loop.assert_reader(5, tr._read_ready) self.loop.assert_reader(5, tr._read_ready)
...@@ -682,7 +681,7 @@ class AbstractChildWatcherTests(unittest.TestCase): ...@@ -682,7 +681,7 @@ class AbstractChildWatcherTests(unittest.TestCase):
def test_not_implemented(self): def test_not_implemented(self):
f = unittest.mock.Mock() f = unittest.mock.Mock()
watcher = unix_events.AbstractChildWatcher() watcher = asyncio.AbstractChildWatcher()
self.assertRaises( self.assertRaises(
NotImplementedError, watcher.add_child_handler, f, f) NotImplementedError, watcher.add_child_handler, f, f)
self.assertRaises( self.assertRaises(
...@@ -717,7 +716,7 @@ WaitPidMocks = collections.namedtuple("WaitPidMocks", ...@@ -717,7 +716,7 @@ WaitPidMocks = collections.namedtuple("WaitPidMocks",
class ChildWatcherTestsMixin: class ChildWatcherTestsMixin:
ignore_warnings = unittest.mock.patch.object(unix_events.logger, "warning") ignore_warnings = unittest.mock.patch.object(log.logger, "warning")
def setUp(self): def setUp(self):
self.loop = test_utils.TestLoop() self.loop = test_utils.TestLoop()
...@@ -730,7 +729,7 @@ class ChildWatcherTestsMixin: ...@@ -730,7 +729,7 @@ class ChildWatcherTestsMixin:
self.watcher.attach_loop(self.loop) self.watcher.attach_loop(self.loop)
def waitpid(self, pid, flags): def waitpid(self, pid, flags):
if isinstance(self.watcher, unix_events.SafeChildWatcher) or pid != -1: if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1:
self.assertGreater(pid, 0) self.assertGreater(pid, 0)
try: try:
if pid < 0: if pid < 0:
...@@ -1205,7 +1204,7 @@ class ChildWatcherTestsMixin: ...@@ -1205,7 +1204,7 @@ class ChildWatcherTestsMixin:
# raise an exception # raise an exception
m.waitpid.side_effect = ValueError m.waitpid.side_effect = ValueError
with unittest.mock.patch.object(unix_events.logger, with unittest.mock.patch.object(log.logger,
"exception") as m_exception: "exception") as m_exception:
self.assertEqual(self.watcher._sig_chld(), None) self.assertEqual(self.watcher._sig_chld(), None)
...@@ -1240,7 +1239,7 @@ class ChildWatcherTestsMixin: ...@@ -1240,7 +1239,7 @@ class ChildWatcherTestsMixin:
self.watcher._sig_chld() self.watcher._sig_chld()
callback.assert_called(m.waitpid) callback.assert_called(m.waitpid)
if isinstance(self.watcher, unix_events.FastChildWatcher): if isinstance(self.watcher, asyncio.FastChildWatcher):
# here the FastChildWatche enters a deadlock # here the FastChildWatche enters a deadlock
# (there is no way to prevent it) # (there is no way to prevent it)
self.assertFalse(callback.called) self.assertFalse(callback.called)
...@@ -1380,7 +1379,7 @@ class ChildWatcherTestsMixin: ...@@ -1380,7 +1379,7 @@ class ChildWatcherTestsMixin:
self.watcher.add_child_handler(64, callback1) self.watcher.add_child_handler(64, callback1)
self.assertEqual(len(self.watcher._callbacks), 1) self.assertEqual(len(self.watcher._callbacks), 1)
if isinstance(self.watcher, unix_events.FastChildWatcher): if isinstance(self.watcher, asyncio.FastChildWatcher):
self.assertEqual(len(self.watcher._zombies), 1) self.assertEqual(len(self.watcher._zombies), 1)
with unittest.mock.patch.object( with unittest.mock.patch.object(
...@@ -1392,31 +1391,31 @@ class ChildWatcherTestsMixin: ...@@ -1392,31 +1391,31 @@ class ChildWatcherTestsMixin:
m_remove_signal_handler.assert_called_once_with( m_remove_signal_handler.assert_called_once_with(
signal.SIGCHLD) signal.SIGCHLD)
self.assertFalse(self.watcher._callbacks) self.assertFalse(self.watcher._callbacks)
if isinstance(self.watcher, unix_events.FastChildWatcher): if isinstance(self.watcher, asyncio.FastChildWatcher):
self.assertFalse(self.watcher._zombies) self.assertFalse(self.watcher._zombies)
class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
def create_watcher(self): def create_watcher(self):
return unix_events.SafeChildWatcher() return asyncio.SafeChildWatcher()
class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
def create_watcher(self): def create_watcher(self):
return unix_events.FastChildWatcher() return asyncio.FastChildWatcher()
class PolicyTests(unittest.TestCase): class PolicyTests(unittest.TestCase):
def create_policy(self): def create_policy(self):
return unix_events.DefaultEventLoopPolicy() return asyncio.DefaultEventLoopPolicy()
def test_get_child_watcher(self): def test_get_child_watcher(self):
policy = self.create_policy() policy = self.create_policy()
self.assertIsNone(policy._watcher) self.assertIsNone(policy._watcher)
watcher = policy.get_child_watcher() watcher = policy.get_child_watcher()
self.assertIsInstance(watcher, unix_events.SafeChildWatcher) self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
self.assertIs(policy._watcher, watcher) self.assertIs(policy._watcher, watcher)
...@@ -1425,7 +1424,7 @@ class PolicyTests(unittest.TestCase): ...@@ -1425,7 +1424,7 @@ class PolicyTests(unittest.TestCase):
def test_get_child_watcher_after_set(self): def test_get_child_watcher_after_set(self):
policy = self.create_policy() policy = self.create_policy()
watcher = unix_events.FastChildWatcher() watcher = asyncio.FastChildWatcher()
policy.set_child_watcher(watcher) policy.set_child_watcher(watcher)
self.assertIs(policy._watcher, watcher) self.assertIs(policy._watcher, watcher)
...@@ -1438,7 +1437,7 @@ class PolicyTests(unittest.TestCase): ...@@ -1438,7 +1437,7 @@ class PolicyTests(unittest.TestCase):
self.assertIsNone(policy._watcher) self.assertIsNone(policy._watcher)
watcher = policy.get_child_watcher() watcher = policy.get_child_watcher()
self.assertIsInstance(watcher, unix_events.SafeChildWatcher) self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
self.assertIs(watcher._loop, loop) self.assertIs(watcher._loop, loop)
loop.close() loop.close()
...@@ -1449,10 +1448,10 @@ class PolicyTests(unittest.TestCase): ...@@ -1449,10 +1448,10 @@ class PolicyTests(unittest.TestCase):
policy.set_event_loop(policy.new_event_loop()) policy.set_event_loop(policy.new_event_loop())
self.assertIsInstance(policy.get_event_loop(), self.assertIsInstance(policy.get_event_loop(),
events.AbstractEventLoop) asyncio.AbstractEventLoop)
watcher = policy.get_child_watcher() watcher = policy.get_child_watcher()
self.assertIsInstance(watcher, unix_events.SafeChildWatcher) self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
self.assertIsNone(watcher._loop) self.assertIsNone(watcher._loop)
policy.get_event_loop().close() policy.get_event_loop().close()
......
...@@ -8,17 +8,12 @@ if sys.platform != 'win32': ...@@ -8,17 +8,12 @@ if sys.platform != 'win32':
import _winapi import _winapi
import asyncio import asyncio
from asyncio import windows_events
from asyncio import futures
from asyncio import protocols
from asyncio import streams
from asyncio import transports
from asyncio import test_utils from asyncio import test_utils
from asyncio import _overlapped from asyncio import _overlapped
from asyncio import windows_events
class UpperProto(protocols.Protocol): class UpperProto(asyncio.Protocol):
def __init__(self): def __init__(self):
self.buf = [] self.buf = []
...@@ -35,7 +30,7 @@ class UpperProto(protocols.Protocol): ...@@ -35,7 +30,7 @@ class UpperProto(protocols.Protocol):
class ProactorTests(unittest.TestCase): class ProactorTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = windows_events.ProactorEventLoop() self.loop = asyncio.ProactorEventLoop()
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
def tearDown(self): def tearDown(self):
...@@ -44,7 +39,7 @@ class ProactorTests(unittest.TestCase): ...@@ -44,7 +39,7 @@ class ProactorTests(unittest.TestCase):
def test_close(self): def test_close(self):
a, b = self.loop._socketpair() a, b = self.loop._socketpair()
trans = self.loop._make_socket_transport(a, protocols.Protocol()) trans = self.loop._make_socket_transport(a, asyncio.Protocol())
f = asyncio.async(self.loop.sock_recv(b, 100)) f = asyncio.async(self.loop.sock_recv(b, 100))
trans.close() trans.close()
self.loop.run_until_complete(f) self.loop.run_until_complete(f)
...@@ -67,7 +62,7 @@ class ProactorTests(unittest.TestCase): ...@@ -67,7 +62,7 @@ class ProactorTests(unittest.TestCase):
with self.assertRaises(FileNotFoundError): with self.assertRaises(FileNotFoundError):
yield from self.loop.create_pipe_connection( yield from self.loop.create_pipe_connection(
protocols.Protocol, ADDRESS) asyncio.Protocol, ADDRESS)
[server] = yield from self.loop.start_serving_pipe( [server] = yield from self.loop.start_serving_pipe(
UpperProto, ADDRESS) UpperProto, ADDRESS)
...@@ -75,11 +70,11 @@ class ProactorTests(unittest.TestCase): ...@@ -75,11 +70,11 @@ class ProactorTests(unittest.TestCase):
clients = [] clients = []
for i in range(5): for i in range(5):
stream_reader = streams.StreamReader(loop=self.loop) stream_reader = asyncio.StreamReader(loop=self.loop)
protocol = streams.StreamReaderProtocol(stream_reader) protocol = asyncio.StreamReaderProtocol(stream_reader)
trans, proto = yield from self.loop.create_pipe_connection( trans, proto = yield from self.loop.create_pipe_connection(
lambda: protocol, ADDRESS) lambda: protocol, ADDRESS)
self.assertIsInstance(trans, transports.Transport) self.assertIsInstance(trans, asyncio.Transport)
self.assertEqual(protocol, proto) self.assertEqual(protocol, proto)
clients.append((stream_reader, trans)) clients.append((stream_reader, trans))
...@@ -95,7 +90,7 @@ class ProactorTests(unittest.TestCase): ...@@ -95,7 +90,7 @@ class ProactorTests(unittest.TestCase):
with self.assertRaises(FileNotFoundError): with self.assertRaises(FileNotFoundError):
yield from self.loop.create_pipe_connection( yield from self.loop.create_pipe_connection(
protocols.Protocol, ADDRESS) asyncio.Protocol, ADDRESS)
return 'done' return 'done'
...@@ -130,7 +125,7 @@ class ProactorTests(unittest.TestCase): ...@@ -130,7 +125,7 @@ class ProactorTests(unittest.TestCase):
f = self.loop._proactor.wait_for_handle(event, 10) f = self.loop._proactor.wait_for_handle(event, 10)
f.cancel() f.cancel()
start = self.loop.time() start = self.loop.time()
with self.assertRaises(futures.CancelledError): with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(f) self.loop.run_until_complete(f)
elapsed = self.loop.time() - start elapsed = self.loop.time() - start
self.assertTrue(0 <= elapsed < 0.1, elapsed) self.assertTrue(0 <= elapsed < 0.1, elapsed)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment