Commit ede3eb92 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge branch 'cdef_closure'

parents 78a0f1fd 914dc35d
...@@ -2395,8 +2395,8 @@ class CFuncDefNode(FuncDefNode): ...@@ -2395,8 +2395,8 @@ class CFuncDefNode(FuncDefNode):
def generate_argument_parsing_code(self, env, code): def generate_argument_parsing_code(self, env, code):
i = 0 i = 0
used = 0 used = 0
if self.type.optional_arg_count:
scope = self.local_scope scope = self.local_scope
if self.type.optional_arg_count:
code.putln('if (%s) {' % Naming.optional_args_cname) code.putln('if (%s) {' % Naming.optional_args_cname)
for arg in self.args: for arg in self.args:
if arg.default: if arg.default:
...@@ -2417,6 +2417,16 @@ class CFuncDefNode(FuncDefNode): ...@@ -2417,6 +2417,16 @@ class CFuncDefNode(FuncDefNode):
code.putln('}') code.putln('}')
code.putln('}') code.putln('}')
# Move arguments into closure if required
def put_into_closure(entry):
if entry.in_closure and not arg.default:
code.putln('%s = %s;' % (entry.cname, entry.original_cname))
code.put_var_incref(entry)
code.put_var_giveref(entry)
for arg in self.args:
put_into_closure(scope.lookup_here(arg.name))
def generate_argument_conversion_code(self, code): def generate_argument_conversion_code(self, code):
pass pass
......
...@@ -2261,8 +2261,8 @@ class MarkClosureVisitor(CythonTransform): ...@@ -2261,8 +2261,8 @@ class MarkClosureVisitor(CythonTransform):
def visit_CFuncDefNode(self, node): def visit_CFuncDefNode(self, node):
self.visit_FuncDefNode(node) self.visit_FuncDefNode(node)
if node.needs_closure: if node.needs_closure and node.overridable:
error(node.pos, "closures inside cdef functions not yet supported") error(node.pos, "closures inside cpdef functions not yet supported")
return node return node
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
...@@ -2401,6 +2401,9 @@ class CreateClosureClasses(CythonTransform): ...@@ -2401,6 +2401,9 @@ class CreateClosureClasses(CythonTransform):
return node return node
def visit_CFuncDefNode(self, node): def visit_CFuncDefNode(self, node):
if not node.overridable:
return self.visit_FuncDefNode(node)
else:
self.visitchildren(node) self.visitchildren(node)
return node return node
......
...@@ -115,6 +115,77 @@ def unpatch_inspect_isfunction(): ...@@ -115,6 +115,77 @@ def unpatch_inspect_isfunction():
else: else:
inspect.isfunction = orig_isfunction inspect.isfunction = orig_isfunction
def def_to_cdef(source):
'''
Converts the module-level def methods into cdef methods, i.e.
@decorator
def foo([args]):
"""
[tests]
"""
[body]
becomes
def foo([args]):
"""
[tests]
"""
return foo_c([args])
cdef foo_c([args]):
[body]
'''
output = []
skip = False
def_node = re.compile(r'def (\w+)\(([^()*]*)\):').match
lines = iter(source.split('\n'))
for line in lines:
if not line.strip():
output.append(line)
continue
if skip:
if line[0] != ' ':
skip = False
else:
continue
if line[0] == '@':
skip = True
continue
m = def_node(line)
if m:
name = m.group(1)
args = m.group(2)
if args:
args_no_types = ", ".join(arg.split()[-1] for arg in args.split(','))
else:
args_no_types = ""
output.append("def %s(%s):" % (name, args_no_types))
line = next(lines)
if '"""' in line:
has_docstring = True
output.append(line)
for line in lines:
output.append(line)
if '"""' in line:
break
else:
has_docstring = False
output.append(" return %s_c(%s)" % (name, args_no_types))
output.append('')
output.append("cdef %s_c(%s):" % (name, args))
if not has_docstring:
output.append(line)
else:
output.append(line)
return '\n'.join(output)
def update_linetrace_extension(ext): def update_linetrace_extension(ext):
ext.define_macros.append(('CYTHON_TRACE', 1)) ext.define_macros.append(('CYTHON_TRACE', 1))
return ext return ext
...@@ -331,7 +402,7 @@ def parse_tags(filepath): ...@@ -331,7 +402,7 @@ def parse_tags(filepath):
if tag == 'tags': if tag == 'tags':
tag = 'tag' tag = 'tag'
print("WARNING: test tags use the 'tag' directive, not 'tags' (%s)" % filepath) print("WARNING: test tags use the 'tag' directive, not 'tags' (%s)" % filepath)
if tag not in ('mode', 'tag', 'ticket', 'cython', 'distutils'): if tag not in ('mode', 'tag', 'ticket', 'cython', 'distutils', 'preparse'):
print("WARNING: unknown test directive '%s' found (%s)" % (tag, filepath)) print("WARNING: unknown test directive '%s' found (%s)" % (tag, filepath))
values = values.split(',') values = values.split(',')
tags[tag].extend(filter(None, [value.strip() for value in values])) tags[tag].extend(filter(None, [value.strip() for value in values]))
...@@ -532,19 +603,25 @@ class TestBuilder(object): ...@@ -532,19 +603,25 @@ class TestBuilder(object):
elif 'no-cpp' in tags['tag'] and 'cpp' in self.languages: elif 'no-cpp' in tags['tag'] and 'cpp' in self.languages:
languages = list(languages) languages = list(languages)
languages.remove('cpp') languages.remove('cpp')
preparse_list = tags.get('preparse', ['id'])
tests = [ self.build_test(test_class, path, workdir, module, tags, tests = [ self.build_test(test_class, path, workdir, module, tags,
language, expect_errors, warning_errors) language, expect_errors, warning_errors, preparse)
for language in languages ] for language in languages
for preparse in preparse_list ]
return tests return tests
def build_test(self, test_class, path, workdir, module, tags, def build_test(self, test_class, path, workdir, module, tags,
language, expect_errors, warning_errors): language, expect_errors, warning_errors, preparse):
language_workdir = os.path.join(workdir, language) language_workdir = os.path.join(workdir, language)
if not os.path.exists(language_workdir): if not os.path.exists(language_workdir):
os.makedirs(language_workdir) os.makedirs(language_workdir)
workdir = os.path.join(language_workdir, module) workdir = os.path.join(language_workdir, module)
if preparse != 'id':
workdir += '_%s' % str(preparse)
return test_class(path, workdir, module, tags, return test_class(path, workdir, module, tags,
language=language, language=language,
preparse=preparse,
expect_errors=expect_errors, expect_errors=expect_errors,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir, cleanup_workdir=self.cleanup_workdir,
...@@ -556,7 +633,7 @@ class TestBuilder(object): ...@@ -556,7 +633,7 @@ class TestBuilder(object):
warning_errors=warning_errors) warning_errors=warning_errors)
class CythonCompileTestCase(unittest.TestCase): class CythonCompileTestCase(unittest.TestCase):
def __init__(self, test_directory, workdir, module, tags, language='c', def __init__(self, test_directory, workdir, module, tags, language='c', preparse='id',
expect_errors=False, annotate=False, cleanup_workdir=True, expect_errors=False, annotate=False, cleanup_workdir=True,
cleanup_sharedlibs=True, cleanup_failures=True, cython_only=False, cleanup_sharedlibs=True, cleanup_failures=True, cython_only=False,
fork=True, language_level=2, warning_errors=False): fork=True, language_level=2, warning_errors=False):
...@@ -565,6 +642,8 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -565,6 +642,8 @@ class CythonCompileTestCase(unittest.TestCase):
self.workdir = workdir self.workdir = workdir
self.module = module self.module = module
self.language = language self.language = language
self.preparse = preparse
self.name = module if self.preparse == "id" else "%s_%s" % (module, preparse)
self.expect_errors = expect_errors self.expect_errors = expect_errors
self.annotate = annotate self.annotate = annotate
self.cleanup_workdir = cleanup_workdir self.cleanup_workdir = cleanup_workdir
...@@ -577,7 +656,7 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -577,7 +656,7 @@ class CythonCompileTestCase(unittest.TestCase):
unittest.TestCase.__init__(self) unittest.TestCase.__init__(self)
def shortDescription(self): def shortDescription(self):
return "compiling (%s) %s" % (self.language, self.module) return "compiling (%s) %s" % (self.language, self.name)
def setUp(self): def setUp(self):
from Cython.Compiler import Options from Cython.Compiler import Options
...@@ -660,6 +739,11 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -660,6 +739,11 @@ class CythonCompileTestCase(unittest.TestCase):
if is_related(filename)] if is_related(filename)]
def copy_files(self, test_directory, target_directory, file_list): def copy_files(self, test_directory, target_directory, file_list):
if self.preparse and self.preparse != 'id':
preparse_func = globals()[self.preparse]
def copy(src, dest):
open(dest, 'w').write(preparse_func(open(src).read()))
else:
# use symlink on Unix, copy on Windows # use symlink on Unix, copy on Windows
try: try:
copy = os.symlink copy = os.symlink
...@@ -707,6 +791,12 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -707,6 +791,12 @@ class CythonCompileTestCase(unittest.TestCase):
include_dirs.append(incdir) include_dirs.append(incdir)
source = self.find_module_source_file( source = self.find_module_source_file(
os.path.join(test_directory, module + '.pyx')) os.path.join(test_directory, module + '.pyx'))
if self.preparse == 'id':
source = self.find_module_source_file(
os.path.join(test_directory, module + '.pyx'))
else:
self.copy_files(test_directory, targetdir, [module + '.pyx'])
source = os.path.join(targetdir, module + '.pyx')
target = os.path.join(targetdir, self.build_target_filename(module)) target = os.path.join(targetdir, self.build_target_filename(module))
if extra_compile_options is None: if extra_compile_options is None:
...@@ -903,7 +993,7 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -903,7 +993,7 @@ class CythonRunTestCase(CythonCompileTestCase):
if self.cython_only: if self.cython_only:
return CythonCompileTestCase.shortDescription(self) return CythonCompileTestCase.shortDescription(self)
else: else:
return "compiling (%s) and running %s" % (self.language, self.module) return "compiling (%s) and running %s" % (self.language, self.name)
def run(self, result=None): def run(self, result=None):
if result is None: if result is None:
...@@ -1105,7 +1195,7 @@ class PartialTestResult(_TextTestResult): ...@@ -1105,7 +1195,7 @@ class PartialTestResult(_TextTestResult):
class CythonUnitTestCase(CythonRunTestCase): class CythonUnitTestCase(CythonRunTestCase):
def shortDescription(self): def shortDescription(self):
return "compiling (%s) tests in %s" % (self.language, self.module) return "compiling (%s) tests in %s" % (self.language, self.name)
def run_tests(self, result, ext_so_path): def run_tests(self, result, ext_so_path):
module = import_ext(self.module, ext_so_path) module = import_ext(self.module, ext_so_path)
......
...@@ -6,7 +6,6 @@ unsignedbehaviour_T184 ...@@ -6,7 +6,6 @@ unsignedbehaviour_T184
missing_baseclass_in_predecl_T262 missing_baseclass_in_predecl_T262
cfunc_call_tuple_args_T408 cfunc_call_tuple_args_T408
cpp_structs cpp_structs
closure_inside_cdef_T554
genexpr_iterable_lookup_T600 genexpr_iterable_lookup_T600
generator_expressions_in_class generator_expressions_in_class
for_from_pyvar_loop_T601 for_from_pyvar_loop_T601
......
# mode: error # mode: error
cdef cdef_yield():
def inner():
pass
cpdef cpdef_yield(): cpdef cpdef_yield():
def inner(): def inner():
pass pass
_ERRORS = u""" _ERRORS = u"""
3:5: closures inside cdef functions not yet supported 3:6: closures inside cpdef functions not yet supported
7:6: closures inside cdef functions not yet supported
""" """
...@@ -38,3 +38,19 @@ cdef class SelfInClosure(object): ...@@ -38,3 +38,19 @@ cdef class SelfInClosure(object):
def nested(): def nested():
return self.x, t.x return self.x, t.x
return nested return nested
def call_closure_method_cdef_attr_c(self, Test t):
"""
>>> o = SelfInClosure()
>>> o.call_closure_method_cdef_attr_c(Test())()
(1, 2)
"""
return self.closure_method_cdef_attr_c(t)
cdef closure_method_cdef_attr_c(self, Test t):
t.x = 2
self._t = t
self.x = 1
def nested():
return self.x, t.x
return nested
# mode: run # mode: run
# tag: closures # tag: closures
# preparse: id
# preparse: def_to_cdef
# #
# closure_tests_1.pyx # closure_tests_1.pyx
# #
......
# mode: run # mode: run
# tag: closures # tag: closures
# preparse: id
# preparse: def_to_cdef
# #
# closure_tests_2.pyx # closure_tests_2.pyx
# #
......
# mode: run # mode: run
# tag: closures # tag: closures
# preparse: id
# preparse: def_to_cdef
# #
# closure_tests_3.pyx # closure_tests_3.pyx
# #
......
# mode: run # mode: run
# tag: closures # tag: closures
# preparse: id
# preparse: def_to_cdef
# #
# closure_tests_4.pyx # closure_tests_4.pyx
# #
......
# mode: run # mode: run
# tag: closures # tag: closures
# ticket: 82 # ticket: 82
# preparse: id
# preparse: def_to_cdef
cimport cython cimport cython
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment