Commit cf63ff71 authored by Stefan Behnel's avatar Stefan Behnel

patch Generator ABC into "collections.abc" when using Cython generators so...

patch Generator ABC into "collections.abc" when using Cython generators so that foreign code can test for the Generator protocol instead of the generator type
parent bb4d9c2d
......@@ -396,6 +396,9 @@ class UtilityCode(UtilityCodeBase):
def inject_string_constants(self, impl, output):
"""Replace 'PYIDENT("xyz")' by a constant Python identifier cname.
"""
if 'PYIDENT(' not in impl:
return False, impl
replacements = {}
def externalise(matchobj):
name = matchobj.group(1)
......@@ -406,9 +409,26 @@ class UtilityCode(UtilityCodeBase):
StringEncoding.EncodedString(name)).cname
return cname
impl = re.sub('PYIDENT\("([^"]+)"\)', externalise, impl)
impl = re.sub(r'PYIDENT\("([^"]+)"\)', externalise, impl)
assert 'PYIDENT(' not in impl
return bool(replacements), impl
def wrap_c_strings(self, impl):
"""Replace CSTRING('''xyz''') by a C compatible string
"""
if 'CSTRING(' not in impl:
return impl
def split_string(matchobj):
content = matchobj.group(1).replace('"', '\042')
return ''.join(
'"%s\\n"\n' % line if not line.endswith('\\') or line.endswith('\\\\') else '"%s"\n' % line[:-1]
for line in content.splitlines())
impl = re.sub(r'CSTRING\(\s*"""([^"]+|"[^"])"""\s*\)', split_string, impl)
assert 'CSTRING(' not in impl
return impl
def put_code(self, output):
if self.requires:
for dependency in self.requires:
......@@ -418,7 +438,7 @@ class UtilityCode(UtilityCodeBase):
self.format_code(self.proto),
'%s_proto' % self.name)
if self.impl:
impl = self.format_code(self.impl)
impl = self.format_code(self.wrap_c_strings(self.impl))
is_specialised, impl = self.inject_string_constants(impl, output)
if not is_specialised:
# no module specific adaptations => can be reused
......
......@@ -2151,6 +2151,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Execution code ---*/")
code.mark_pos(None)
code.putln("#ifdef __Pyx_Generator_USED")
code.put_error_if_neg(self.pos, "__Pyx_patch_abc()")
code.putln("#endif")
if profile or linetrace:
code.put_trace_call(header3, self.pos, nogil=not code.funcstate.gil_owned)
code.funcstate.can_trace = True
......
......@@ -84,6 +84,7 @@ static int __Pyx_PyGen_FetchStopIterationValue(PyObject **pvalue);
//@requires: ObjectHandling.c::PyObjectCallMethod1
//@requires: ObjectHandling.c::PyObjectGetAttrStr
//@requires: CommonTypes.c::FetchCommonType
//@requires: PatchGeneratorABC
static PyObject *__Pyx_Generator_Next(PyObject *self);
static PyObject *__Pyx_Generator_Send(PyObject *self, PyObject *value);
......@@ -793,7 +794,7 @@ static PyObject* __Pyx_Generator_patch_module(PyObject* module, const char* py_c
globals = PyDict_New(); if (unlikely(!globals)) goto ignore;
if (unlikely(PyDict_SetItemString(globals, "_cython_generator_type", (PyObject*)__pyx_GeneratorType) < 0)) goto ignore;
if (unlikely(PyDict_SetItemString(globals, "_module", module) < 0)) goto ignore;
if (unlikely(PyDict_SetItemString(globals, "_b", $builtins_cname) < 0)) goto ignore;
if (unlikely(PyDict_SetItemString(globals, "__builtins__", $builtins_cname) < 0)) goto ignore;
result_obj = PyRun_String(py_code, Py_file_input, globals, globals);
if (unlikely(!result_obj)) goto ignore;
Py_DECREF(result_obj);
......@@ -815,6 +816,110 @@ ignore:
}
//////////////////// PatchGeneratorABC.proto ////////////////////
// patch 'collections.abc' if it lacks generator support
// see https://bugs.python.org/issue24018
static int __Pyx_patch_abc(void); /*proto*/
//////////////////// PatchGeneratorABC ////////////////////
//@requires: PatchModuleWithGenerator
static int __Pyx_patch_abc(void) {
#if defined(__Pyx_Generator_USED) && (!defined(CYTHON_PATCH_ABC) || CYTHON_PATCH_ABC)
static int abc_patched = 0;
if (!abc_patched) {
PyObject *module;
module = PyImport_ImportModule("collections.abc");
if (!module) {
PyErr_Clear();
module = PyImport_ImportModule("collections");
if (!module)
PyErr_Clear();
}
if (module) {
PyObject *abc = PyObject_GetAttrString(module, "Generator");
if (abc) {
abc_patched = 1;
Py_DECREF(abc);
} else {
PyErr_Clear();
module = __Pyx_Generator_patch_module(
module, CSTRING("""\
def mk_gen():
from abc import abstractmethod
required_methods = (
'__iter__', '__next__' if hasattr(iter(()), '__next__') else 'next',
'send', 'throw', 'close')
class Generator(_module.Iterator):
__slots__ = ()
if '__next__' in required_methods:
def __next__(self):
return self.send(None)
else:
def next(self):
return self.send(None)
@abstractmethod
def send(self, value):
raise StopIteration
def throw(self, typ, val=None, tb=None):
if val is None:
if tb is None:
raise typ
val = typ()
if tb is not None:
val = val.with_traceback(tb)
raise val
def close(self):
try:
self.throw(GeneratorExit)
except (GeneratorExit, StopIteration):
pass
else:
raise RuntimeError('generator ignored GeneratorExit')
@classmethod
def __subclasshook__(cls, C):
if cls is Generator:
mro = C.__mro__
for method in required_methods:
for base in mro:
if method in base.__dict__:
break
else:
return NotImplemented
return True
return NotImplemented
generator = type((lambda: (yield))())
Generator.register(generator)
Generator.register(_cython_generator_type)
return Generator
_module.Generator = mk_gen()
""")
);
abc_patched = 1;
if (unlikely(!module))
return -1;
}
Py_DECREF(module);
}
}
#else
// avoid "unused" warning for __Pyx_Generator_patch_module()
if (0) __Pyx_Generator_patch_module(NULL, NULL);
#endif
return 0;
}
//////////////////// PatchAsyncIO.proto ////////////////////
// run after importing "asyncio" to patch Cython generator support into it
......@@ -833,10 +938,11 @@ static PyObject* __Pyx_patch_asyncio(PyObject* module) {
package = __Pyx_Import(PYIDENT("asyncio.coroutines"), NULL, 0);
if (package) {
patch_module = __Pyx_Generator_patch_module(
PyObject_GetAttrString(package, "coroutines"),
"old_types = _b.getattr(_module, '_COROUTINE_TYPES', None)\n"
"if old_types is not None and _cython_generator_type not in old_types:\n"
" _module._COROUTINE_TYPES = _b.type(old_types) (_b.tuple(old_types) + (_cython_generator_type,))\n"
PyObject_GetAttrString(package, "coroutines"), CSTRING("""\
old_types = getattr(_module, '_COROUTINE_TYPES', None)
if old_types is not None and _cython_generator_type not in old_types:
_module._COROUTINE_TYPES = type(old_types) (tuple(old_types) + (_cython_generator_type,))
""")
);
#if PY_VERSION_HEX < 0x03050000
} else {
......@@ -845,14 +951,15 @@ static PyObject* __Pyx_patch_asyncio(PyObject* module) {
package = __Pyx_Import(PYIDENT("asyncio.tasks"), NULL, 0);
if (unlikely(!package)) goto asyncio_done;
patch_module = __Pyx_Generator_patch_module(
PyObject_GetAttrString(package, "tasks"),
"if (_b.hasattr(_module, 'iscoroutine') and"
" _b.getattr(_module.iscoroutine, '_cython_generator_type', None) is not _cython_generator_type):\n"
" def cy_wrap(orig_func, cython_generator_type=_cython_generator_type, type=_b.type):\n"
" def cy_iscoroutine(obj): return type(obj) is cython_generator_type or orig_func(obj)\n"
" cy_iscoroutine._cython_generator_type = cython_generator_type\n"
" return cy_iscoroutine\n"
" _module.iscoroutine = cy_wrap(_module.iscoroutine)\n"
PyObject_GetAttrString(package, "tasks"), CSTRING("""\
if (hasattr(_module, 'iscoroutine') and
getattr(_module.iscoroutine, '_cython_generator_type', None) is not _cython_generator_type):
def cy_wrap(orig_func, cython_generator_type=_cython_generator_type, type=type):
def cy_iscoroutine(obj): return type(obj) is cython_generator_type or orig_func(obj)
cy_iscoroutine._cython_generator_type = cython_generator_type
return cy_iscoroutine
_module.iscoroutine = cy_wrap(_module.iscoroutine)
""")
);
#endif
}
......@@ -909,13 +1016,14 @@ static PyObject* __Pyx_patch_inspect(PyObject* module) {
static int inspect_patched = 0;
if (unlikely((!inspect_patched) && module)) {
module = __Pyx_Generator_patch_module(
module,
"if _b.getattr(_module.isgenerator, '_cython_generator_type', None) is not _cython_generator_type:\n"
" def cy_wrap(orig_func, cython_generator_type=_cython_generator_type, type=_b.type):\n"
" def cy_isgenerator(obj): return type(obj) is cython_generator_type or orig_func(obj)\n"
" cy_isgenerator._cython_generator_type = cython_generator_type\n"
" return cy_isgenerator\n"
" _module.isgenerator = cy_wrap(_module.isgenerator)\n"
module, CSTRING("""\
if getattr(_module.isgenerator, '_cython_generator_type', None) is not _cython_generator_type:
def cy_wrap(orig_func, cython_generator_type=_cython_generator_type, type=type):
def cy_isgenerator(obj): return type(obj) is cython_generator_type or orig_func(obj)
cy_isgenerator._cython_generator_type = cython_generator_type
return cy_isgenerator
_module.isgenerator = cy_wrap(_module.isgenerator)
""")
);
inspect_patched = 1;
}
......
......@@ -63,6 +63,15 @@ runloop(import_asyncio.wait3) # 2b)
runloop(from_asyncio_import.wait3) # 3a)
runloop(import_asyncio.wait3) # 3b)
try:
from collections.abc import Generator
except ImportError:
from collections import Generator
assert isinstance(from_asyncio_import.wait3(), Generator)
assert isinstance(import_asyncio.wait3(), Generator)
assert isinstance((lambda:(yield))(), Generator)
######## import_asyncio.pyx ########
# cython: binding=True
......
# mode: run
# tag: generators
try:
from collections.abc import Generator
except ImportError:
from collections import Generator
def very_simple():
"""
>>> x = very_simple()
......@@ -450,3 +456,21 @@ def test_double_with_gil_section():
pass
with gil:
pass
def test_generator_abc():
"""
>>> isinstance(test_generator_abc(), Generator)
True
>>> try:
... from collections.abc import Generator
... except ImportError:
... from collections import Generator
>>> isinstance(test_generator_abc(), Generator)
True
>>> isinstance((lambda:(yield))(), Generator)
True
"""
yield 1
......@@ -4,10 +4,9 @@
import cython
try:
from builtins import next # Py3k
from collections.abc import Generator
except ImportError:
def next(it):
return it.next()
from collections import Generator
def very_simple():
......@@ -384,3 +383,23 @@ def test_yield_in_const_conditional_true():
"""
if True:
print((yield 1))
def test_generator_abc():
"""
>>> isinstance(test_generator_abc(), Generator)
True
>>> isinstance((lambda:(yield))(), Generator)
True
>>> try:
... from collections.abc import Generator
... except ImportError:
... from collections import Generator
>>> isinstance(test_generator_abc(), Generator)
True
>>> isinstance((lambda:(yield))(), Generator)
True
"""
yield 1
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment