Commit 5c618f29 authored by Stefan Behnel's avatar Stefan Behnel

reimplement min()/max() optimisation before type analysis

parent f629adda
...@@ -1200,6 +1200,42 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1200,6 +1200,42 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node.pos, loop = exec_code, result_node = result_ref, gen_expr_node.pos, loop = exec_code, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope, orig_func = 'sum') expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
def _handle_simple_function_min(self, node, pos_args):
return self._optimise_min_max(node, pos_args, '<')
def _handle_simple_function_max(self, node, pos_args):
return self._optimise_min_max(node, pos_args, '>')
def _optimise_min_max(self, node, args, operator):
"""Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
"""
if len(args) <= 1:
# leave this to Python
return node
cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:])
last_result = args[0]
for arg_node in cascaded_nodes:
result_ref = UtilNodes.ResultRefNode(last_result)
last_result = ExprNodes.CondExprNode(
arg_node.pos,
true_val = arg_node,
false_val = result_ref,
test = ExprNodes.PrimaryCmpNode(
arg_node.pos,
operand1 = arg_node,
operator = operator,
operand2 = result_ref,
)
)
last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
for ref_node in cascaded_nodes[::-1]:
last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
return last_result
def _DISABLED_handle_simple_function_tuple(self, node, pos_args): def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
if len(pos_args) == 0: if len(pos_args) == 0:
return ExprNodes.TupleNode(node.pos, args=[], constant_result=()) return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
...@@ -1791,62 +1827,6 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -1791,62 +1827,6 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
is_temp = False) is_temp = False)
return ExprNodes.CastNode(node, PyrexTypes.py_object_type) return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
def _handle_simple_function_min(self, node, pos_args):
return self._optimise_min_max(node, pos_args, '<')
def _handle_simple_function_max(self, node, pos_args):
return self._optimise_min_max(node, pos_args, '>')
def _optimise_min_max(self, node, args, operator):
"""Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
"""
if len(args) <= 1:
# leave this to Python
return node
unpacked_args = []
for arg in args:
if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
arg = arg.arg
unpacked_args.append(arg)
arg_nodes = []
ref_nodes = []
spanning_type = PyrexTypes.spanning_type(unpacked_args[0].type, unpacked_args[1].type)
for arg in unpacked_args:
arg = arg.coerce_to(spanning_type, self.current_env())
if not isinstance(arg, ExprNodes.ConstNode):
arg = UtilNodes.LetRefNode(arg)
ref_nodes.append(arg)
arg_nodes.append(arg)
spanning_type = PyrexTypes.spanning_type(spanning_type, arg.type)
last_result = arg_nodes[0]
for arg_node in arg_nodes[1:]:
last_result = last_result.coerce_to(arg_node.type, self.current_env())
is_py_compare = arg_node.type.is_pyobject
last_result = ExprNodes.CondExprNode(
arg_node.pos,
type = arg_node.type, # already coerced, so this is the target type
is_temp = True,
true_val = arg_node,
false_val = last_result,
test = ExprNodes.PrimaryCmpNode(
arg.pos,
operand1 = arg_node,
operator = operator,
operand2 = last_result,
is_pycmp = is_py_compare,
is_temp = is_py_compare,
type = is_py_compare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type,
).coerce_to_boolean(self.current_env()).coerce_to_temp(self.current_env()),
)
for ref_node in ref_nodes[::-1]:
last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
return last_result.coerce_to(node.type, self.current_env())
### special methods ### special methods
Pyx_tp_new_func_type = PyrexTypes.CFuncType( Pyx_tp_new_func_type = PyrexTypes.CFuncType(
......
cimport cython
class loud_list(list): class loud_list(list):
def __len__(self): def __len__(self):
print "calling __len__" print "calling __len__"
...@@ -6,6 +8,11 @@ class loud_list(list): ...@@ -6,6 +8,11 @@ class loud_list(list):
# max() # max()
@cython.test_assert_path_exists(
'//PrintStatNode//CondExprNode')
@cython.test_fail_if_path_exists(
'//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
'//PrintStatNode//SimpleCallNode//ConstNode')
def test_max2(): def test_max2():
""" """
>>> test_max2() >>> test_max2()
...@@ -33,6 +40,11 @@ def test_max2(): ...@@ -33,6 +40,11 @@ def test_max2():
print max(my_int, len(my_list)) print max(my_int, len(my_list))
print max(len(my_list), my_int) print max(len(my_list), my_int)
@cython.test_assert_path_exists(
'//PrintStatNode//CondExprNode')
@cython.test_fail_if_path_exists(
'//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
'//PrintStatNode//SimpleCallNode//ConstNode')
def test_max3(): def test_max3():
""" """
>>> test_max3() >>> test_max3()
...@@ -49,6 +61,11 @@ def test_max3(): ...@@ -49,6 +61,11 @@ def test_max3():
print max(my_int, my_pyint, len(my_list)) print max(my_int, my_pyint, len(my_list))
print max(my_pyint, my_list.__len__(), len(my_list)) print max(my_pyint, my_list.__len__(), len(my_list))
@cython.test_assert_path_exists(
'//PrintStatNode//CondExprNode')
@cython.test_fail_if_path_exists(
'//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
'//PrintStatNode//SimpleCallNode//ConstNode')
def test_maxN(): def test_maxN():
""" """
>>> test_maxN() >>> test_maxN()
...@@ -71,6 +88,11 @@ def test_maxN(): ...@@ -71,6 +88,11 @@ def test_maxN():
# min() # min()
@cython.test_assert_path_exists(
'//PrintStatNode//CondExprNode')
@cython.test_fail_if_path_exists(
'//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
'//PrintStatNode//SimpleCallNode//ConstNode')
def test_min2(): def test_min2():
""" """
>>> test_min2() >>> test_min2()
...@@ -99,6 +121,11 @@ def test_min2(): ...@@ -99,6 +121,11 @@ def test_min2():
print min(len(my_list), my_int) print min(len(my_list), my_int)
@cython.test_assert_path_exists(
'//PrintStatNode//CondExprNode')
@cython.test_fail_if_path_exists(
'//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
'//PrintStatNode//SimpleCallNode//ConstNode')
def test_min3(): def test_min3():
""" """
>>> test_min3() >>> test_min3()
...@@ -116,6 +143,11 @@ def test_min3(): ...@@ -116,6 +143,11 @@ def test_min3():
print min(my_pyint, my_list.__len__(), len(my_list)) print min(my_pyint, my_list.__len__(), len(my_list))
@cython.test_assert_path_exists(
'//PrintStatNode//CondExprNode')
@cython.test_fail_if_path_exists(
'//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
'//PrintStatNode//SimpleCallNode//ConstNode')
def test_minN(): def test_minN():
""" """
>>> test_minN() >>> test_minN()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment