Commit 64c0c794 authored by Stefan Behnel's avatar Stefan Behnel

follow PEP492 change in Py3.5.2 that makes __aiter__() a simple function...

follow PEP492 change in Py3.5.2 that makes __aiter__() a simple function instead of an async function
See https://bugs.python.org/issue27243
parent 7a6d7364
......@@ -9237,6 +9237,13 @@ class AwaitExprNode(YieldFromExprNode):
return "__Pyx_Coroutine_Yield_From"
class AIterAwaitExprNode(AwaitExprNode):
# 'await' expression node used in async-for loops to support the pre-Py3.5.2 'aiter' protocol
def yield_from_func(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("CoroutineAIterYieldFrom", "Coroutine.c"))
return "__Pyx_Coroutine_AIter_Yield_From"
class AwaitIterNextExprNode(AwaitExprNode):
# 'await' expression node as part of 'async for' iteration
#
......
......@@ -6144,7 +6144,7 @@ class _ForInStatNode(LoopNode, StatNode):
# Base class of 'for-in' statements.
#
# target ExprNode
# iterator IteratorNode | AwaitExprNode(AsyncIteratorNode)
# iterator IteratorNode | AIterAwaitExprNode(AsyncIteratorNode)
# body StatNode
# else_clause StatNode
# item NextNode | AwaitExprNode(AsyncNextNode)
......@@ -6250,7 +6250,7 @@ class ForInStatNode(_ForInStatNode):
class AsyncForStatNode(_ForInStatNode):
# 'async for' statement
#
# iterator AwaitExprNode(AsyncIteratorNode)
# iterator AIterAwaitExprNode(AsyncIteratorNode)
# item AwaitIterNextExprNode(AsyncIteratorNode)
is_async = True
......@@ -6259,7 +6259,7 @@ class AsyncForStatNode(_ForInStatNode):
assert 'item' not in kw
from . import ExprNodes
# AwaitExprNodes must appear before running MarkClosureVisitor
kw['iterator'] = ExprNodes.AwaitExprNode(iterator.pos, arg=iterator)
kw['iterator'] = ExprNodes.AIterAwaitExprNode(iterator.pos, arg=iterator)
kw['item'] = ExprNodes.AwaitIterNextExprNode(iterator.pos, arg=None)
_ForInStatNode.__init__(self, pos, **kw)
......
......@@ -45,15 +45,40 @@ static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject
//////////////////// CoroutineYieldFrom.proto ////////////////////
static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source);
#define __Pyx_Coroutine_Yield_From(gen, source) __Pyx__Coroutine_Yield_From(gen, source, 0)
static CYTHON_INLINE PyObject* __Pyx__Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source, int warn);
//////////////////// CoroutineYieldFrom ////////////////////
//@requires: Coroutine
//@requires: GetAwaitIter
static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source) {
static int __Pyx_WarnAIterDeprecation(PyObject *aiter) {
int result;
#if PY_MAJOR_VERSION >= 3
result = PyErr_WarnFormat(
PyExc_PendingDeprecationWarning, 1,
"'%.100s' implements legacy __aiter__ protocol; "
"__aiter__ should return an asynchronous "
"iterator, not awaitable",
Py_TYPE(aiter)->tp_name);
#else
result = PyErr_WarnEx(
PyExc_PendingDeprecationWarning,
"object implements legacy __aiter__ protocol; "
"__aiter__ should return an asynchronous "
"iterator, not awaitable",
1);
#endif
return result != 0;
}
static CYTHON_INLINE PyObject* __Pyx__Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source, int warn) {
PyObject *retval;
if (__Pyx_Coroutine_CheckExact(source)) {
if (warn && unlikely(__Pyx_WarnAIterDeprecation(source))) {
/* Warning was converted to an error. */
return NULL;
}
retval = __Pyx_Generator_Next(source);
if (retval) {
Py_INCREF(source);
......@@ -64,6 +89,11 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject
PyObject *source_gen = __Pyx__Coroutine_GetAwaitableIter(source);
if (unlikely(!source_gen))
return NULL;
if (warn && unlikely(__Pyx_WarnAIterDeprecation(source))) {
/* Warning was converted to an error. */
Py_DECREF(source_gen);
return NULL;
}
// source_gen is now the iterator, make the first next() call
if (__Pyx_Coroutine_CheckExact(source_gen)) {
retval = __Pyx_Generator_Next(source_gen);
......@@ -80,6 +110,53 @@ static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject
}
//////////////////// CoroutineAIterYieldFrom.proto ////////////////////
static CYTHON_INLINE PyObject* __Pyx_Coroutine_AIter_Yield_From(__pyx_CoroutineObject *gen, PyObject *source);
//////////////////// CoroutineAIterYieldFrom ////////////////////
//@requires: CoroutineYieldFrom
static CYTHON_INLINE PyObject* __Pyx_Coroutine_AIter_Yield_From(__pyx_CoroutineObject *gen, PyObject *source) {
#if PY_MAJOR_VERSION >= 3
__Pyx_PyAsyncMethodsStruct* am = __Pyx_PyType_AsAsync(source);
if (likely(am && am->am_anext)) {
// Starting with CPython 3.5.2, __aiter__ should return
// asynchronous iterators directly (not awaitables that
// resolve to asynchronous iterators.)
//
// Therefore, we check if the object that was returned
// from __aiter__ has an __anext__ method. If it does,
// we return it directly as StopIteration result,
// which avoids yielding.
//
// See http://bugs.python.org/issue27243 for more
// details.
PyErr_SetObject(PyExc_StopIteration, source);
return NULL;
}
#endif
#if PY_VERSION_HEX < 0x030500B2
if (!__Pyx_PyType_AsAsync(source)) {
#ifdef __Pyx_Coroutine_USED
if (!__Pyx_Coroutine_CheckExact(source)) // quickly rule out a likely case
#endif
{
// same as above in slow
PyObject *method = __Pyx_PyObject_GetAttrStr(source, PYIDENT("__anext__"));
if (method) {
Py_DECREF(method);
PyErr_SetObject(PyExc_StopIteration, source);
return NULL;
}
PyErr_Clear();
}
}
#endif
return __Pyx__Coroutine_Yield_From(gen, source, 1);
}
//////////////////// GetAwaitIter.proto ////////////////////
static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAwaitableIter(PyObject *o); /*proto*/
......@@ -196,7 +273,7 @@ static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAsyncIter(PyObject *obj) {
return NULL;
}
#else
// avoid 'unused function' warning
// avoid C warning about 'unused function'
if (0) (void) __Pyx_PyObject_CallMethod0(obj, PYIDENT("__aiter__"));
#endif
......
......@@ -5,7 +5,9 @@
import re
import gc
import sys
import copy
#import types
import pickle
import os.path
#import inspect
import unittest
......@@ -128,6 +130,17 @@ def silence_coro_gc():
gc.collect()
def min_py27(method):
return None if sys.version_info < (2, 7) else method
def ignore_py26(manager):
@contextlib.contextmanager
def dummy():
yield
return dummy() if sys.version_info < (2, 7) else manager
class AsyncBadSyntaxTest(unittest.TestCase):
@contextlib.contextmanager
......@@ -1238,8 +1251,9 @@ class CoroutineTest(unittest.TestCase):
buffer = []
async def test1():
async for i1, i2 in AsyncIter():
buffer.append(i1 + i2)
with ignore_py26(self.assertWarnsRegex(PendingDeprecationWarning, "legacy")):
async for i1, i2 in AsyncIter():
buffer.append(i1 + i2)
yielded, _ = run_async(test1())
# Make sure that __aiter__ was called only once
......@@ -1251,12 +1265,13 @@ class CoroutineTest(unittest.TestCase):
buffer = []
async def test2():
nonlocal buffer
async for i in AsyncIter():
buffer.append(i[0])
if i[0] == 20:
break
else:
buffer.append('what?')
with ignore_py26(self.assertWarnsRegex(PendingDeprecationWarning, "legacy")):
async for i in AsyncIter():
buffer.append(i[0])
if i[0] == 20:
break
else:
buffer.append('what?')
buffer.append('end')
yielded, _ = run_async(test2())
......@@ -1269,12 +1284,13 @@ class CoroutineTest(unittest.TestCase):
buffer = []
async def test3():
nonlocal buffer
async for i in AsyncIter():
if i[0] > 20:
continue
buffer.append(i[0])
else:
buffer.append('what?')
with ignore_py26(self.assertWarnsRegex(PendingDeprecationWarning, "legacy")):
async for i in AsyncIter():
if i[0] > 20:
continue
buffer.append(i[0])
else:
buffer.append('what?')
buffer.append('end')
yielded, _ = run_async(test3())
......@@ -1321,7 +1337,7 @@ class CoroutineTest(unittest.TestCase):
def test_for_4(self):
class I(object):
async def __aiter__(self):
def __aiter__(self):
return self
def __anext__(self):
......@@ -1351,8 +1367,9 @@ class CoroutineTest(unittest.TestCase):
return 123
async def foo():
async for i in I():
print('never going to happen')
with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
async for i in I():
print('never going to happen')
with self.assertRaisesRegex(
TypeError,
......@@ -1376,7 +1393,7 @@ class CoroutineTest(unittest.TestCase):
def __init__(self):
self.i = 0
async def __aiter__(self):
def __aiter__(self):
return self
async def __anext__(self):
......@@ -1453,6 +1470,21 @@ class CoroutineTest(unittest.TestCase):
class AI(object):
async def __aiter__(self):
1/0
async def foo():
nonlocal CNT
with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
async for i in AI():
CNT += 1
CNT += 10
with self.assertRaises(ZeroDivisionError):
run_async(foo())
self.assertEqual(CNT, 0)
def test_for_8(self):
CNT = 0
class AI:
def __aiter__(self):
1/0
async def foo():
nonlocal CNT
async for i in AI():
......@@ -1460,8 +1492,74 @@ class CoroutineTest(unittest.TestCase):
CNT += 10
with self.assertRaises(ZeroDivisionError):
run_async(foo())
with warnings.catch_warnings():
warnings.simplefilter("error")
# Test that if __aiter__ raises an exception it propagates
# without any kind of warning.
run_async(foo())
self.assertEqual(CNT, 0)
@min_py27
def test_for_9(self):
# Test that PendingDeprecationWarning can safely be converted into
# an exception (__aiter__ should not have a chance to raise
# a ZeroDivisionError.)
class AI:
async def __aiter__(self):
1/0
async def foo():
async for i in AI():
pass
with self.assertRaises(PendingDeprecationWarning):
with warnings.catch_warnings():
warnings.simplefilter("error")
run_async(foo())
@min_py27
def test_for_10(self):
# Test that PendingDeprecationWarning can safely be converted into
# an exception.
class AI:
async def __aiter__(self):
pass
async def foo():
async for i in AI():
pass
with self.assertRaises(PendingDeprecationWarning):
with warnings.catch_warnings():
warnings.simplefilter("error")
run_async(foo())
def test_copy(self):
async def func(): pass
coro = func()
with self.assertRaises(TypeError):
copy.copy(coro)
aw = coro.__await__()
try:
with self.assertRaises(TypeError):
copy.copy(aw)
finally:
aw.close()
def test_pickle(self):
async def func(): pass
coro = func()
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((TypeError, pickle.PicklingError)):
pickle.dumps(coro, proto)
aw = coro.__await__()
try:
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((TypeError, pickle.PicklingError)):
pickle.dumps(aw, proto)
finally:
aw.close()
class CoroAsyncIOCompatTest(unittest.TestCase):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment