Commit 620279b9 authored by Yury Selivanov's avatar Yury Selivanov

asyncio: ensure_future() now understands awaitables

parent e2382c59
...@@ -512,7 +512,7 @@ def async(coro_or_future, *, loop=None): ...@@ -512,7 +512,7 @@ def async(coro_or_future, *, loop=None):
def ensure_future(coro_or_future, *, loop=None): def ensure_future(coro_or_future, *, loop=None):
"""Wrap a coroutine in a future. """Wrap a coroutine or an awaitable in a future.
If the argument is a Future, it is returned directly. If the argument is a Future, it is returned directly.
""" """
...@@ -527,8 +527,20 @@ def ensure_future(coro_or_future, *, loop=None): ...@@ -527,8 +527,20 @@ def ensure_future(coro_or_future, *, loop=None):
if task._source_traceback: if task._source_traceback:
del task._source_traceback[-1] del task._source_traceback[-1]
return task return task
elif compat.PY35 and inspect.isawaitable(coro_or_future):
return ensure_future(_wrap_awaitable(coro_or_future), loop=loop)
else: else:
raise TypeError('A Future or coroutine is required') raise TypeError('A Future, a coroutine or an awaitable is required')
@coroutine
def _wrap_awaitable(awaitable):
"""Helper for asyncio.ensure_future().
Wraps awaitable (an object with __await__) into a coroutine
that will later be wrapped in a Task by ensure_future().
"""
return (yield from awaitable.__await__())
class _GatheringFuture(futures.Future): class _GatheringFuture(futures.Future):
......
...@@ -153,6 +153,24 @@ class TaskTests(test_utils.TestCase): ...@@ -153,6 +153,24 @@ class TaskTests(test_utils.TestCase):
t = asyncio.ensure_future(t_orig, loop=self.loop) t = asyncio.ensure_future(t_orig, loop=self.loop)
self.assertIs(t, t_orig) self.assertIs(t, t_orig)
@unittest.skipUnless(PY35, 'need python 3.5 or later')
def test_ensure_future_awaitable(self):
class Aw:
def __init__(self, coro):
self.coro = coro
def __await__(self):
return (yield from self.coro)
@asyncio.coroutine
def coro():
return 'ok'
loop = asyncio.new_event_loop()
self.set_event_loop(loop)
fut = asyncio.ensure_future(Aw(coro()), loop=loop)
loop.run_until_complete(fut)
assert fut.result() == 'ok'
def test_ensure_future_neither(self): def test_ensure_future_neither(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
asyncio.ensure_future('ok') asyncio.ensure_future('ok')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment