Commit 00e33723 authored by Yury Selivanov's avatar Yury Selivanov

Issue #24325, #24400: Add more unittests for types.coroutine; tweak wrapper implementation.

parent 66f8828b
"""Tests support for new syntax introduced by PEP 492.""" """Tests support for new syntax introduced by PEP 492."""
import collections.abc import collections.abc
import types
import unittest import unittest
from test import support from test import support
...@@ -164,5 +165,23 @@ class CoroutineTests(BaseTest): ...@@ -164,5 +165,23 @@ class CoroutineTests(BaseTest):
self.loop.run_until_complete(start()) self.loop.run_until_complete(start())
def test_types_coroutine(self):
def gen():
yield from ()
return 'spam'
@types.coroutine
def func():
return gen()
async def coro():
wrapper = func()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
return await wrapper
data = self.loop.run_until_complete(coro())
self.assertEqual(data, 'spam')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -7,7 +7,8 @@ import pickle ...@@ -7,7 +7,8 @@ import pickle
import locale import locale
import sys import sys
import types import types
import unittest import unittest.mock
import weakref
class TypesTests(unittest.TestCase): class TypesTests(unittest.TestCase):
...@@ -1191,23 +1192,27 @@ class SimpleNamespaceTests(unittest.TestCase): ...@@ -1191,23 +1192,27 @@ class SimpleNamespaceTests(unittest.TestCase):
class CoroutineTests(unittest.TestCase): class CoroutineTests(unittest.TestCase):
def test_wrong_args(self): def test_wrong_args(self):
class Foo:
def __call__(self):
pass
def bar(): pass
samples = [None, 1, object()] samples = [None, 1, object()]
for sample in samples: for sample in samples:
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
'types.coroutine.*expects a callable'): 'types.coroutine.*expects a callable'):
types.coroutine(sample) types.coroutine(sample)
def test_wrong_func(self): def test_non_gen_values(self):
@types.coroutine @types.coroutine
def foo(): def foo():
return 'spam' return 'spam'
self.assertEqual(foo(), 'spam') self.assertEqual(foo(), 'spam')
class Awaitable:
def __await__(self):
return ()
aw = Awaitable()
@types.coroutine
def foo():
return aw
self.assertIs(aw, foo())
def test_async_def(self): def test_async_def(self):
# Test that types.coroutine passes 'async def' coroutines # Test that types.coroutine passes 'async def' coroutines
# without modification # without modification
...@@ -1263,24 +1268,157 @@ class CoroutineTests(unittest.TestCase): ...@@ -1263,24 +1268,157 @@ class CoroutineTests(unittest.TestCase):
def send(self): pass def send(self): pass
def throw(self): pass def throw(self): pass
def close(self): pass def close(self): pass
def __iter__(self): return self def __iter__(self): pass
def __next__(self): pass def __next__(self): pass
gen = GenLike() # Setup generator mock object
gen = unittest.mock.MagicMock(GenLike)
gen.__iter__ = lambda gen: gen
gen.__name__ = 'gen'
gen.__qualname__ = 'test.gen'
self.assertIsInstance(gen, collections.abc.Generator)
self.assertIs(gen, iter(gen))
@types.coroutine @types.coroutine
def foo(): def foo(): return gen
return gen
self.assertIs(foo().__await__(), gen) wrapper = foo()
self.assertTrue(isinstance(foo(), collections.abc.Coroutine)) self.assertIsInstance(wrapper, types._GeneratorWrapper)
with self.assertRaises(AttributeError): self.assertIs(wrapper.__await__(), wrapper)
foo().gi_code # Wrapper proxies duck generators completely:
self.assertIs(iter(wrapper), wrapper)
self.assertIsInstance(wrapper, collections.abc.Coroutine)
self.assertIsInstance(wrapper, collections.abc.Awaitable)
self.assertIs(wrapper.__qualname__, gen.__qualname__)
self.assertIs(wrapper.__name__, gen.__name__)
# Test AttributeErrors
for name in {'gi_running', 'gi_frame', 'gi_code',
'cr_running', 'cr_frame', 'cr_code'}:
with self.assertRaises(AttributeError):
getattr(wrapper, name)
# Test attributes pass-through
gen.gi_running = object()
gen.gi_frame = object()
gen.gi_code = object()
self.assertIs(wrapper.gi_running, gen.gi_running)
self.assertIs(wrapper.gi_frame, gen.gi_frame)
self.assertIs(wrapper.gi_code, gen.gi_code)
self.assertIs(wrapper.cr_running, gen.gi_running)
self.assertIs(wrapper.cr_frame, gen.gi_frame)
self.assertIs(wrapper.cr_code, gen.gi_code)
wrapper.close()
gen.close.assert_called_once_with()
wrapper.send(1)
gen.send.assert_called_once_with(1)
wrapper.throw(1, 2, 3)
gen.throw.assert_called_once_with(1, 2, 3)
gen.reset_mock()
wrapper.throw(1, 2)
gen.throw.assert_called_once_with(1, 2)
gen.reset_mock()
wrapper.throw(1)
gen.throw.assert_called_once_with(1)
gen.reset_mock()
# Test exceptions propagation
error = Exception()
gen.throw.side_effect = error
try:
wrapper.throw(1)
except Exception as ex:
self.assertIs(ex, error)
else:
self.fail('wrapper did not propagate an exception')
# Test invalid args
gen.reset_mock()
with self.assertRaises(TypeError):
wrapper.throw()
self.assertFalse(gen.throw.called)
with self.assertRaises(TypeError):
wrapper.close(1)
self.assertFalse(gen.close.called)
with self.assertRaises(TypeError):
wrapper.send()
self.assertFalse(gen.send.called)
# Test that we do not double wrap
@types.coroutine
def bar(): return wrapper
self.assertIs(wrapper, bar())
# Test weakrefs support
ref = weakref.ref(wrapper)
self.assertIs(ref(), wrapper)
def test_duck_functional_gen(self):
class Generator:
"""Emulates the following generator (very clumsy):
def gen(fut):
result = yield fut
return result * 2
"""
def __init__(self, fut):
self._i = 0
self._fut = fut
def __iter__(self):
return self
def __next__(self):
return self.send(None)
def send(self, v):
try:
if self._i == 0:
assert v is None
return self._fut
if self._i == 1:
raise StopIteration(v * 2)
if self._i > 1:
raise StopIteration
finally:
self._i += 1
def throw(self, tp, *exc):
self._i = 100
if tp is not GeneratorExit:
raise tp
def close(self):
self.throw(GeneratorExit)
@types.coroutine
def foo(): return Generator('spam')
wrapper = foo()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
async def corofunc():
return await foo() + 100
coro = corofunc()
self.assertEqual(coro.send(None), 'spam')
try:
coro.send(20)
except StopIteration as ex:
self.assertEqual(ex.args[0], 140)
else:
self.fail('StopIteration was expected')
def test_gen(self): def test_gen(self):
def gen(): yield def gen(): yield
gen = gen() gen = gen()
@types.coroutine @types.coroutine
def foo(): return gen def foo(): return gen
self.assertIs(foo().__await__(), gen) wrapper = foo()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
self.assertIs(wrapper.__await__(), gen)
for name in ('__name__', '__qualname__', 'gi_code', for name in ('__name__', '__qualname__', 'gi_code',
'gi_running', 'gi_frame'): 'gi_running', 'gi_frame'):
...@@ -1289,19 +1427,8 @@ class CoroutineTests(unittest.TestCase): ...@@ -1289,19 +1427,8 @@ class CoroutineTests(unittest.TestCase):
self.assertIs(foo().cr_code, gen.gi_code) self.assertIs(foo().cr_code, gen.gi_code)
def test_genfunc(self): def test_genfunc(self):
def gen(): def gen(): yield
yield self.assertIs(types.coroutine(gen), gen)
self.assertFalse(isinstance(gen(), collections.abc.Coroutine))
self.assertFalse(isinstance(gen(), collections.abc.Awaitable))
gen_code = gen.__code__
decorated_gen = types.coroutine(gen)
self.assertIs(decorated_gen, gen)
self.assertIsNot(decorated_gen.__code__, gen_code)
decorated_gen2 = types.coroutine(decorated_gen)
self.assertIs(decorated_gen2.__code__, decorated_gen.__code__)
self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE) self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE)
self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE) self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE)
...@@ -1309,10 +1436,27 @@ class CoroutineTests(unittest.TestCase): ...@@ -1309,10 +1436,27 @@ class CoroutineTests(unittest.TestCase):
g = gen() g = gen()
self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE)
self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE) self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE)
self.assertTrue(isinstance(g, collections.abc.Coroutine)) self.assertIsInstance(g, collections.abc.Coroutine)
self.assertTrue(isinstance(g, collections.abc.Awaitable)) self.assertIsInstance(g, collections.abc.Awaitable)
g.close() # silence warning g.close() # silence warning
self.assertIs(types.coroutine(gen), gen)
def test_wrapper_object(self):
def gen():
yield
@types.coroutine
def coro():
return gen()
wrapper = coro()
self.assertIn('GeneratorWrapper', repr(wrapper))
self.assertEqual(repr(wrapper), str(wrapper))
self.assertTrue(set(dir(wrapper)).issuperset({
'__await__', '__iter__', '__next__', 'cr_code', 'cr_running',
'cr_frame', 'gi_code', 'gi_frame', 'gi_running', 'send',
'close', 'throw'}))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -166,6 +166,39 @@ class DynamicClassAttribute: ...@@ -166,6 +166,39 @@ class DynamicClassAttribute:
import functools as _functools import functools as _functools
import collections.abc as _collections_abc import collections.abc as _collections_abc
class _GeneratorWrapper:
# TODO: Implement this in C.
def __init__(self, gen):
self.__wrapped__ = gen
self.__isgen__ = gen.__class__ is GeneratorType
self.__name__ = getattr(gen, '__name__', None)
self.__qualname__ = getattr(gen, '__qualname__', None)
def send(self, val):
return self.__wrapped__.send(val)
def throw(self, tp, *rest):
return self.__wrapped__.throw(tp, *rest)
def close(self):
return self.__wrapped__.close()
@property
def gi_code(self):
return self.__wrapped__.gi_code
@property
def gi_frame(self):
return self.__wrapped__.gi_frame
@property
def gi_running(self):
return self.__wrapped__.gi_running
cr_code = gi_code
cr_frame = gi_frame
cr_running = gi_running
def __next__(self):
return next(self.__wrapped__)
def __iter__(self):
if self.__isgen__:
return self.__wrapped__
return self
__await__ = __iter__
def coroutine(func): def coroutine(func):
"""Convert regular generator function to a coroutine.""" """Convert regular generator function to a coroutine."""
...@@ -201,36 +234,6 @@ def coroutine(func): ...@@ -201,36 +234,6 @@ def coroutine(func):
# return generator-like objects (for instance generators # return generator-like objects (for instance generators
# compiled with Cython). # compiled with Cython).
class GeneratorWrapper:
def __init__(self, gen):
self.__wrapped__ = gen
self.__name__ = getattr(gen, '__name__', None)
self.__qualname__ = getattr(gen, '__qualname__', None)
def send(self, val):
return self.__wrapped__.send(val)
def throw(self, *args):
return self.__wrapped__.throw(*args)
def close(self):
return self.__wrapped__.close()
@property
def gi_code(self):
return self.__wrapped__.gi_code
@property
def gi_frame(self):
return self.__wrapped__.gi_frame
@property
def gi_running(self):
return self.__wrapped__.gi_running
cr_code = gi_code
cr_frame = gi_frame
cr_running = gi_running
def __next__(self):
return next(self.__wrapped__)
def __iter__(self):
return self.__wrapped__
def __await__(self):
return self.__wrapped__
@_functools.wraps(func) @_functools.wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
coro = func(*args, **kwargs) coro = func(*args, **kwargs)
...@@ -243,7 +246,7 @@ def coroutine(func): ...@@ -243,7 +246,7 @@ def coroutine(func):
# 'coro' is either a pure Python generator iterator, or it # 'coro' is either a pure Python generator iterator, or it
# implements collections.abc.Generator (and does not implement # implements collections.abc.Generator (and does not implement
# collections.abc.Coroutine). # collections.abc.Coroutine).
return GeneratorWrapper(coro) return _GeneratorWrapper(coro)
# 'coro' is either an instance of collections.abc.Coroutine or # 'coro' is either an instance of collections.abc.Coroutine or
# some other object -- pass it through. # some other object -- pass it through.
return coro return coro
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment