Commit ede3eb92 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge branch 'cdef_closure'

parents 78a0f1fd 914dc35d
......@@ -2395,8 +2395,8 @@ class CFuncDefNode(FuncDefNode):
def generate_argument_parsing_code(self, env, code):
i = 0
used = 0
scope = self.local_scope
if self.type.optional_arg_count:
scope = self.local_scope
code.putln('if (%s) {' % Naming.optional_args_cname)
for arg in self.args:
if arg.default:
......@@ -2417,6 +2417,16 @@ class CFuncDefNode(FuncDefNode):
code.putln('}')
code.putln('}')
# Move arguments into closure if required
def put_into_closure(entry):
if entry.in_closure and not arg.default:
code.putln('%s = %s;' % (entry.cname, entry.original_cname))
code.put_var_incref(entry)
code.put_var_giveref(entry)
for arg in self.args:
put_into_closure(scope.lookup_here(arg.name))
def generate_argument_conversion_code(self, code):
pass
......
......@@ -2261,8 +2261,8 @@ class MarkClosureVisitor(CythonTransform):
def visit_CFuncDefNode(self, node):
self.visit_FuncDefNode(node)
if node.needs_closure:
error(node.pos, "closures inside cdef functions not yet supported")
if node.needs_closure and node.overridable:
error(node.pos, "closures inside cpdef functions not yet supported")
return node
def visit_LambdaNode(self, node):
......@@ -2401,8 +2401,11 @@ class CreateClosureClasses(CythonTransform):
return node
def visit_CFuncDefNode(self, node):
self.visitchildren(node)
return node
if not node.overridable:
return self.visit_FuncDefNode(node)
else:
self.visitchildren(node)
return node
class GilCheck(VisitorTransform):
......
......@@ -115,6 +115,77 @@ def unpatch_inspect_isfunction():
else:
inspect.isfunction = orig_isfunction
def def_to_cdef(source):
'''
Converts the module-level def methods into cdef methods, i.e.
@decorator
def foo([args]):
"""
[tests]
"""
[body]
becomes
def foo([args]):
"""
[tests]
"""
return foo_c([args])
cdef foo_c([args]):
[body]
'''
output = []
skip = False
def_node = re.compile(r'def (\w+)\(([^()*]*)\):').match
lines = iter(source.split('\n'))
for line in lines:
if not line.strip():
output.append(line)
continue
if skip:
if line[0] != ' ':
skip = False
else:
continue
if line[0] == '@':
skip = True
continue
m = def_node(line)
if m:
name = m.group(1)
args = m.group(2)
if args:
args_no_types = ", ".join(arg.split()[-1] for arg in args.split(','))
else:
args_no_types = ""
output.append("def %s(%s):" % (name, args_no_types))
line = next(lines)
if '"""' in line:
has_docstring = True
output.append(line)
for line in lines:
output.append(line)
if '"""' in line:
break
else:
has_docstring = False
output.append(" return %s_c(%s)" % (name, args_no_types))
output.append('')
output.append("cdef %s_c(%s):" % (name, args))
if not has_docstring:
output.append(line)
else:
output.append(line)
return '\n'.join(output)
def update_linetrace_extension(ext):
ext.define_macros.append(('CYTHON_TRACE', 1))
return ext
......@@ -331,7 +402,7 @@ def parse_tags(filepath):
if tag == 'tags':
tag = 'tag'
print("WARNING: test tags use the 'tag' directive, not 'tags' (%s)" % filepath)
if tag not in ('mode', 'tag', 'ticket', 'cython', 'distutils'):
if tag not in ('mode', 'tag', 'ticket', 'cython', 'distutils', 'preparse'):
print("WARNING: unknown test directive '%s' found (%s)" % (tag, filepath))
values = values.split(',')
tags[tag].extend(filter(None, [value.strip() for value in values]))
......@@ -532,19 +603,25 @@ class TestBuilder(object):
elif 'no-cpp' in tags['tag'] and 'cpp' in self.languages:
languages = list(languages)
languages.remove('cpp')
preparse_list = tags.get('preparse', ['id'])
tests = [ self.build_test(test_class, path, workdir, module, tags,
language, expect_errors, warning_errors)
for language in languages ]
language, expect_errors, warning_errors, preparse)
for language in languages
for preparse in preparse_list ]
return tests
def build_test(self, test_class, path, workdir, module, tags,
language, expect_errors, warning_errors):
language, expect_errors, warning_errors, preparse):
language_workdir = os.path.join(workdir, language)
if not os.path.exists(language_workdir):
os.makedirs(language_workdir)
workdir = os.path.join(language_workdir, module)
if preparse != 'id':
workdir += '_%s' % str(preparse)
return test_class(path, workdir, module, tags,
language=language,
preparse=preparse,
expect_errors=expect_errors,
annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir,
......@@ -556,7 +633,7 @@ class TestBuilder(object):
warning_errors=warning_errors)
class CythonCompileTestCase(unittest.TestCase):
def __init__(self, test_directory, workdir, module, tags, language='c',
def __init__(self, test_directory, workdir, module, tags, language='c', preparse='id',
expect_errors=False, annotate=False, cleanup_workdir=True,
cleanup_sharedlibs=True, cleanup_failures=True, cython_only=False,
fork=True, language_level=2, warning_errors=False):
......@@ -565,6 +642,8 @@ class CythonCompileTestCase(unittest.TestCase):
self.workdir = workdir
self.module = module
self.language = language
self.preparse = preparse
self.name = module if self.preparse == "id" else "%s_%s" % (module, preparse)
self.expect_errors = expect_errors
self.annotate = annotate
self.cleanup_workdir = cleanup_workdir
......@@ -577,7 +656,7 @@ class CythonCompileTestCase(unittest.TestCase):
unittest.TestCase.__init__(self)
def shortDescription(self):
return "compiling (%s) %s" % (self.language, self.module)
return "compiling (%s) %s" % (self.language, self.name)
def setUp(self):
from Cython.Compiler import Options
......@@ -660,11 +739,16 @@ class CythonCompileTestCase(unittest.TestCase):
if is_related(filename)]
def copy_files(self, test_directory, target_directory, file_list):
# use symlink on Unix, copy on Windows
try:
copy = os.symlink
except AttributeError:
copy = shutil.copy
if self.preparse and self.preparse != 'id':
preparse_func = globals()[self.preparse]
def copy(src, dest):
open(dest, 'w').write(preparse_func(open(src).read()))
else:
# use symlink on Unix, copy on Windows
try:
copy = os.symlink
except AttributeError:
copy = shutil.copy
join = os.path.join
for filename in file_list:
......@@ -707,6 +791,12 @@ class CythonCompileTestCase(unittest.TestCase):
include_dirs.append(incdir)
source = self.find_module_source_file(
os.path.join(test_directory, module + '.pyx'))
if self.preparse == 'id':
source = self.find_module_source_file(
os.path.join(test_directory, module + '.pyx'))
else:
self.copy_files(test_directory, targetdir, [module + '.pyx'])
source = os.path.join(targetdir, module + '.pyx')
target = os.path.join(targetdir, self.build_target_filename(module))
if extra_compile_options is None:
......@@ -903,7 +993,7 @@ class CythonRunTestCase(CythonCompileTestCase):
if self.cython_only:
return CythonCompileTestCase.shortDescription(self)
else:
return "compiling (%s) and running %s" % (self.language, self.module)
return "compiling (%s) and running %s" % (self.language, self.name)
def run(self, result=None):
if result is None:
......@@ -1105,7 +1195,7 @@ class PartialTestResult(_TextTestResult):
class CythonUnitTestCase(CythonRunTestCase):
def shortDescription(self):
return "compiling (%s) tests in %s" % (self.language, self.module)
return "compiling (%s) tests in %s" % (self.language, self.name)
def run_tests(self, result, ext_so_path):
module = import_ext(self.module, ext_so_path)
......
......@@ -6,7 +6,6 @@ unsignedbehaviour_T184
missing_baseclass_in_predecl_T262
cfunc_call_tuple_args_T408
cpp_structs
closure_inside_cdef_T554
genexpr_iterable_lookup_T600
generator_expressions_in_class
for_from_pyvar_loop_T601
......
# mode: error
cdef cdef_yield():
def inner():
pass
cpdef cpdef_yield():
def inner():
pass
_ERRORS = u"""
3:5: closures inside cdef functions not yet supported
7:6: closures inside cdef functions not yet supported
3:6: closures inside cpdef functions not yet supported
"""
......@@ -38,3 +38,19 @@ cdef class SelfInClosure(object):
def nested():
return self.x, t.x
return nested
def call_closure_method_cdef_attr_c(self, Test t):
"""
>>> o = SelfInClosure()
>>> o.call_closure_method_cdef_attr_c(Test())()
(1, 2)
"""
return self.closure_method_cdef_attr_c(t)
cdef closure_method_cdef_attr_c(self, Test t):
t.x = 2
self._t = t
self.x = 1
def nested():
return self.x, t.x
return nested
# mode: run
# tag: closures
# preparse: id
# preparse: def_to_cdef
#
# closure_tests_1.pyx
#
......
# mode: run
# tag: closures
# preparse: id
# preparse: def_to_cdef
#
# closure_tests_2.pyx
#
......
# mode: run
# tag: closures
# preparse: id
# preparse: def_to_cdef
#
# closure_tests_3.pyx
#
......
# mode: run
# tag: closures
# preparse: id
# preparse: def_to_cdef
#
# closure_tests_4.pyx
#
......
# mode: run
# tag: closures
# ticket: 82
# preparse: id
# preparse: def_to_cdef
cimport cython
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment