Commit 0016d14d authored by Stefan Behnel's avatar Stefan Behnel

follow PEP492 change in Py3.5.2 that makes __aiter__() a simple function...

follow PEP492 change in Py3.5.2 that makes __aiter__() a simple function instead of an async function
See https://bugs.python.org/issue27243
parent 6c2afddc
...@@ -9271,6 +9271,13 @@ class AwaitExprNode(YieldFromExprNode): ...@@ -9271,6 +9271,13 @@ class AwaitExprNode(YieldFromExprNode):
return "__Pyx_Coroutine_Yield_From" return "__Pyx_Coroutine_Yield_From"
class AIterAwaitExprNode(AwaitExprNode):
# 'await' expression node used in async-for loops to support the pre-Py3.5.2 'aiter' protocol
def yield_from_func(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("CoroutineAIterYieldFrom", "Coroutine.c"))
return "__Pyx_Coroutine_AIter_Yield_From"
class AwaitIterNextExprNode(AwaitExprNode): class AwaitIterNextExprNode(AwaitExprNode):
# 'await' expression node as part of 'async for' iteration # 'await' expression node as part of 'async for' iteration
# #
......
...@@ -6145,7 +6145,7 @@ class _ForInStatNode(LoopNode, StatNode): ...@@ -6145,7 +6145,7 @@ class _ForInStatNode(LoopNode, StatNode):
# Base class of 'for-in' statements. # Base class of 'for-in' statements.
# #
# target ExprNode # target ExprNode
# iterator IteratorNode | AwaitExprNode(AsyncIteratorNode) # iterator IteratorNode | AIterAwaitExprNode(AsyncIteratorNode)
# body StatNode # body StatNode
# else_clause StatNode # else_clause StatNode
# item NextNode | AwaitExprNode(AsyncNextNode) # item NextNode | AwaitExprNode(AsyncNextNode)
...@@ -6251,7 +6251,7 @@ class ForInStatNode(_ForInStatNode): ...@@ -6251,7 +6251,7 @@ class ForInStatNode(_ForInStatNode):
class AsyncForStatNode(_ForInStatNode): class AsyncForStatNode(_ForInStatNode):
# 'async for' statement # 'async for' statement
# #
# iterator AwaitExprNode(AsyncIteratorNode) # iterator AIterAwaitExprNode(AsyncIteratorNode)
# item AwaitIterNextExprNode(AsyncIteratorNode) # item AwaitIterNextExprNode(AsyncIteratorNode)
is_async = True is_async = True
...@@ -6260,7 +6260,7 @@ class AsyncForStatNode(_ForInStatNode): ...@@ -6260,7 +6260,7 @@ class AsyncForStatNode(_ForInStatNode):
assert 'item' not in kw assert 'item' not in kw
from . import ExprNodes from . import ExprNodes
# AwaitExprNodes must appear before running MarkClosureVisitor # AwaitExprNodes must appear before running MarkClosureVisitor
kw['iterator'] = ExprNodes.AwaitExprNode(iterator.pos, arg=iterator) kw['iterator'] = ExprNodes.AIterAwaitExprNode(iterator.pos, arg=iterator)
kw['item'] = ExprNodes.AwaitIterNextExprNode(iterator.pos, arg=None) kw['item'] = ExprNodes.AwaitIterNextExprNode(iterator.pos, arg=None)
_ForInStatNode.__init__(self, pos, **kw) _ForInStatNode.__init__(self, pos, **kw)
......
...@@ -45,15 +45,40 @@ static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject ...@@ -45,15 +45,40 @@ static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject
//////////////////// CoroutineYieldFrom.proto //////////////////// //////////////////// CoroutineYieldFrom.proto ////////////////////
static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source); #define __Pyx_Coroutine_Yield_From(gen, source) __Pyx__Coroutine_Yield_From(gen, source, 0)
static CYTHON_INLINE PyObject* __Pyx__Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source, int warn);
//////////////////// CoroutineYieldFrom //////////////////// //////////////////// CoroutineYieldFrom ////////////////////
//@requires: Coroutine //@requires: Coroutine
//@requires: GetAwaitIter //@requires: GetAwaitIter
static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source) { static int __Pyx_WarnAIterDeprecation(PyObject *aiter) {
int result;
#if PY_MAJOR_VERSION >= 3
result = PyErr_WarnFormat(
PyExc_PendingDeprecationWarning, 1,
"'%.100s' implements legacy __aiter__ protocol; "
"__aiter__ should return an asynchronous "
"iterator, not awaitable",
Py_TYPE(aiter)->tp_name);
#else
result = PyErr_WarnEx(
PyExc_PendingDeprecationWarning,
"object implements legacy __aiter__ protocol; "
"__aiter__ should return an asynchronous "
"iterator, not awaitable",
1);
#endif
return result != 0;
}
static CYTHON_INLINE PyObject* __Pyx__Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source, int warn) {
PyObject *retval; PyObject *retval;
if (__Pyx_Coroutine_CheckExact(source)) { if (__Pyx_Coroutine_CheckExact(source)) {
if (warn && unlikely(__Pyx_WarnAIterDeprecation(source))) {
/* Warning was converted to an error. */
return NULL;
}
retval = __Pyx_Generator_Next(source); retval = __Pyx_Generator_Next(source);
if (retval) { if (retval) {
Py_INCREF(source); Py_INCREF(source);
...@@ -64,6 +89,11 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject ...@@ -64,6 +89,11 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject
PyObject *source_gen = __Pyx__Coroutine_GetAwaitableIter(source); PyObject *source_gen = __Pyx__Coroutine_GetAwaitableIter(source);
if (unlikely(!source_gen)) if (unlikely(!source_gen))
return NULL; return NULL;
if (warn && unlikely(__Pyx_WarnAIterDeprecation(source))) {
/* Warning was converted to an error. */
Py_DECREF(source_gen);
return NULL;
}
// source_gen is now the iterator, make the first next() call // source_gen is now the iterator, make the first next() call
if (__Pyx_Coroutine_CheckExact(source_gen)) { if (__Pyx_Coroutine_CheckExact(source_gen)) {
retval = __Pyx_Generator_Next(source_gen); retval = __Pyx_Generator_Next(source_gen);
...@@ -80,6 +110,53 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject ...@@ -80,6 +110,53 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject
} }
//////////////////// CoroutineAIterYieldFrom.proto ////////////////////
static CYTHON_INLINE PyObject* __Pyx_Coroutine_AIter_Yield_From(__pyx_CoroutineObject *gen, PyObject *source);
//////////////////// CoroutineAIterYieldFrom ////////////////////
//@requires: CoroutineYieldFrom
static CYTHON_INLINE PyObject* __Pyx_Coroutine_AIter_Yield_From(__pyx_CoroutineObject *gen, PyObject *source) {
#if PY_MAJOR_VERSION >= 3
__Pyx_PyAsyncMethodsStruct* am = __Pyx_PyType_AsAsync(source);
if (likely(am && am->am_anext)) {
// Starting with CPython 3.5.2, __aiter__ should return
// asynchronous iterators directly (not awaitables that
// resolve to asynchronous iterators.)
//
// Therefore, we check if the object that was returned
// from __aiter__ has an __anext__ method. If it does,
// we return it directly as StopIteration result,
// which avoids yielding.
//
// See http://bugs.python.org/issue27243 for more
// details.
PyErr_SetObject(PyExc_StopIteration, source);
return NULL;
}
#endif
#if PY_VERSION_HEX < 0x030500B2
if (!__Pyx_PyType_AsAsync(source)) {
#ifdef __Pyx_Coroutine_USED
if (!__Pyx_Coroutine_CheckExact(source)) // quickly rule out a likely case
#endif
{
// same as above in slow
PyObject *method = __Pyx_PyObject_GetAttrStr(source, PYIDENT("__anext__"));
if (method) {
Py_DECREF(method);
PyErr_SetObject(PyExc_StopIteration, source);
return NULL;
}
PyErr_Clear();
}
}
#endif
return __Pyx__Coroutine_Yield_From(gen, source, 1);
}
//////////////////// GetAwaitIter.proto //////////////////// //////////////////// GetAwaitIter.proto ////////////////////
static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAwaitableIter(PyObject *o); /*proto*/ static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAwaitableIter(PyObject *o); /*proto*/
...@@ -196,7 +273,7 @@ static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAsyncIter(PyObject *obj) { ...@@ -196,7 +273,7 @@ static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAsyncIter(PyObject *obj) {
return NULL; return NULL;
} }
#else #else
// avoid 'unused function' warning // avoid C warning about 'unused function'
if (0) (void) __Pyx_PyObject_CallMethod0(obj, PYIDENT("__aiter__")); if (0) (void) __Pyx_PyObject_CallMethod0(obj, PYIDENT("__aiter__"));
#endif #endif
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
import re import re
import gc import gc
import sys import sys
import copy
#import types #import types
import pickle
import os.path import os.path
#import inspect #import inspect
import unittest import unittest
...@@ -128,6 +130,17 @@ def silence_coro_gc(): ...@@ -128,6 +130,17 @@ def silence_coro_gc():
gc.collect() gc.collect()
def min_py27(method):
return None if sys.version_info < (2, 7) else method
def ignore_py26(manager):
@contextlib.contextmanager
def dummy():
yield
return dummy() if sys.version_info < (2, 7) else manager
class AsyncBadSyntaxTest(unittest.TestCase): class AsyncBadSyntaxTest(unittest.TestCase):
@contextlib.contextmanager @contextlib.contextmanager
...@@ -1247,8 +1260,9 @@ class CoroutineTest(unittest.TestCase): ...@@ -1247,8 +1260,9 @@ class CoroutineTest(unittest.TestCase):
buffer = [] buffer = []
async def test1(): async def test1():
async for i1, i2 in AsyncIter(): with ignore_py26(self.assertWarnsRegex(PendingDeprecationWarning, "legacy")):
buffer.append(i1 + i2) async for i1, i2 in AsyncIter():
buffer.append(i1 + i2)
yielded, _ = run_async(test1()) yielded, _ = run_async(test1())
# Make sure that __aiter__ was called only once # Make sure that __aiter__ was called only once
...@@ -1260,12 +1274,13 @@ class CoroutineTest(unittest.TestCase): ...@@ -1260,12 +1274,13 @@ class CoroutineTest(unittest.TestCase):
buffer = [] buffer = []
async def test2(): async def test2():
nonlocal buffer nonlocal buffer
async for i in AsyncIter(): with ignore_py26(self.assertWarnsRegex(PendingDeprecationWarning, "legacy")):
buffer.append(i[0]) async for i in AsyncIter():
if i[0] == 20: buffer.append(i[0])
break if i[0] == 20:
else: break
buffer.append('what?') else:
buffer.append('what?')
buffer.append('end') buffer.append('end')
yielded, _ = run_async(test2()) yielded, _ = run_async(test2())
...@@ -1278,12 +1293,13 @@ class CoroutineTest(unittest.TestCase): ...@@ -1278,12 +1293,13 @@ class CoroutineTest(unittest.TestCase):
buffer = [] buffer = []
async def test3(): async def test3():
nonlocal buffer nonlocal buffer
async for i in AsyncIter(): with ignore_py26(self.assertWarnsRegex(PendingDeprecationWarning, "legacy")):
if i[0] > 20: async for i in AsyncIter():
continue if i[0] > 20:
buffer.append(i[0]) continue
else: buffer.append(i[0])
buffer.append('what?') else:
buffer.append('what?')
buffer.append('end') buffer.append('end')
yielded, _ = run_async(test3()) yielded, _ = run_async(test3())
...@@ -1330,7 +1346,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -1330,7 +1346,7 @@ class CoroutineTest(unittest.TestCase):
def test_for_4(self): def test_for_4(self):
class I(object): class I(object):
async def __aiter__(self): def __aiter__(self):
return self return self
def __anext__(self): def __anext__(self):
...@@ -1360,8 +1376,9 @@ class CoroutineTest(unittest.TestCase): ...@@ -1360,8 +1376,9 @@ class CoroutineTest(unittest.TestCase):
return 123 return 123
async def foo(): async def foo():
async for i in I(): with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
print('never going to happen') async for i in I():
print('never going to happen')
with self.assertRaisesRegex( with self.assertRaisesRegex(
TypeError, TypeError,
...@@ -1385,7 +1402,7 @@ class CoroutineTest(unittest.TestCase): ...@@ -1385,7 +1402,7 @@ class CoroutineTest(unittest.TestCase):
def __init__(self): def __init__(self):
self.i = 0 self.i = 0
async def __aiter__(self): def __aiter__(self):
return self return self
async def __anext__(self): async def __anext__(self):
...@@ -1462,6 +1479,21 @@ class CoroutineTest(unittest.TestCase): ...@@ -1462,6 +1479,21 @@ class CoroutineTest(unittest.TestCase):
class AI(object): class AI(object):
async def __aiter__(self): async def __aiter__(self):
1/0 1/0
async def foo():
nonlocal CNT
with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
async for i in AI():
CNT += 1
CNT += 10
with self.assertRaises(ZeroDivisionError):
run_async(foo())
self.assertEqual(CNT, 0)
def test_for_8(self):
CNT = 0
class AI:
def __aiter__(self):
1/0
async def foo(): async def foo():
nonlocal CNT nonlocal CNT
async for i in AI(): async for i in AI():
...@@ -1469,8 +1501,74 @@ class CoroutineTest(unittest.TestCase): ...@@ -1469,8 +1501,74 @@ class CoroutineTest(unittest.TestCase):
CNT += 10 CNT += 10
with self.assertRaises(ZeroDivisionError): with self.assertRaises(ZeroDivisionError):
run_async(foo()) run_async(foo())
with warnings.catch_warnings():
warnings.simplefilter("error")
# Test that if __aiter__ raises an exception it propagates
# without any kind of warning.
run_async(foo())
self.assertEqual(CNT, 0) self.assertEqual(CNT, 0)
@min_py27
def test_for_9(self):
# Test that PendingDeprecationWarning can safely be converted into
# an exception (__aiter__ should not have a chance to raise
# a ZeroDivisionError.)
class AI:
async def __aiter__(self):
1/0
async def foo():
async for i in AI():
pass
with self.assertRaises(PendingDeprecationWarning):
with warnings.catch_warnings():
warnings.simplefilter("error")
run_async(foo())
@min_py27
def test_for_10(self):
# Test that PendingDeprecationWarning can safely be converted into
# an exception.
class AI:
async def __aiter__(self):
pass
async def foo():
async for i in AI():
pass
with self.assertRaises(PendingDeprecationWarning):
with warnings.catch_warnings():
warnings.simplefilter("error")
run_async(foo())
def test_copy(self):
async def func(): pass
coro = func()
with self.assertRaises(TypeError):
copy.copy(coro)
aw = coro.__await__()
try:
with self.assertRaises(TypeError):
copy.copy(aw)
finally:
aw.close()
def test_pickle(self):
async def func(): pass
coro = func()
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((TypeError, pickle.PicklingError)):
pickle.dumps(coro, proto)
aw = coro.__await__()
try:
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((TypeError, pickle.PicklingError)):
pickle.dumps(aw, proto)
finally:
aw.close()
class CoroAsyncIOCompatTest(unittest.TestCase): class CoroAsyncIOCompatTest(unittest.TestCase):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment