Commit a9365c2e authored by Stefan Behnel's avatar Stefan Behnel

make Cython's Coroutine type and extension types work with Python's async-for...

make Cython's Coroutine type and extension types work with Python's async-for loop, implement new slots in Py3.5
parent 498371be
...@@ -63,6 +63,13 @@ uncachable_builtins = [ ...@@ -63,6 +63,13 @@ uncachable_builtins = [
'_', # e.g. gettext '_', # e.g. gettext
] ]
special_py_methods = set([
'__cinit__', '__dealloc__', '__richcmp__', '__next__',
'__await__', '__aiter__', '__anext__',
'__getreadbuffer__', '__getwritebuffer__', '__getsegcount__',
'__getcharbuffer__', '__getbuffer__', '__releasebuffer__'
])
modifier_output_mapper = { modifier_output_mapper = {
'inline': 'CYTHON_INLINE' 'inline': 'CYTHON_INLINE'
}.get }.get
...@@ -1999,7 +2006,7 @@ class CCodeWriter(object): ...@@ -1999,7 +2006,7 @@ class CCodeWriter(object):
def put_pymethoddef(self, entry, term, allow_skip=True): def put_pymethoddef(self, entry, term, allow_skip=True):
if entry.is_special or entry.name == '__getattribute__': if entry.is_special or entry.name == '__getattribute__':
if entry.name not in ['__cinit__', '__dealloc__', '__richcmp__', '__next__', '__getreadbuffer__', '__getwritebuffer__', '__getsegcount__', '__getcharbuffer__', '__getbuffer__', '__releasebuffer__']: if entry.name not in special_py_methods:
if entry.name == '__getattr__' and not self.globalstate.directives['fast_getattr']: if entry.name == '__getattr__' and not self.globalstate.directives['fast_getattr']:
pass pass
# Python's typeobject.c will automatically fill in our slot # Python's typeobject.c will automatically fill in our slot
......
...@@ -431,8 +431,8 @@ class SuiteSlot(SlotDescriptor): ...@@ -431,8 +431,8 @@ class SuiteSlot(SlotDescriptor):
# #
# sub_slots [SlotDescriptor] # sub_slots [SlotDescriptor]
def __init__(self, sub_slots, slot_type, slot_name): def __init__(self, sub_slots, slot_type, slot_name, ifdef=None):
SlotDescriptor.__init__(self, slot_name) SlotDescriptor.__init__(self, slot_name, ifdef=ifdef)
self.sub_slots = sub_slots self.sub_slots = sub_slots
self.slot_type = slot_type self.slot_type = slot_type
substructures.append(self) substructures.append(self)
...@@ -454,6 +454,8 @@ class SuiteSlot(SlotDescriptor): ...@@ -454,6 +454,8 @@ class SuiteSlot(SlotDescriptor):
def generate_substructure(self, scope, code): def generate_substructure(self, scope, code):
if not self.is_empty(scope): if not self.is_empty(scope):
code.putln("") code.putln("")
if self.ifdef:
code.putln("#if %s" % self.ifdef)
code.putln( code.putln(
"static %s %s = {" % ( "static %s %s = {" % (
self.slot_type, self.slot_type,
...@@ -461,6 +463,8 @@ class SuiteSlot(SlotDescriptor): ...@@ -461,6 +463,8 @@ class SuiteSlot(SlotDescriptor):
for slot in self.sub_slots: for slot in self.sub_slots:
slot.generate(scope, code) slot.generate(scope, code)
code.putln("};") code.putln("};")
if self.ifdef:
code.putln("#endif")
substructures = [] # List of all SuiteSlot instances substructures = [] # List of all SuiteSlot instances
...@@ -506,6 +510,29 @@ class BaseClassSlot(SlotDescriptor): ...@@ -506,6 +510,29 @@ class BaseClassSlot(SlotDescriptor):
base_type.typeptr_cname)) base_type.typeptr_cname))
class AlternativeSlot(SlotDescriptor):
"""Slot descriptor that delegates to different slots using C macros."""
def __init__(self, alternatives):
SlotDescriptor.__init__(self, "")
self.alternatives = alternatives
def generate(self, scope, code):
# state machine: "#if ... (#elif ...)* #else ... #endif"
test = 'if'
for guard, slot in self.alternatives:
if guard:
assert test in ('if', 'elif'), test
else:
assert test == 'elif', test
test = 'else'
code.putln("#%s %s" % (test, guard))
slot.generate(scope, code)
if test == 'if':
test = 'elif'
assert test == 'else', test
code.putln("#endif")
# The following dictionary maps __xxx__ method names to slot descriptors. # The following dictionary maps __xxx__ method names to slot descriptors.
method_name_to_slot = {} method_name_to_slot = {}
...@@ -748,6 +775,12 @@ PyBufferProcs = ( ...@@ -748,6 +775,12 @@ PyBufferProcs = (
MethodSlot(releasebufferproc, "bf_releasebuffer", "__releasebuffer__") MethodSlot(releasebufferproc, "bf_releasebuffer", "__releasebuffer__")
) )
PyAsyncMethods = (
MethodSlot(unaryfunc, "am_await", "__await__"),
MethodSlot(unaryfunc, "am_aiter", "__aiter__"),
MethodSlot(unaryfunc, "am_anext", "__anext__"),
)
#------------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------------
# #
# The main slot table. This table contains descriptors for all the # The main slot table. This table contains descriptors for all the
...@@ -761,7 +794,11 @@ slot_table = ( ...@@ -761,7 +794,11 @@ slot_table = (
EmptySlot("tp_print"), #MethodSlot(printfunc, "tp_print", "__print__"), EmptySlot("tp_print"), #MethodSlot(printfunc, "tp_print", "__print__"),
EmptySlot("tp_getattr"), EmptySlot("tp_getattr"),
EmptySlot("tp_setattr"), EmptySlot("tp_setattr"),
MethodSlot(cmpfunc, "tp_compare", "__cmp__", py3 = '<RESERVED>'), AlternativeSlot([
("PY_MAJOR_VERSION < 3", MethodSlot(cmpfunc, "tp_compare", "__cmp__")),
("PY_VERSION_HEX < 0x030500B1", EmptySlot("tp_reserved")),
("", SuiteSlot(PyAsyncMethods, "PyAsyncMethods", "tp_as_async", ifdef="PY_VERSION_HEX >= 0x030500B1")),
]),
MethodSlot(reprfunc, "tp_repr", "__repr__"), MethodSlot(reprfunc, "tp_repr", "__repr__"),
SuiteSlot(PyNumberMethods, "PyNumberMethods", "tp_as_number"), SuiteSlot(PyNumberMethods, "PyNumberMethods", "tp_as_number"),
......
...@@ -894,13 +894,6 @@ static PyMemberDef __pyx_Coroutine_memberlist[] = { ...@@ -894,13 +894,6 @@ static PyMemberDef __pyx_Coroutine_memberlist[] = {
{0, 0, 0, 0, 0} {0, 0, 0, 0, 0}
}; };
static PyMethodDef __pyx_Coroutine_methods[] = {
{"send", (PyCFunction) __Pyx_Coroutine_Send, METH_O, 0},
{"throw", (PyCFunction) __Pyx_Coroutine_Throw, METH_VARARGS, 0},
{"close", (PyCFunction) __Pyx_Coroutine_Close, METH_NOARGS, 0},
{0, 0, 0, 0}
};
static __pyx_CoroutineObject *__Pyx__Coroutine_New(PyTypeObject* type, __pyx_coroutine_body_t body, static __pyx_CoroutineObject *__Pyx__Coroutine_New(PyTypeObject* type, __pyx_coroutine_body_t body,
PyObject *closure, PyObject *name, PyObject *qualname) { PyObject *closure, PyObject *name, PyObject *qualname) {
__pyx_CoroutineObject *gen = PyObject_GC_New(__pyx_CoroutineObject, type); __pyx_CoroutineObject *gen = PyObject_GC_New(__pyx_CoroutineObject, type);
...@@ -981,9 +974,24 @@ static void __Pyx_Coroutine_check_and_dealloc(PyObject *self) { ...@@ -981,9 +974,24 @@ static void __Pyx_Coroutine_check_and_dealloc(PyObject *self) {
__Pyx_Coroutine_dealloc(self); __Pyx_Coroutine_dealloc(self);
} }
static PyObject *__Pyx_Coroutine_return_self(PyObject *self) {
Py_INCREF(self);
return self;
}
static PyMethodDef __pyx_Coroutine_methods[] = {
{"send", (PyCFunction) __Pyx_Coroutine_Send, METH_O, 0},
{"throw", (PyCFunction) __Pyx_Coroutine_Throw, METH_VARARGS, 0},
{"close", (PyCFunction) __Pyx_Coroutine_Close, METH_NOARGS, 0},
#if PY_VERSION_HEX < 0x030500B1
{"__await__", (PyCFunction) __Pyx_Coroutine_return_self, METH_NOARGS, 0},
#endif
{0, 0, 0, 0}
};
#if PY_VERSION_HEX >= 0x030500B1 #if PY_VERSION_HEX >= 0x030500B1
static PyAsyncMethods __pyx_Coroutine_as_async = { static PyAsyncMethods __pyx_Coroutine_as_async = {
0, /*am_await*/ __Pyx_Coroutine_return_self, /*am_await*/
0, /*am_aiter*/ 0, /*am_aiter*/
0, /*am_anext*/ 0, /*am_anext*/
}; };
...@@ -1066,6 +1074,13 @@ static int __pyx_Coroutine_init(void) { ...@@ -1066,6 +1074,13 @@ static int __pyx_Coroutine_init(void) {
//@requires: CoroutineBase //@requires: CoroutineBase
//@requires: PatchGeneratorABC //@requires: PatchGeneratorABC
static PyMethodDef __pyx_Generator_methods[] = {
{"send", (PyCFunction) __Pyx_Coroutine_Send, METH_O, 0},
{"throw", (PyCFunction) __Pyx_Coroutine_Throw, METH_VARARGS, 0},
{"close", (PyCFunction) __Pyx_Coroutine_Close, METH_NOARGS, 0},
{0, 0, 0, 0}
};
static PyTypeObject __pyx_GeneratorType_type = { static PyTypeObject __pyx_GeneratorType_type = {
PyVarObject_HEAD_INIT(0, 0) PyVarObject_HEAD_INIT(0, 0)
"generator", /*tp_name*/ "generator", /*tp_name*/
...@@ -1075,11 +1090,7 @@ static PyTypeObject __pyx_GeneratorType_type = { ...@@ -1075,11 +1090,7 @@ static PyTypeObject __pyx_GeneratorType_type = {
0, /*tp_print*/ 0, /*tp_print*/
0, /*tp_getattr*/ 0, /*tp_getattr*/
0, /*tp_setattr*/ 0, /*tp_setattr*/
#if PY_MAJOR_VERSION < 3 0, /*tp_compare / tp_as_async*/
0, /*tp_compare*/
#else
0, /*reserved*/
#endif
0, /*tp_repr*/ 0, /*tp_repr*/
0, /*tp_as_number*/ 0, /*tp_as_number*/
0, /*tp_as_sequence*/ 0, /*tp_as_sequence*/
...@@ -1098,7 +1109,7 @@ static PyTypeObject __pyx_GeneratorType_type = { ...@@ -1098,7 +1109,7 @@ static PyTypeObject __pyx_GeneratorType_type = {
offsetof(__pyx_CoroutineObject, gi_weakreflist), /*tp_weaklistoffset*/ offsetof(__pyx_CoroutineObject, gi_weakreflist), /*tp_weaklistoffset*/
0, /*tp_iter*/ 0, /*tp_iter*/
(iternextfunc) __Pyx_Generator_Next, /*tp_iternext*/ (iternextfunc) __Pyx_Generator_Next, /*tp_iternext*/
__pyx_Coroutine_methods, /*tp_methods*/ __pyx_Generator_methods, /*tp_methods*/
__pyx_Coroutine_memberlist, /*tp_members*/ __pyx_Coroutine_memberlist, /*tp_members*/
__pyx_Coroutine_getsets, /*tp_getset*/ __pyx_Coroutine_getsets, /*tp_getset*/
0, /*tp_base*/ 0, /*tp_base*/
......
# mode: run
# tag: pep492, asyncfor, await
import sys
if sys.version_info >= (3, 5, 0, 'beta'):
# pass Cython implemented AsyncIter() into a Python async-for loop
__doc__ = u"""
>>> def test_py35():
... buffer = []
... async def coro():
... async for i1, i2 in AsyncIter(1):
... buffer.append(i1 + i2)
... return coro, buffer
>>> testfunc, buffer = test_py35()
>>> buffer
[]
>>> yielded, _ = run_async(testfunc(), check_type=False)
>>> yielded == [i * 100 for i in range(1, 11)] or yielded
True
>>> buffer == [i*2 for i in range(1, 101)] or buffer
True
"""
cdef class AsyncYieldFrom:
cdef object obj
def __init__(self, obj):
self.obj = obj
def __await__(self):
yield from self.obj
cdef class AsyncYield:
cdef object value
def __init__(self, value):
self.value = value
def __await__(self):
yield self.value
def run_async(coro, check_type='coroutine'):
if check_type:
assert coro.__class__.__name__ == check_type, \
'type(%s) != %s' % (coro.__class__, check_type)
buffer = []
result = None
while True:
try:
buffer.append(coro.send(None))
except StopIteration as ex:
result = ex.args[0] if ex.args else None
break
return buffer, result
cdef class AsyncIter:
cdef long i
cdef long aiter_calls
cdef long max_iter_calls
def __init__(self, long max_iter_calls=1):
self.i = 0
self.aiter_calls = 0
self.max_iter_calls = max_iter_calls
async def __aiter__(self):
self.aiter_calls += 1
return self
async def __anext__(self):
self.i += 1
assert self.aiter_calls <= self.max_iter_calls
if not (self.i % 10):
await AsyncYield(self.i * 10)
if self.i > 100:
raise StopAsyncIteration
return self.i, self.i
def test_for_1():
"""
>>> testfunc, buffer = test_for_1()
>>> buffer
[]
>>> yielded, _ = run_async(testfunc())
>>> yielded == [i * 100 for i in range(1, 11)] or yielded
True
>>> buffer == [i*2 for i in range(1, 101)] or buffer
True
"""
buffer = []
async def test1():
async for i1, i2 in AsyncIter(1):
buffer.append(i1 + i2)
return test1, buffer
def test_for_2():
"""
>>> testfunc, buffer = test_for_2()
>>> buffer
[]
>>> yielded, _ = run_async(testfunc())
>>> yielded == [100, 200] or yielded
True
>>> buffer == [i for i in range(1, 21)] + ['end'] or buffer
True
"""
buffer = []
async def test2():
nonlocal buffer
async for i in AsyncIter(2):
buffer.append(i[0])
if i[0] == 20:
break
else:
buffer.append('what?')
buffer.append('end')
return test2, buffer
def test_for_3():
"""
>>> testfunc, buffer = test_for_3()
>>> buffer
[]
>>> yielded, _ = run_async(testfunc())
>>> yielded == [i * 100 for i in range(1, 11)] or yielded
True
>>> buffer == [i for i in range(1, 21)] + ['what?', 'end'] or buffer
True
"""
buffer = []
async def test3():
nonlocal buffer
async for i in AsyncIter(3):
if i[0] > 20:
continue
buffer.append(i[0])
else:
buffer.append('what?')
buffer.append('end')
return test3, buffer
cdef class NonAwaitableFromAnext:
async def __aiter__(self):
return self
def __anext__(self):
return 123
def test_broken_anext():
"""
>>> testfunc = test_broken_anext()
>>> try: run_async(testfunc())
... except TypeError as exc:
... assert ' int ' in str(exc)
... else:
... print("NOT RAISED!")
"""
async def foo():
async for i in NonAwaitableFromAnext():
print('never going to happen')
return foo
cdef class Manager:
cdef readonly list counter
def __init__(self, counter):
self.counter = counter
async def __aenter__(self):
self.counter[0] += 10000
async def __aexit__(self, *args):
self.counter[0] += 100000
cdef class Iterable:
cdef long i
def __init__(self):
self.i = 0
async def __aiter__(self):
return self
async def __anext__(self):
if self.i > 10:
raise StopAsyncIteration
self.i += 1
return self.i
def test_with_for():
"""
>>> test_with_for()
111011
333033
20555255
"""
I = [0]
manager = Manager(I)
iterable = Iterable()
mrefs_before = sys.getrefcount(manager)
irefs_before = sys.getrefcount(iterable)
async def main():
async with manager:
async for i in iterable:
I[0] += 1
I[0] += 1000
run_async(main())
print(I[0])
assert sys.getrefcount(manager) == mrefs_before
assert sys.getrefcount(iterable) == irefs_before
##############
async def main():
nonlocal I
async with Manager(I):
async for i in Iterable():
I[0] += 1
I[0] += 1000
async with Manager(I):
async for i in Iterable():
I[0] += 1
I[0] += 1000
run_async(main())
print(I[0])
##############
async def main():
async with Manager(I):
I[0] += 100
async for i in Iterable():
I[0] += 1
else:
I[0] += 10000000
I[0] += 1000
async with Manager(I):
I[0] += 100
async for i in Iterable():
I[0] += 1
else:
I[0] += 10000000
I[0] += 1000
run_async(main())
print(I[0])
cdef class AI:
async def __aiter__(self):
1/0
def test_aiter_raises():
"""
>>> test_aiter_raises()
RAISED
0
"""
CNT = 0
async def foo():
nonlocal CNT
async for i in AI():
CNT += 1
CNT += 10
try:
run_async(foo())
except ZeroDivisionError:
print("RAISED")
else:
print("NOT RAISED")
return CNT
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment