Commit cf63ff71 authored by Stefan Behnel's avatar Stefan Behnel

patch Generator ABC into "collections.abc" when using Cython generators so...

patch Generator ABC into "collections.abc" when using Cython generators so that foreign code can test for the Generator protocol instead of the generator type
parent bb4d9c2d
...@@ -396,6 +396,9 @@ class UtilityCode(UtilityCodeBase): ...@@ -396,6 +396,9 @@ class UtilityCode(UtilityCodeBase):
def inject_string_constants(self, impl, output): def inject_string_constants(self, impl, output):
"""Replace 'PYIDENT("xyz")' by a constant Python identifier cname. """Replace 'PYIDENT("xyz")' by a constant Python identifier cname.
""" """
if 'PYIDENT(' not in impl:
return False, impl
replacements = {} replacements = {}
def externalise(matchobj): def externalise(matchobj):
name = matchobj.group(1) name = matchobj.group(1)
...@@ -406,9 +409,26 @@ class UtilityCode(UtilityCodeBase): ...@@ -406,9 +409,26 @@ class UtilityCode(UtilityCodeBase):
StringEncoding.EncodedString(name)).cname StringEncoding.EncodedString(name)).cname
return cname return cname
impl = re.sub('PYIDENT\("([^"]+)"\)', externalise, impl) impl = re.sub(r'PYIDENT\("([^"]+)"\)', externalise, impl)
assert 'PYIDENT(' not in impl
return bool(replacements), impl return bool(replacements), impl
def wrap_c_strings(self, impl):
"""Replace CSTRING('''xyz''') by a C compatible string
"""
if 'CSTRING(' not in impl:
return impl
def split_string(matchobj):
content = matchobj.group(1).replace('"', '\042')
return ''.join(
'"%s\\n"\n' % line if not line.endswith('\\') or line.endswith('\\\\') else '"%s"\n' % line[:-1]
for line in content.splitlines())
impl = re.sub(r'CSTRING\(\s*"""([^"]+|"[^"])"""\s*\)', split_string, impl)
assert 'CSTRING(' not in impl
return impl
def put_code(self, output): def put_code(self, output):
if self.requires: if self.requires:
for dependency in self.requires: for dependency in self.requires:
...@@ -418,7 +438,7 @@ class UtilityCode(UtilityCodeBase): ...@@ -418,7 +438,7 @@ class UtilityCode(UtilityCodeBase):
self.format_code(self.proto), self.format_code(self.proto),
'%s_proto' % self.name) '%s_proto' % self.name)
if self.impl: if self.impl:
impl = self.format_code(self.impl) impl = self.format_code(self.wrap_c_strings(self.impl))
is_specialised, impl = self.inject_string_constants(impl, output) is_specialised, impl = self.inject_string_constants(impl, output)
if not is_specialised: if not is_specialised:
# no module specific adaptations => can be reused # no module specific adaptations => can be reused
......
...@@ -2151,6 +2151,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2151,6 +2151,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Execution code ---*/") code.putln("/*--- Execution code ---*/")
code.mark_pos(None) code.mark_pos(None)
code.putln("#ifdef __Pyx_Generator_USED")
code.put_error_if_neg(self.pos, "__Pyx_patch_abc()")
code.putln("#endif")
if profile or linetrace: if profile or linetrace:
code.put_trace_call(header3, self.pos, nogil=not code.funcstate.gil_owned) code.put_trace_call(header3, self.pos, nogil=not code.funcstate.gil_owned)
code.funcstate.can_trace = True code.funcstate.can_trace = True
......
...@@ -84,6 +84,7 @@ static int __Pyx_PyGen_FetchStopIterationValue(PyObject **pvalue); ...@@ -84,6 +84,7 @@ static int __Pyx_PyGen_FetchStopIterationValue(PyObject **pvalue);
//@requires: ObjectHandling.c::PyObjectCallMethod1 //@requires: ObjectHandling.c::PyObjectCallMethod1
//@requires: ObjectHandling.c::PyObjectGetAttrStr //@requires: ObjectHandling.c::PyObjectGetAttrStr
//@requires: CommonTypes.c::FetchCommonType //@requires: CommonTypes.c::FetchCommonType
//@requires: PatchGeneratorABC
static PyObject *__Pyx_Generator_Next(PyObject *self); static PyObject *__Pyx_Generator_Next(PyObject *self);
static PyObject *__Pyx_Generator_Send(PyObject *self, PyObject *value); static PyObject *__Pyx_Generator_Send(PyObject *self, PyObject *value);
...@@ -793,7 +794,7 @@ static PyObject* __Pyx_Generator_patch_module(PyObject* module, const char* py_c ...@@ -793,7 +794,7 @@ static PyObject* __Pyx_Generator_patch_module(PyObject* module, const char* py_c
globals = PyDict_New(); if (unlikely(!globals)) goto ignore; globals = PyDict_New(); if (unlikely(!globals)) goto ignore;
if (unlikely(PyDict_SetItemString(globals, "_cython_generator_type", (PyObject*)__pyx_GeneratorType) < 0)) goto ignore; if (unlikely(PyDict_SetItemString(globals, "_cython_generator_type", (PyObject*)__pyx_GeneratorType) < 0)) goto ignore;
if (unlikely(PyDict_SetItemString(globals, "_module", module) < 0)) goto ignore; if (unlikely(PyDict_SetItemString(globals, "_module", module) < 0)) goto ignore;
if (unlikely(PyDict_SetItemString(globals, "_b", $builtins_cname) < 0)) goto ignore; if (unlikely(PyDict_SetItemString(globals, "__builtins__", $builtins_cname) < 0)) goto ignore;
result_obj = PyRun_String(py_code, Py_file_input, globals, globals); result_obj = PyRun_String(py_code, Py_file_input, globals, globals);
if (unlikely(!result_obj)) goto ignore; if (unlikely(!result_obj)) goto ignore;
Py_DECREF(result_obj); Py_DECREF(result_obj);
...@@ -815,6 +816,110 @@ ignore: ...@@ -815,6 +816,110 @@ ignore:
} }
//////////////////// PatchGeneratorABC.proto ////////////////////
// patch 'collections.abc' if it lacks generator support
// see https://bugs.python.org/issue24018
static int __Pyx_patch_abc(void); /*proto*/
//////////////////// PatchGeneratorABC ////////////////////
//@requires: PatchModuleWithGenerator
static int __Pyx_patch_abc(void) {
#if defined(__Pyx_Generator_USED) && (!defined(CYTHON_PATCH_ABC) || CYTHON_PATCH_ABC)
static int abc_patched = 0;
if (!abc_patched) {
PyObject *module;
module = PyImport_ImportModule("collections.abc");
if (!module) {
PyErr_Clear();
module = PyImport_ImportModule("collections");
if (!module)
PyErr_Clear();
}
if (module) {
PyObject *abc = PyObject_GetAttrString(module, "Generator");
if (abc) {
abc_patched = 1;
Py_DECREF(abc);
} else {
PyErr_Clear();
module = __Pyx_Generator_patch_module(
module, CSTRING("""\
def mk_gen():
from abc import abstractmethod
required_methods = (
'__iter__', '__next__' if hasattr(iter(()), '__next__') else 'next',
'send', 'throw', 'close')
class Generator(_module.Iterator):
__slots__ = ()
if '__next__' in required_methods:
def __next__(self):
return self.send(None)
else:
def next(self):
return self.send(None)
@abstractmethod
def send(self, value):
raise StopIteration
def throw(self, typ, val=None, tb=None):
if val is None:
if tb is None:
raise typ
val = typ()
if tb is not None:
val = val.with_traceback(tb)
raise val
def close(self):
try:
self.throw(GeneratorExit)
except (GeneratorExit, StopIteration):
pass
else:
raise RuntimeError('generator ignored GeneratorExit')
@classmethod
def __subclasshook__(cls, C):
if cls is Generator:
mro = C.__mro__
for method in required_methods:
for base in mro:
if method in base.__dict__:
break
else:
return NotImplemented
return True
return NotImplemented
generator = type((lambda: (yield))())
Generator.register(generator)
Generator.register(_cython_generator_type)
return Generator
_module.Generator = mk_gen()
""")
);
abc_patched = 1;
if (unlikely(!module))
return -1;
}
Py_DECREF(module);
}
}
#else
// avoid "unused" warning for __Pyx_Generator_patch_module()
if (0) __Pyx_Generator_patch_module(NULL, NULL);
#endif
return 0;
}
//////////////////// PatchAsyncIO.proto //////////////////// //////////////////// PatchAsyncIO.proto ////////////////////
// run after importing "asyncio" to patch Cython generator support into it // run after importing "asyncio" to patch Cython generator support into it
...@@ -833,10 +938,11 @@ static PyObject* __Pyx_patch_asyncio(PyObject* module) { ...@@ -833,10 +938,11 @@ static PyObject* __Pyx_patch_asyncio(PyObject* module) {
package = __Pyx_Import(PYIDENT("asyncio.coroutines"), NULL, 0); package = __Pyx_Import(PYIDENT("asyncio.coroutines"), NULL, 0);
if (package) { if (package) {
patch_module = __Pyx_Generator_patch_module( patch_module = __Pyx_Generator_patch_module(
PyObject_GetAttrString(package, "coroutines"), PyObject_GetAttrString(package, "coroutines"), CSTRING("""\
"old_types = _b.getattr(_module, '_COROUTINE_TYPES', None)\n" old_types = getattr(_module, '_COROUTINE_TYPES', None)
"if old_types is not None and _cython_generator_type not in old_types:\n" if old_types is not None and _cython_generator_type not in old_types:
" _module._COROUTINE_TYPES = _b.type(old_types) (_b.tuple(old_types) + (_cython_generator_type,))\n" _module._COROUTINE_TYPES = type(old_types) (tuple(old_types) + (_cython_generator_type,))
""")
); );
#if PY_VERSION_HEX < 0x03050000 #if PY_VERSION_HEX < 0x03050000
} else { } else {
...@@ -845,14 +951,15 @@ static PyObject* __Pyx_patch_asyncio(PyObject* module) { ...@@ -845,14 +951,15 @@ static PyObject* __Pyx_patch_asyncio(PyObject* module) {
package = __Pyx_Import(PYIDENT("asyncio.tasks"), NULL, 0); package = __Pyx_Import(PYIDENT("asyncio.tasks"), NULL, 0);
if (unlikely(!package)) goto asyncio_done; if (unlikely(!package)) goto asyncio_done;
patch_module = __Pyx_Generator_patch_module( patch_module = __Pyx_Generator_patch_module(
PyObject_GetAttrString(package, "tasks"), PyObject_GetAttrString(package, "tasks"), CSTRING("""\
"if (_b.hasattr(_module, 'iscoroutine') and" if (hasattr(_module, 'iscoroutine') and
" _b.getattr(_module.iscoroutine, '_cython_generator_type', None) is not _cython_generator_type):\n" getattr(_module.iscoroutine, '_cython_generator_type', None) is not _cython_generator_type):
" def cy_wrap(orig_func, cython_generator_type=_cython_generator_type, type=_b.type):\n" def cy_wrap(orig_func, cython_generator_type=_cython_generator_type, type=type):
" def cy_iscoroutine(obj): return type(obj) is cython_generator_type or orig_func(obj)\n" def cy_iscoroutine(obj): return type(obj) is cython_generator_type or orig_func(obj)
" cy_iscoroutine._cython_generator_type = cython_generator_type\n" cy_iscoroutine._cython_generator_type = cython_generator_type
" return cy_iscoroutine\n" return cy_iscoroutine
" _module.iscoroutine = cy_wrap(_module.iscoroutine)\n" _module.iscoroutine = cy_wrap(_module.iscoroutine)
""")
); );
#endif #endif
} }
...@@ -909,13 +1016,14 @@ static PyObject* __Pyx_patch_inspect(PyObject* module) { ...@@ -909,13 +1016,14 @@ static PyObject* __Pyx_patch_inspect(PyObject* module) {
static int inspect_patched = 0; static int inspect_patched = 0;
if (unlikely((!inspect_patched) && module)) { if (unlikely((!inspect_patched) && module)) {
module = __Pyx_Generator_patch_module( module = __Pyx_Generator_patch_module(
module, module, CSTRING("""\
"if _b.getattr(_module.isgenerator, '_cython_generator_type', None) is not _cython_generator_type:\n" if getattr(_module.isgenerator, '_cython_generator_type', None) is not _cython_generator_type:
" def cy_wrap(orig_func, cython_generator_type=_cython_generator_type, type=_b.type):\n" def cy_wrap(orig_func, cython_generator_type=_cython_generator_type, type=type):
" def cy_isgenerator(obj): return type(obj) is cython_generator_type or orig_func(obj)\n" def cy_isgenerator(obj): return type(obj) is cython_generator_type or orig_func(obj)
" cy_isgenerator._cython_generator_type = cython_generator_type\n" cy_isgenerator._cython_generator_type = cython_generator_type
" return cy_isgenerator\n" return cy_isgenerator
" _module.isgenerator = cy_wrap(_module.isgenerator)\n" _module.isgenerator = cy_wrap(_module.isgenerator)
""")
); );
inspect_patched = 1; inspect_patched = 1;
} }
......
...@@ -63,6 +63,15 @@ runloop(import_asyncio.wait3) # 2b) ...@@ -63,6 +63,15 @@ runloop(import_asyncio.wait3) # 2b)
runloop(from_asyncio_import.wait3) # 3a) runloop(from_asyncio_import.wait3) # 3a)
runloop(import_asyncio.wait3) # 3b) runloop(import_asyncio.wait3) # 3b)
try:
from collections.abc import Generator
except ImportError:
from collections import Generator
assert isinstance(from_asyncio_import.wait3(), Generator)
assert isinstance(import_asyncio.wait3(), Generator)
assert isinstance((lambda:(yield))(), Generator)
######## import_asyncio.pyx ######## ######## import_asyncio.pyx ########
# cython: binding=True # cython: binding=True
......
# mode: run # mode: run
# tag: generators # tag: generators
try:
from collections.abc import Generator
except ImportError:
from collections import Generator
def very_simple(): def very_simple():
""" """
>>> x = very_simple() >>> x = very_simple()
...@@ -450,3 +456,21 @@ def test_double_with_gil_section(): ...@@ -450,3 +456,21 @@ def test_double_with_gil_section():
pass pass
with gil: with gil:
pass pass
def test_generator_abc():
"""
>>> isinstance(test_generator_abc(), Generator)
True
>>> try:
... from collections.abc import Generator
... except ImportError:
... from collections import Generator
>>> isinstance(test_generator_abc(), Generator)
True
>>> isinstance((lambda:(yield))(), Generator)
True
"""
yield 1
...@@ -4,10 +4,9 @@ ...@@ -4,10 +4,9 @@
import cython import cython
try: try:
from builtins import next # Py3k from collections.abc import Generator
except ImportError: except ImportError:
def next(it): from collections import Generator
return it.next()
def very_simple(): def very_simple():
...@@ -384,3 +383,23 @@ def test_yield_in_const_conditional_true(): ...@@ -384,3 +383,23 @@ def test_yield_in_const_conditional_true():
""" """
if True: if True:
print((yield 1)) print((yield 1))
def test_generator_abc():
"""
>>> isinstance(test_generator_abc(), Generator)
True
>>> isinstance((lambda:(yield))(), Generator)
True
>>> try:
... from collections.abc import Generator
... except ImportError:
... from collections import Generator
>>> isinstance(test_generator_abc(), Generator)
True
>>> isinstance((lambda:(yield))(), Generator)
True
"""
yield 1
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment