Commit 74006e33 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

merge

parents b127d487 fbc20b57
...@@ -2348,9 +2348,6 @@ class SimpleCallNode(CallNode): ...@@ -2348,9 +2348,6 @@ class SimpleCallNode(CallNode):
arg_list_code.append(actual_arg.result()) arg_list_code.append(actual_arg.result())
result = "%s(%s)" % (self.function.result(), result = "%s(%s)" % (self.function.result(),
', '.join(arg_list_code)) ', '.join(arg_list_code))
# if self.wrapper_call or \
# self.function.entry.is_unbound_cmethod and self.function.entry.type.is_overridable:
# result = "(%s = 1, %s)" % (Naming.skip_dispatch_cname, result)
return result return result
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -3271,6 +3268,8 @@ class ListNode(SequenceNode): ...@@ -3271,6 +3268,8 @@ class ListNode(SequenceNode):
# obj_conversion_errors [PyrexError] used internally # obj_conversion_errors [PyrexError] used internally
# orignial_args [ExprNode] used internally # orignial_args [ExprNode] used internally
obj_conversion_errors = []
gil_message = "Constructing Python list" gil_message = "Constructing Python list"
def analyse_expressions(self, env): def analyse_expressions(self, env):
...@@ -3404,11 +3403,12 @@ class ComprehensionAppendNode(ExprNode): ...@@ -3404,11 +3403,12 @@ class ComprehensionAppendNode(ExprNode):
# target must not be in child_attrs/subexprs # target must not be in child_attrs/subexprs
subexprs = ['expr'] subexprs = ['expr']
type = PyrexTypes.c_int_type
def analyse_types(self, env): def analyse_types(self, env):
self.expr.analyse_types(env) self.expr.analyse_types(env)
if not self.expr.type.is_pyobject: if not self.expr.type.is_pyobject:
self.expr = self.expr.coerce_to_pyobject(env) self.expr = self.expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1 self.is_temp = 1
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -3437,7 +3437,6 @@ class DictComprehensionAppendNode(ComprehensionAppendNode): ...@@ -3437,7 +3437,6 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
self.value_expr.analyse_types(env) self.value_expr.analyse_types(env)
if not self.value_expr.type.is_pyobject: if not self.value_expr.type.is_pyobject:
self.value_expr = self.value_expr.coerce_to_pyobject(env) self.value_expr = self.value_expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1 self.is_temp = 1
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -3504,6 +3503,9 @@ class DictNode(ExprNode): ...@@ -3504,6 +3503,9 @@ class DictNode(ExprNode):
is_temp = 1 is_temp = 1
type = dict_type type = dict_type
type = dict_type
obj_conversion_errors = []
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = dict([ self.constant_result = dict([
item.constant_result for item in self.key_value_pairs]) item.constant_result for item in self.key_value_pairs])
......
...@@ -163,9 +163,14 @@ class Context(object): ...@@ -163,9 +163,14 @@ class Context(object):
module_node.scope.utility_code_list.extend(scope.utility_code_list) module_node.scope.utility_code_list.extend(scope.utility_code_list)
return module_node return module_node
test_support = []
if options.evaluate_tree_assertions:
from Cython.TestUtils import TreeAssertVisitor
test_support.append(TreeAssertVisitor())
return ([ return ([
create_parse(self), create_parse(self),
] + self.create_pipeline(pxd=False, py=py) + [ ] + self.create_pipeline(pxd=False, py=py) + test_support + [
inject_pxd_code, inject_pxd_code,
abort_on_errors, abort_on_errors,
generate_pyx_code, generate_pyx_code,
...@@ -592,6 +597,7 @@ class CompilationOptions(object): ...@@ -592,6 +597,7 @@ class CompilationOptions(object):
verbose boolean Always print source names being compiled verbose boolean Always print source names being compiled
quiet boolean Don't print source names in recursive mode quiet boolean Don't print source names in recursive mode
compiler_directives dict Overrides for pragma options (see Options.py) compiler_directives dict Overrides for pragma options (see Options.py)
evaluate_tree_assertions boolean Test support: evaluate parse tree assertions
Following options are experimental and only used on MacOSX: Following options are experimental and only used on MacOSX:
...@@ -780,6 +786,7 @@ default_options = dict( ...@@ -780,6 +786,7 @@ default_options = dict(
verbose = 0, verbose = 0,
quiet = 0, quiet = 0,
compiler_directives = {}, compiler_directives = {},
evaluate_tree_assertions = False,
emit_linenums = False, emit_linenums = False,
) )
if sys.platform == "mac": if sys.platform == "mac":
......
...@@ -1668,13 +1668,21 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1668,13 +1668,21 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Initialize various global constants etc. ---*/") code.putln("/*--- Initialize various global constants etc. ---*/")
code.putln(code.error_goto_if_neg("__Pyx_InitGlobals()", self.pos)) code.putln(code.error_goto_if_neg("__Pyx_InitGlobals()", self.pos))
__main__name = code.globalstate.get_py_string_const(
EncodedString("__main__"), identifier=True)
code.putln("if (%s%s) {" % (Naming.module_is_main, self.full_module_name.replace('.', '__')))
code.putln(
'if (__Pyx_SetAttrString(%s, "__name__", %s) < 0) %s;' % (
env.module_cname,
__main__name.cname,
code.error_goto(self.pos)))
code.putln("}")
if Options.cache_builtins: if Options.cache_builtins:
code.putln("/*--- Builtin init code ---*/") code.putln("/*--- Builtin init code ---*/")
code.putln(code.error_goto_if_neg("__Pyx_InitCachedBuiltins()", code.putln(code.error_goto_if_neg("__Pyx_InitCachedBuiltins()",
self.pos)) self.pos))
code.putln("%s = 0;" % Naming.skip_dispatch_cname);
code.putln("/*--- Global init code ---*/") code.putln("/*--- Global init code ---*/")
self.generate_global_init_code(env, code) self.generate_global_init_code(env, code)
...@@ -1840,17 +1848,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1840,17 +1848,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
env.module_cname, env.module_cname,
Naming.builtins_cname, Naming.builtins_cname,
code.error_goto(self.pos))) code.error_goto(self.pos)))
__main__name = code.globalstate.get_py_string_const(
EncodedString("__main__"), identifier=True)
code.putln("if (%s%s) {" % (Naming.module_is_main, self.full_module_name.replace('.', '__')))
code.putln(
'if (__Pyx_SetAttrString(%s, "__name__", %s) < 0) %s;' % (
env.module_cname,
__main__name.cname,
code.error_goto(self.pos)))
code.putln("}")
if Options.pre_import is not None: if Options.pre_import is not None:
code.putln( code.putln(
'%s = PyImport_AddModule(__Pyx_NAMESTR("%s"));' % ( '%s = PyImport_AddModule(__Pyx_NAMESTR("%s"));' % (
......
...@@ -3971,6 +3971,8 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -3971,6 +3971,8 @@ class ForFromStatNode(LoopNode, StatNode):
# depend on whether or not the loop is a python type. # depend on whether or not the loop is a python type.
self.py_loopvar_node.generate_evaluation_code(code) self.py_loopvar_node.generate_evaluation_code(code)
self.target.generate_assignment_code(self.py_loopvar_node, code) self.target.generate_assignment_code(self.py_loopvar_node, code)
if from_range:
code.funcstate.release_temp(loopvar_name)
break_label = code.break_label break_label = code.break_label
code.set_loop_labels(old_loop_labels) code.set_loop_labels(old_loop_labels)
if self.else_clause: if self.else_clause:
...@@ -4746,11 +4748,7 @@ utility_function_predeclarations = \ ...@@ -4746,11 +4748,7 @@ utility_function_predeclarations = \
typedef struct {PyObject **p; char *s; long n; char is_unicode; char intern; char is_identifier;} __Pyx_StringTabEntry; /*proto*/ typedef struct {PyObject **p; char *s; long n; char is_unicode; char intern; char is_identifier;} __Pyx_StringTabEntry; /*proto*/
""" + """ """
static int %(skip_dispatch_cname)s = 0;
""" % { 'skip_dispatch_cname': Naming.skip_dispatch_cname }
if Options.gcc_branch_hints: if Options.gcc_branch_hints:
branch_prediction_macros = \ branch_prediction_macros = \
......
...@@ -34,11 +34,11 @@ def is_common_value(a, b): ...@@ -34,11 +34,11 @@ def is_common_value(a, b):
return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
return False return False
class IterationTransform(Visitor.VisitorTransform): class IterationTransform(Visitor.VisitorTransform):
"""Transform some common for-in loop patterns into efficient C loops: """Transform some common for-in loop patterns into efficient C loops:
- for-in-dict loop becomes a while loop calling PyDict_Next() - for-in-dict loop becomes a while loop calling PyDict_Next()
- for-in-enumerate is replaced by an external counter variable
- for-in-range loop becomes a plain C for loop - for-in-range loop becomes a plain C for loop
""" """
PyDict_Next_func_type = PyrexTypes.CFuncType( PyDict_Next_func_type = PyrexTypes.CFuncType(
...@@ -224,6 +224,13 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -224,6 +224,13 @@ class IterationTransform(Visitor.VisitorTransform):
bound2 = args[1].coerce_to_integer(self.current_scope) bound2 = args[1].coerce_to_integer(self.current_scope)
step = step.coerce_to_integer(self.current_scope) step = step.coerce_to_integer(self.current_scope)
if not isinstance(bound2, ExprNodes.ConstNode):
# stop bound must be immutable => keep it in a temp var
bound2_is_temp = True
bound2 = UtilNodes.LetRefNode(bound2)
else:
bound2_is_temp = False
for_node = Nodes.ForFromStatNode( for_node = Nodes.ForFromStatNode(
node.pos, node.pos,
target=node.target, target=node.target,
...@@ -232,6 +239,10 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -232,6 +239,10 @@ class IterationTransform(Visitor.VisitorTransform):
step=step, body=node.body, step=step, body=node.body,
else_clause=node.else_clause, else_clause=node.else_clause,
from_range=True) from_range=True)
if bound2_is_temp:
for_node = UtilNodes.LetNode(bound2, for_node)
return for_node return for_node
def _transform_dict_iteration(self, node, dict_obj, keys, values): def _transform_dict_iteration(self, node, dict_obj, keys, values):
...@@ -613,24 +624,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -613,24 +624,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
]) ])
def _handle_simple_function_dict(self, node, pos_args): def _handle_simple_function_dict(self, node, pos_args):
"""Replace dict(some_dict) by PyDict_Copy(some_dict). """Replace dict(some_dict) by PyDict_Copy(some_dict) and
dict([ (a,b) for ... ]) by a literal { a:b for ... }.
""" """
if len(pos_args.args) != 1: if len(pos_args.args) != 1:
return node return node
dict_arg = pos_args.args[0] arg = pos_args.args[0]
if dict_arg.type is not Builtin.dict_type: if arg.type is Builtin.dict_type:
return node arg = ExprNodes.NoneCheckNode(
arg, "PyExc_TypeError", "'NoneType' is not iterable")
dict_arg = ExprNodes.NoneCheckNode(
dict_arg, "PyExc_TypeError", "'NoneType' is not iterable")
return ExprNodes.PythonCapiCallNode( return ExprNodes.PythonCapiCallNode(
node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
args = [dict_arg], args = [arg],
is_temp = node.is_temp is_temp = node.is_temp
) )
elif isinstance(arg, ExprNodes.ComprehensionNode) and \
arg.type is Builtin.list_type:
append_node = arg.append
if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
len(append_node.expr.args) == 2:
key_node, value_node = append_node.expr.args
target_node = ExprNodes.DictNode(
pos=arg.target.pos, key_value_pairs=[], is_temp=1)
new_append_node = ExprNodes.DictComprehensionAppendNode(
append_node.pos, target=target_node,
key_expr=key_node, value_expr=value_node,
is_temp=1)
arg.target = target_node
arg.type = target_node.type
replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
return replace_in(arg)
return node
def _handle_simple_function_set(self, node, pos_args): def _handle_simple_function_set(self, node, pos_args):
"""Replace set([a,b,...]) by a literal set {a,b,...}. """Replace set([a,b,...]) by a literal set {a,b,...} and
set([ x for ... ]) by a literal { x for ... }.
""" """
arg_count = len(pos_args.args) arg_count = len(pos_args.args)
if arg_count == 0: if arg_count == 0:
...@@ -881,12 +909,6 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -881,12 +909,6 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
self.PyUnicode_AsXyzString_func_type, self.PyUnicode_AsXyzString_func_type,
'encode', is_unbound_method, [string_node]) 'encode', is_unbound_method, [string_node])
return self._substitute_method_call(
node, "PyUnicode_AsEncodedString",
self.PyUnicode_AsEncodedString_func_type,
'encode', is_unbound_method,
[string_node, encoding_node, null_node])
return self._substitute_method_call( return self._substitute_method_call(
node, "PyUnicode_AsEncodedString", node, "PyUnicode_AsEncodedString",
self.PyUnicode_AsEncodedString_func_type, self.PyUnicode_AsEncodedString_func_type,
......
...@@ -68,7 +68,11 @@ option_defaults = { ...@@ -68,7 +68,11 @@ option_defaults = {
'c99_complex' : False, # Don't use macro wrappers for complex arith, not sure what to name this... 'c99_complex' : False, # Don't use macro wrappers for complex arith, not sure what to name this...
'callspec' : "", 'callspec' : "",
'profile': False, 'profile': False,
'autotestdict': True 'autotestdict': True,
# test support
'test_assert_path_exists' : [],
'test_fail_if_path_exists' : [],
} }
# Override types possibilities above, if needed # Override types possibilities above, if needed
...@@ -80,7 +84,9 @@ for key, val in option_defaults.items(): ...@@ -80,7 +84,9 @@ for key, val in option_defaults.items():
option_scopes = { # defaults to available everywhere option_scopes = { # defaults to available everywhere
# 'module', 'function', 'class', 'with statement' # 'module', 'function', 'class', 'with statement'
'autotestdict' : ('module',) 'autotestdict' : ('module',),
'test_assert_path_exists' : ('function',),
'test_fail_if_path_exists' : ('function',),
} }
def parse_option_value(name, value): def parse_option_value(name, value):
......
...@@ -457,6 +457,11 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -457,6 +457,11 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
raise PostParseError(dec.function.pos, raise PostParseError(dec.function.pos,
'The %s option takes no prepositional arguments' % optname) 'The %s option takes no prepositional arguments' % optname)
return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs]) return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
elif optiontype is list:
if kwds and len(kwds) != 0:
raise PostParseError(dec.function.pos,
'The %s option takes no keyword arguments' % optname)
return optname, [ str(arg.value) for arg in args ]
else: else:
assert False assert False
...@@ -499,10 +504,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -499,10 +504,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
legal_scopes = Options.option_scopes.get(name, None) legal_scopes = Options.option_scopes.get(name, None)
if not self.check_directive_scope(node.pos, name, 'function'): if not self.check_directive_scope(node.pos, name, 'function'):
continue continue
if name in optdict and isinstance(optdict[name], dict): if name in optdict:
# only keywords can be merged, everything else old_value = optdict[name]
# overrides completely # keywords and arg lists can be merged, everything
optdict[name].update(value) # else overrides completely
if isinstance(old_value, dict):
old_value.update(value)
elif isinstance(old_value, list):
old_value.extend(value)
else:
optdict[name] = value
else: else:
optdict[name] = value optdict[name] = value
body = StatListNode(node.pos, stats=[node]) body = StatListNode(node.pos, stats=[node])
......
...@@ -2,6 +2,7 @@ import unittest ...@@ -2,6 +2,7 @@ import unittest
from Cython.Compiler.Visitor import PrintTree from Cython.Compiler.Visitor import PrintTree
from Cython.TestUtils import TransformTest from Cython.TestUtils import TransformTest
from Cython.Compiler.TreePath import find_first, find_all from Cython.Compiler.TreePath import find_first, find_all
from Cython.Compiler import Nodes, ExprNodes
class TestTreePath(TransformTest): class TestTreePath(TransformTest):
_tree = None _tree = None
...@@ -24,6 +25,12 @@ class TestTreePath(TransformTest): ...@@ -24,6 +25,12 @@ class TestTreePath(TransformTest):
self.assertEquals(1, len(find_all(t, "//ReturnStatNode"))) self.assertEquals(1, len(find_all(t, "//ReturnStatNode")))
self.assertEquals(1, len(find_all(t, "//DefNode//ReturnStatNode"))) self.assertEquals(1, len(find_all(t, "//DefNode//ReturnStatNode")))
def test_node_path_star(self):
t = self._build_tree()
self.assertEquals(10, len(find_all(t, "//*")))
self.assertEquals(8, len(find_all(t, "//DefNode//*")))
self.assertEquals(0, len(find_all(t, "//NameNode//*")))
def test_node_path_attribute(self): def test_node_path_attribute(self):
t = self._build_tree() t = self._build_tree()
self.assertEquals(2, len(find_all(t, "//NameNode/@name"))) self.assertEquals(2, len(find_all(t, "//NameNode/@name")))
...@@ -34,18 +41,49 @@ class TestTreePath(TransformTest): ...@@ -34,18 +41,49 @@ class TestTreePath(TransformTest):
self.assertEquals(1, len(find_all(t, "//DefNode/ReturnStatNode/NameNode"))) self.assertEquals(1, len(find_all(t, "//DefNode/ReturnStatNode/NameNode")))
self.assertEquals(1, len(find_all(t, "//ReturnStatNode/NameNode"))) self.assertEquals(1, len(find_all(t, "//ReturnStatNode/NameNode")))
def test_node_path_node_predicate(self):
t = self._build_tree()
self.assertEquals(0, len(find_all(t, "//DefNode[.//ForInStatNode]")))
self.assertEquals(2, len(find_all(t, "//DefNode[.//NameNode]")))
self.assertEquals(1, len(find_all(t, "//ReturnStatNode[./NameNode]")))
self.assertEquals(Nodes.ReturnStatNode,
type(find_first(t, "//ReturnStatNode[./NameNode]")))
def test_node_path_node_predicate_step(self):
t = self._build_tree()
self.assertEquals(2, len(find_all(t, "//DefNode[.//NameNode]")))
self.assertEquals(8, len(find_all(t, "//DefNode[.//NameNode]//*")))
self.assertEquals(1, len(find_all(t, "//DefNode[.//NameNode]//ReturnStatNode")))
self.assertEquals(Nodes.ReturnStatNode,
type(find_first(t, "//DefNode[.//NameNode]//ReturnStatNode")))
def test_node_path_attribute_exists(self): def test_node_path_attribute_exists(self):
t = self._build_tree() t = self._build_tree()
self.assertEquals(2, len(find_all(t, "//NameNode[@name]"))) self.assertEquals(2, len(find_all(t, "//NameNode[@name]")))
self.assertEquals(ExprNodes.NameNode,
type(find_first(t, "//NameNode[@name]")))
def test_node_path_attribute_exists_not(self): def test_node_path_attribute_exists_not(self):
t = self._build_tree() t = self._build_tree()
self.assertEquals(0, len(find_all(t, "//NameNode[not(@name)]"))) self.assertEquals(0, len(find_all(t, "//NameNode[not(@name)]")))
self.assertEquals(2, len(find_all(t, "//NameNode[not(@honking)]"))) self.assertEquals(2, len(find_all(t, "//NameNode[not(@honking)]")))
def test_node_path_and(self):
t = self._build_tree()
self.assertEquals(1, len(find_all(t, "//DefNode[.//ReturnStatNode and .//NameNode]")))
self.assertEquals(0, len(find_all(t, "//NameNode[@honking and @name]")))
self.assertEquals(0, len(find_all(t, "//NameNode[@name and @honking]")))
self.assertEquals(2, len(find_all(t, "//DefNode[.//NameNode[@name] and @name]")))
def test_node_path_attribute_string_predicate(self): def test_node_path_attribute_string_predicate(self):
t = self._build_tree() t = self._build_tree()
self.assertEquals(1, len(find_all(t, "//NameNode[@name = 'decorator']"))) self.assertEquals(1, len(find_all(t, "//NameNode[@name = 'decorator']")))
def test_node_path_recursive_predicate(self):
t = self._build_tree()
self.assertEquals(2, len(find_all(t, "//DefNode[.//NameNode[@name]]")))
self.assertEquals(1, len(find_all(t, "//DefNode[.//NameNode[@name = 'decorator']]")))
self.assertEquals(1, len(find_all(t, "//DefNode[.//ReturnStatNode[./NameNode[@name = 'fun']]/NameNode]")))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -54,7 +54,7 @@ def parse_func(next, token): ...@@ -54,7 +54,7 @@ def parse_func(next, token):
def handle_func_not(next, token): def handle_func_not(next, token):
""" """
func(...) not(...)
""" """
name, predicate = parse_func(next, token) name, predicate = parse_func(next, token)
...@@ -167,14 +167,20 @@ def handle_attribute(next, token): ...@@ -167,14 +167,20 @@ def handle_attribute(next, token):
def parse_path_value(next): def parse_path_value(next):
token = next() token = next()
value = token[0] value = token[0]
if value:
if value[:1] == "'" or value[:1] == '"': if value[:1] == "'" or value[:1] == '"':
value = value[1:-1] return value[1:-1]
else:
try: try:
value = int(value) return int(value)
except ValueError: except ValueError:
pass
else:
name = token[1].lower()
if name == 'true':
return True
elif name == 'false':
return False
raise ValueError("Invalid attribute predicate: '%s'" % value) raise ValueError("Invalid attribute predicate: '%s'" % value)
return value
def handle_predicate(next, token): def handle_predicate(next, token):
token = next() token = next()
...@@ -189,6 +195,9 @@ def handle_predicate(next, token): ...@@ -189,6 +195,9 @@ def handle_predicate(next, token):
if token[0] == "/": if token[0] == "/":
token = next() token = next()
if not token[0] and token[1] == 'and':
return logical_and(selector, handle_predicate(next, token))
def select(result): def select(result):
for node in result: for node in result:
subresult = iter((node,)) subresult = iter((node,))
...@@ -196,9 +205,23 @@ def handle_predicate(next, token): ...@@ -196,9 +205,23 @@ def handle_predicate(next, token):
subresult = select(subresult) subresult = select(subresult)
predicate_result = _get_first_or_none(subresult) predicate_result = _get_first_or_none(subresult)
if predicate_result is not None: if predicate_result is not None:
yield predicate_result yield node
return select return select
def logical_and(lhs_selects, rhs_select):
def select(result):
for node in result:
subresult = iter((node,))
for select in lhs_selects:
subresult = select(subresult)
predicate_result = _get_first_or_none(subresult)
subresult = iter((node,))
if predicate_result is not None:
for result_node in rhs_select(subresult):
yield node
return select
operations = { operations = {
"@": handle_attribute, "@": handle_attribute,
"": handle_name, "": handle_name,
......
...@@ -306,6 +306,22 @@ class ScopeTrackingTransform(CythonTransform): ...@@ -306,6 +306,22 @@ class ScopeTrackingTransform(CythonTransform):
def visit_CStructOrUnionDefNode(self, node): def visit_CStructOrUnionDefNode(self, node):
return self.visit_scope(node, 'struct') return self.visit_scope(node, 'struct')
class RecursiveNodeReplacer(VisitorTransform):
"""
Recursively replace all occurrences of a node in a subtree by
another node.
"""
def __init__(self, orig_node, new_node):
super(RecursiveNodeReplacer, self).__init__()
self.orig_node, self.new_node = orig_node, new_node
def visit_Node(self, node):
self.visitchildren(node)
if node is self.orig_node:
return self.new_node
else:
return node
......
...@@ -4,7 +4,8 @@ import unittest ...@@ -4,7 +4,8 @@ import unittest
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
import Cython.Compiler.Main as Main import Cython.Compiler.Main as Main
from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
from Cython.Compiler.Visitor import TreeVisitor from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform
from Cython.Compiler import TreePath
class NodeTypeWriter(TreeVisitor): class NodeTypeWriter(TreeVisitor):
def __init__(self): def __init__(self):
...@@ -74,6 +75,10 @@ class CythonTest(unittest.TestCase): ...@@ -74,6 +75,10 @@ class CythonTest(unittest.TestCase):
self.assertEqual(len(result_lines), len(expected_lines), self.assertEqual(len(result_lines), len(expected_lines),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected)) "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
def assertNodeExists(self, path, result_tree):
self.assertNotEqual(TreePath.find_first(result_tree, path), None,
"Path '%s' not found in result tree" % path)
def fragment(self, code, pxds={}, pipeline=[]): def fragment(self, code, pxds={}, pipeline=[]):
"Simply create a tree fragment using the name of the test-case in parse errors." "Simply create a tree fragment using the name of the test-case in parse errors."
name = self.id() name = self.id()
...@@ -136,3 +141,26 @@ class TransformTest(CythonTest): ...@@ -136,3 +141,26 @@ class TransformTest(CythonTest):
tree = T(tree) tree = T(tree)
return tree return tree
class TreeAssertVisitor(VisitorTransform):
# actually, a TreeVisitor would be enough, but this needs to run
# as part of the compiler pipeline
def visit_CompilerDirectivesNode(self, node):
directives = node.directives
if 'test_assert_path_exists' in directives:
for path in directives['test_assert_path_exists']:
if TreePath.find_first(node, path) is None:
Errors.error(
node.pos,
"Expected path '%s' not found in result tree" % path)
if 'test_fail_if_path_exists' in directives:
for path in directives['test_fail_if_path_exists']:
if TreePath.find_first(node, path) is not None:
Errors.error(
node.pos,
"Unexpected path '%s' found in result tree" % path)
self.visitchildren(node)
return node
visit_Node = VisitorTransform.recurse_to_children
...@@ -279,7 +279,9 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -279,7 +279,9 @@ class CythonCompileTestCase(unittest.TestCase):
annotate = annotate, annotate = annotate,
use_listing_file = False, use_listing_file = False,
cplus = self.language == 'cpp', cplus = self.language == 'cpp',
generate_pxi = False) generate_pxi = False,
evaluate_tree_assertions = True,
)
cython_compile(source, options=options, cython_compile(source, options=options,
full_module_name=module) full_module_name=module)
......
cimport cython
@cython.test_assert_path_exists(
"//SingleAssignmentNode",
"//SingleAssignmentNode[./NameNode[@name = 'a']]",
"//SingleAssignmentNode[./NameNode[@name = 'a'] and @first = True]",
)
def test_cdef():
cdef int a = 1
@cython.test_assert_path_exists(
"//SingleAssignmentNode",
"//SingleAssignmentNode[./NameNode[@name = 'a']]",
# FIXME: currently not working
# "//SingleAssignmentNode[./NameNode[@name = 'a'] and @first = True]",
)
def test_py():
a = 1
@cython.test_assert_path_exists(
"//SingleAssignmentNode",
"//SingleAssignmentNode[./NameNode[@name = 'a']]",
# FIXME: currently not working
# "//SingleAssignmentNode[./NameNode[@name = 'a'] and @first = True]",
)
def test_cond():
if True:
a = 1
else:
a = 2
cimport cython
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//NameNode")
@cython.test_assert_path_exists("//ComprehensionNode",
"//ComprehensionNode//FuncDefNode")
def test():
object()
_ERRORS = u"""
8:0: Expected path '//ComprehensionNode' not found in result tree
8:0: Expected path '//ComprehensionNode//FuncDefNode' not found in result tree
8:0: Unexpected path '//NameNode' found in result tree
8:0: Unexpected path '//SimpleCallNode' found in result tree
"""
__doc__ = u""" __doc__ = u"""
>>> type(smoketest()) is dict >>> type(smoketest_dict()) is dict
True
>>> type(smoketest_list()) is dict
True True
>>> sorted(smoketest().items()) >>> sorted(smoketest_dict().items())
[(2, 0), (4, 4), (6, 8)]
>>> sorted(smoketest_list().items())
[(2, 0), (4, 4), (6, 8)] [(2, 0), (4, 4), (6, 8)]
>>> list(typed().items()) >>> list(typed().items())
[(A, 1), (A, 1), (A, 1)] [(A, 1), (A, 1), (A, 1)]
>>> sorted(iterdict().items()) >>> sorted(iterdict().items())
[(1, 'a'), (2, 'b'), (3, 'c')] [(1, 'a'), (2, 'b'), (3, 'c')]
""" """
def smoketest(): cimport cython
return {x+2:x*2 for x in range(5) if x % 2 == 0}
def smoketest_dict():
return { x+2:x*2
for x in range(5)
if x % 2 == 0 }
@cython.test_fail_if_path_exists(
"//ComprehensionNode//ComprehensionAppendNode",
"//SimpleCallNode//ComprehensionNode")
@cython.test_assert_path_exists(
"//ComprehensionNode",
"//ComprehensionNode//DictComprehensionAppendNode")
def smoketest_list():
return dict([ (x+2,x*2)
for x in range(5)
if x % 2 == 0 ])
cdef class A: cdef class A:
def __repr__(self): return u"A" def __repr__(self): return u"A"
......
...@@ -63,20 +63,26 @@ __doc__ = u""" ...@@ -63,20 +63,26 @@ __doc__ = u"""
""" """
cimport cython
@cython.test_fail_if_path_exists("//SimpleCallNode//NameNode[@name = 'enumerate']")
def go_py_enumerate(): def go_py_enumerate():
for i,k in enumerate(range(1,5)): for i,k in enumerate(range(1,5)):
print i, k print i, k
@cython.test_fail_if_path_exists("//SimpleCallNode//NameNode[@name = 'enumerate']")
def go_c_enumerate(): def go_c_enumerate():
cdef int i,k cdef int i,k
for i,k in enumerate(range(1,5)): for i,k in enumerate(range(1,5)):
print i, k print i, k
@cython.test_fail_if_path_exists("//SimpleCallNode//NameNode[@name = 'enumerate']")
def go_c_enumerate_step(): def go_c_enumerate_step():
cdef int i,k cdef int i,k
for i,k in enumerate(range(1,7,2)): for i,k in enumerate(range(1,7,2)):
print i, k print i, k
@cython.test_fail_if_path_exists("//SimpleCallNode//NameNode[@name = 'enumerate']")
def py_enumerate_dict(dict d): def py_enumerate_dict(dict d):
cdef int i = 55 cdef int i = 55
k = 99 k = 99
...@@ -84,6 +90,7 @@ def py_enumerate_dict(dict d): ...@@ -84,6 +90,7 @@ def py_enumerate_dict(dict d):
print i, k print i, k
print u"::", i, k print u"::", i, k
@cython.test_fail_if_path_exists("//SimpleCallNode")
def py_enumerate_break(*t): def py_enumerate_break(*t):
i,k = 55,99 i,k = 55,99
for i,k in enumerate(t): for i,k in enumerate(t):
...@@ -91,6 +98,7 @@ def py_enumerate_break(*t): ...@@ -91,6 +98,7 @@ def py_enumerate_break(*t):
break break
print u"::", i, k print u"::", i, k
@cython.test_fail_if_path_exists("//SimpleCallNode")
def py_enumerate_return(*t): def py_enumerate_return(*t):
i,k = 55,99 i,k = 55,99
for i,k in enumerate(t): for i,k in enumerate(t):
...@@ -98,6 +106,7 @@ def py_enumerate_return(*t): ...@@ -98,6 +106,7 @@ def py_enumerate_return(*t):
return return
print u"::", i, k print u"::", i, k
@cython.test_fail_if_path_exists("//SimpleCallNode")
def py_enumerate_continue(*t): def py_enumerate_continue(*t):
i,k = 55,99 i,k = 55,99
for i,k in enumerate(t): for i,k in enumerate(t):
...@@ -105,20 +114,24 @@ def py_enumerate_continue(*t): ...@@ -105,20 +114,24 @@ def py_enumerate_continue(*t):
continue continue
print u"::", i, k print u"::", i, k
@cython.test_fail_if_path_exists("//SimpleCallNode//NameNode[@name = 'enumerate']")
def empty_c_enumerate(): def empty_c_enumerate():
cdef int i = 55, k = 99 cdef int i = 55, k = 99
for i,k in enumerate(range(0)): for i,k in enumerate(range(0)):
print i, k print i, k
return i, k return i, k
# not currently optimised
def single_target_enumerate(): def single_target_enumerate():
for t in enumerate(range(1,5)): for t in enumerate(range(1,5)):
print t[0], t[1] print t[0], t[1]
@cython.test_fail_if_path_exists("//SimpleCallNode//NameNode[@name = 'enumerate']")
def multi_enumerate(): def multi_enumerate():
for a,(b,(c,d)) in enumerate(enumerate(enumerate(range(1,5)))): for a,(b,(c,d)) in enumerate(enumerate(enumerate(range(1,5)))):
print a,b,c,d print a,b,c,d
@cython.test_fail_if_path_exists("//SimpleCallNode")
def multi_c_enumerate(): def multi_c_enumerate():
cdef int a,b,c,d cdef int a,b,c,d
for a,(b,(c,d)) in enumerate(enumerate(enumerate(range(1,5)))): for a,(b,(c,d)) in enumerate(enumerate(enumerate(range(1,5)))):
......
__doc__ = u"""
>>> test_modify()
0 1 2 3 4
(4, 0)
>>> test_fix()
0 1 2 3 4
4
>>> test_break()
0 1 2
(2, 0)
>>> test_return()
0 1 2
(2, 0)
"""
cimport cython
@cython.test_assert_path_exists("//ForFromStatNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def test_modify():
cdef int i, n = 5
for i in range(n):
print i,
n = 0
print
return i,n
@cython.test_assert_path_exists("//ForFromStatNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def test_fix():
cdef int i
for i in range(5):
print i,
print
return i
@cython.test_assert_path_exists("//ForFromStatNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def test_break():
cdef int i, n = 5
for i in range(n):
print i,
n = 0
if i == 2:
break
print
return i,n
@cython.test_assert_path_exists("//ForFromStatNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def test_return():
cdef int i, n = 5
for i in range(n):
print i,
n = 0
if i == 2:
return i,n
print
return "FAILED!"
__doc__ = u""" __doc__ = u"""
>>> type(smoketest()) is not list >>> type(smoketest_set()) is not list
True True
>>> type(smoketest()) is _set >>> type(smoketest_set()) is _set
True
>>> type(smoketest_list()) is _set
True True
>>> sorted(smoketest()) >>> sorted(smoketest_set())
[0, 4, 8]
>>> sorted(smoketest_list())
[0, 4, 8] [0, 4, 8]
>>> list(typed()) >>> list(typed())
[A, A, A] [A, A, A]
>>> sorted(iterdict()) >>> sorted(iterdict())
[1, 2, 3] [1, 2, 3]
""" """
cimport cython
# Py2.3 doesn't have the set type, but Cython does :) # Py2.3 doesn't have the set type, but Cython does :)
_set = set _set = set
def smoketest(): def smoketest_set():
return {x*2 for x in range(5) if x % 2 == 0} return { x*2
for x in range(5)
if x % 2 == 0 }
@cython.test_fail_if_path_exists("//SimpleCallNode//ComprehensionNode")
@cython.test_assert_path_exists("//ComprehensionNode",
"//ComprehensionNode//ComprehensionAppendNode")
def smoketest_list():
return set([ x*2
for x in range(5)
if x % 2 == 0 ])
cdef class A: cdef class A:
def __repr__(self): return u"A" def __repr__(self): return u"A"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment