Commit 00e33723 authored by Yury Selivanov's avatar Yury Selivanov

Issue #24325, #24400: Add more unittests for types.coroutine; tweak wrapper implementation.

parent 66f8828b
"""Tests support for new syntax introduced by PEP 492."""
import collections.abc
import types
import unittest
from test import support
......@@ -164,5 +165,23 @@ class CoroutineTests(BaseTest):
self.loop.run_until_complete(start())
def test_types_coroutine(self):
def gen():
yield from ()
return 'spam'
@types.coroutine
def func():
return gen()
async def coro():
wrapper = func()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
return await wrapper
data = self.loop.run_until_complete(coro())
self.assertEqual(data, 'spam')
if __name__ == '__main__':
unittest.main()
......@@ -7,7 +7,8 @@ import pickle
import locale
import sys
import types
import unittest
import unittest.mock
import weakref
class TypesTests(unittest.TestCase):
......@@ -1191,23 +1192,27 @@ class SimpleNamespaceTests(unittest.TestCase):
class CoroutineTests(unittest.TestCase):
def test_wrong_args(self):
class Foo:
def __call__(self):
pass
def bar(): pass
samples = [None, 1, object()]
for sample in samples:
with self.assertRaisesRegex(TypeError,
'types.coroutine.*expects a callable'):
types.coroutine(sample)
def test_wrong_func(self):
def test_non_gen_values(self):
@types.coroutine
def foo():
return 'spam'
self.assertEqual(foo(), 'spam')
class Awaitable:
def __await__(self):
return ()
aw = Awaitable()
@types.coroutine
def foo():
return aw
self.assertIs(aw, foo())
def test_async_def(self):
# Test that types.coroutine passes 'async def' coroutines
# without modification
......@@ -1263,24 +1268,157 @@ class CoroutineTests(unittest.TestCase):
def send(self): pass
def throw(self): pass
def close(self): pass
def __iter__(self): return self
def __iter__(self): pass
def __next__(self): pass
gen = GenLike()
# Setup generator mock object
gen = unittest.mock.MagicMock(GenLike)
gen.__iter__ = lambda gen: gen
gen.__name__ = 'gen'
gen.__qualname__ = 'test.gen'
self.assertIsInstance(gen, collections.abc.Generator)
self.assertIs(gen, iter(gen))
@types.coroutine
def foo():
return gen
self.assertIs(foo().__await__(), gen)
self.assertTrue(isinstance(foo(), collections.abc.Coroutine))
with self.assertRaises(AttributeError):
foo().gi_code
def foo(): return gen
wrapper = foo()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
self.assertIs(wrapper.__await__(), wrapper)
# Wrapper proxies duck generators completely:
self.assertIs(iter(wrapper), wrapper)
self.assertIsInstance(wrapper, collections.abc.Coroutine)
self.assertIsInstance(wrapper, collections.abc.Awaitable)
self.assertIs(wrapper.__qualname__, gen.__qualname__)
self.assertIs(wrapper.__name__, gen.__name__)
# Test AttributeErrors
for name in {'gi_running', 'gi_frame', 'gi_code',
'cr_running', 'cr_frame', 'cr_code'}:
with self.assertRaises(AttributeError):
getattr(wrapper, name)
# Test attributes pass-through
gen.gi_running = object()
gen.gi_frame = object()
gen.gi_code = object()
self.assertIs(wrapper.gi_running, gen.gi_running)
self.assertIs(wrapper.gi_frame, gen.gi_frame)
self.assertIs(wrapper.gi_code, gen.gi_code)
self.assertIs(wrapper.cr_running, gen.gi_running)
self.assertIs(wrapper.cr_frame, gen.gi_frame)
self.assertIs(wrapper.cr_code, gen.gi_code)
wrapper.close()
gen.close.assert_called_once_with()
wrapper.send(1)
gen.send.assert_called_once_with(1)
wrapper.throw(1, 2, 3)
gen.throw.assert_called_once_with(1, 2, 3)
gen.reset_mock()
wrapper.throw(1, 2)
gen.throw.assert_called_once_with(1, 2)
gen.reset_mock()
wrapper.throw(1)
gen.throw.assert_called_once_with(1)
gen.reset_mock()
# Test exceptions propagation
error = Exception()
gen.throw.side_effect = error
try:
wrapper.throw(1)
except Exception as ex:
self.assertIs(ex, error)
else:
self.fail('wrapper did not propagate an exception')
# Test invalid args
gen.reset_mock()
with self.assertRaises(TypeError):
wrapper.throw()
self.assertFalse(gen.throw.called)
with self.assertRaises(TypeError):
wrapper.close(1)
self.assertFalse(gen.close.called)
with self.assertRaises(TypeError):
wrapper.send()
self.assertFalse(gen.send.called)
# Test that we do not double wrap
@types.coroutine
def bar(): return wrapper
self.assertIs(wrapper, bar())
# Test weakrefs support
ref = weakref.ref(wrapper)
self.assertIs(ref(), wrapper)
def test_duck_functional_gen(self):
class Generator:
"""Emulates the following generator (very clumsy):
def gen(fut):
result = yield fut
return result * 2
"""
def __init__(self, fut):
self._i = 0
self._fut = fut
def __iter__(self):
return self
def __next__(self):
return self.send(None)
def send(self, v):
try:
if self._i == 0:
assert v is None
return self._fut
if self._i == 1:
raise StopIteration(v * 2)
if self._i > 1:
raise StopIteration
finally:
self._i += 1
def throw(self, tp, *exc):
self._i = 100
if tp is not GeneratorExit:
raise tp
def close(self):
self.throw(GeneratorExit)
@types.coroutine
def foo(): return Generator('spam')
wrapper = foo()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
async def corofunc():
return await foo() + 100
coro = corofunc()
self.assertEqual(coro.send(None), 'spam')
try:
coro.send(20)
except StopIteration as ex:
self.assertEqual(ex.args[0], 140)
else:
self.fail('StopIteration was expected')
def test_gen(self):
def gen(): yield
gen = gen()
@types.coroutine
def foo(): return gen
self.assertIs(foo().__await__(), gen)
wrapper = foo()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
self.assertIs(wrapper.__await__(), gen)
for name in ('__name__', '__qualname__', 'gi_code',
'gi_running', 'gi_frame'):
......@@ -1289,19 +1427,8 @@ class CoroutineTests(unittest.TestCase):
self.assertIs(foo().cr_code, gen.gi_code)
def test_genfunc(self):
def gen():
yield
self.assertFalse(isinstance(gen(), collections.abc.Coroutine))
self.assertFalse(isinstance(gen(), collections.abc.Awaitable))
gen_code = gen.__code__
decorated_gen = types.coroutine(gen)
self.assertIs(decorated_gen, gen)
self.assertIsNot(decorated_gen.__code__, gen_code)
decorated_gen2 = types.coroutine(decorated_gen)
self.assertIs(decorated_gen2.__code__, decorated_gen.__code__)
def gen(): yield
self.assertIs(types.coroutine(gen), gen)
self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE)
self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE)
......@@ -1309,10 +1436,27 @@ class CoroutineTests(unittest.TestCase):
g = gen()
self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE)
self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE)
self.assertTrue(isinstance(g, collections.abc.Coroutine))
self.assertTrue(isinstance(g, collections.abc.Awaitable))
self.assertIsInstance(g, collections.abc.Coroutine)
self.assertIsInstance(g, collections.abc.Awaitable)
g.close() # silence warning
self.assertIs(types.coroutine(gen), gen)
def test_wrapper_object(self):
def gen():
yield
@types.coroutine
def coro():
return gen()
wrapper = coro()
self.assertIn('GeneratorWrapper', repr(wrapper))
self.assertEqual(repr(wrapper), str(wrapper))
self.assertTrue(set(dir(wrapper)).issuperset({
'__await__', '__iter__', '__next__', 'cr_code', 'cr_running',
'cr_frame', 'gi_code', 'gi_frame', 'gi_running', 'send',
'close', 'throw'}))
if __name__ == '__main__':
unittest.main()
......@@ -166,6 +166,39 @@ class DynamicClassAttribute:
import functools as _functools
import collections.abc as _collections_abc
class _GeneratorWrapper:
# TODO: Implement this in C.
def __init__(self, gen):
self.__wrapped__ = gen
self.__isgen__ = gen.__class__ is GeneratorType
self.__name__ = getattr(gen, '__name__', None)
self.__qualname__ = getattr(gen, '__qualname__', None)
def send(self, val):
return self.__wrapped__.send(val)
def throw(self, tp, *rest):
return self.__wrapped__.throw(tp, *rest)
def close(self):
return self.__wrapped__.close()
@property
def gi_code(self):
return self.__wrapped__.gi_code
@property
def gi_frame(self):
return self.__wrapped__.gi_frame
@property
def gi_running(self):
return self.__wrapped__.gi_running
cr_code = gi_code
cr_frame = gi_frame
cr_running = gi_running
def __next__(self):
return next(self.__wrapped__)
def __iter__(self):
if self.__isgen__:
return self.__wrapped__
return self
__await__ = __iter__
def coroutine(func):
"""Convert regular generator function to a coroutine."""
......@@ -201,36 +234,6 @@ def coroutine(func):
# return generator-like objects (for instance generators
# compiled with Cython).
class GeneratorWrapper:
def __init__(self, gen):
self.__wrapped__ = gen
self.__name__ = getattr(gen, '__name__', None)
self.__qualname__ = getattr(gen, '__qualname__', None)
def send(self, val):
return self.__wrapped__.send(val)
def throw(self, *args):
return self.__wrapped__.throw(*args)
def close(self):
return self.__wrapped__.close()
@property
def gi_code(self):
return self.__wrapped__.gi_code
@property
def gi_frame(self):
return self.__wrapped__.gi_frame
@property
def gi_running(self):
return self.__wrapped__.gi_running
cr_code = gi_code
cr_frame = gi_frame
cr_running = gi_running
def __next__(self):
return next(self.__wrapped__)
def __iter__(self):
return self.__wrapped__
def __await__(self):
return self.__wrapped__
@_functools.wraps(func)
def wrapped(*args, **kwargs):
coro = func(*args, **kwargs)
......@@ -243,7 +246,7 @@ def coroutine(func):
# 'coro' is either a pure Python generator iterator, or it
# implements collections.abc.Generator (and does not implement
# collections.abc.Coroutine).
return GeneratorWrapper(coro)
return _GeneratorWrapper(coro)
# 'coro' is either an instance of collections.abc.Coroutine or
# some other object -- pass it through.
return coro
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment