Commit f629adda authored by Stefan Behnel's avatar Stefan Behnel

rewrite of min()/max() optimisation, now correctly handling temps and types

parent de519dc7
......@@ -1803,50 +1803,49 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
if len(args) <= 1:
# leave this to Python
return node
unpacked_args = []
for arg in args:
if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
arg = arg.arg
unpacked_args.append(arg)
spanning_type = reduce(PyrexTypes.spanning_type,
[ arg.type for arg in unpacked_args ])
is_pycompare = spanning_type.is_pyobject
result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=spanning_type)
stats = [
Nodes.SingleAssignmentNode(
node.pos,
lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
rhs = unpacked_args[0].coerce_to(spanning_type, self.current_env()),
first = True)
]
for arg in unpacked_args[1:]:
stats.append(Nodes.IfStatNode(
arg.pos,
else_clause = None,
if_clauses = [ Nodes.IfClauseNode(
arg_nodes = []
ref_nodes = []
spanning_type = PyrexTypes.spanning_type(unpacked_args[0].type, unpacked_args[1].type)
for arg in unpacked_args:
arg = arg.coerce_to(spanning_type, self.current_env())
if not isinstance(arg, ExprNodes.ConstNode):
arg = UtilNodes.LetRefNode(arg)
ref_nodes.append(arg)
arg_nodes.append(arg)
spanning_type = PyrexTypes.spanning_type(spanning_type, arg.type)
last_result = arg_nodes[0]
for arg_node in arg_nodes[1:]:
last_result = last_result.coerce_to(arg_node.type, self.current_env())
is_py_compare = arg_node.type.is_pyobject
last_result = ExprNodes.CondExprNode(
arg_node.pos,
type = arg_node.type, # already coerced, so this is the target type
is_temp = True,
true_val = arg_node,
false_val = last_result,
test = ExprNodes.PrimaryCmpNode(
arg.pos,
condition = ExprNodes.PrimaryCmpNode(
arg.pos,
operand1 = arg.coerce_to(spanning_type, self.current_env()),
operator = operator,
operand2 = result_ref,
is_pycmp = is_pycompare,
is_temp = is_pycompare,
type = is_pycompare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type
).coerce_to_boolean(self.current_env()),
body = Nodes.SingleAssignmentNode(
arg.pos,
lhs = result_ref,
rhs = arg)
)]
))
operand1 = arg_node,
operator = operator,
operand2 = last_result,
is_pycmp = is_py_compare,
is_temp = is_py_compare,
type = is_py_compare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type,
).coerce_to_boolean(self.current_env()).coerce_to_temp(self.current_env()),
)
for ref_node in ref_nodes[::-1]:
last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
return UtilNodes.TempResultFromStatNode(
result_ref, Nodes.StatListNode(node.pos, stats = stats)
).coerce_to(node.type, self.current_env())
return last_result.coerce_to(node.type, self.current_env())
### special methods
......
class loud_list(list):
def __len__(self):
print "calling __len__"
return super(loud_list, self).__len__()
# max()
def test_max2():
"""
>>> test_max2()
2
2
2
2
2
calling __len__
3
calling __len__
3
"""
cdef int my_int = 1
cdef object my_pyint = 2
cdef object my_list = loud_list([1,2,3])
print max(1, 2)
print max(2, my_int)
print max(my_int, 2)
print max(my_int, my_pyint)
print max(my_pyint, my_int)
print max(my_int, len(my_list))
print max(len(my_list), my_int)
def test_max3():
"""
>>> test_max3()
calling __len__
3
calling __len__
calling __len__
3
"""
cdef int my_int = 1
cdef object my_pyint = 2
cdef object my_list = loud_list([1,2,3])
print max(my_int, my_pyint, len(my_list))
print max(my_pyint, my_list.__len__(), len(my_list))
def test_maxN():
"""
>>> test_maxN()
calling __len__
3
calling __len__
3
calling __len__
3
"""
cdef int my_int = 1
cdef object my_pyint = 2
cdef object my_list = loud_list([1,2,3])
print max(my_int, 2, my_int, 0, my_pyint, my_int, len(my_list))
print max(my_int, my_int, 0, my_pyint, my_int, len(my_list))
print max(my_int, my_int, 2, my_int, 0, my_pyint, my_int, len(my_list))
# min()
def test_min2():
"""
>>> test_min2()
1
1
1
1
1
calling __len__
1
calling __len__
1
"""
cdef int my_int = 1
cdef object my_pyint = 2
cdef object my_list = loud_list([1,2,3])
print min(1, 2)
print min(2, my_int)
print min(my_int, 2)
print min(my_int, my_pyint)
print min(my_pyint, my_int)
print min(my_int, len(my_list))
print min(len(my_list), my_int)
def test_min3():
"""
>>> test_min3()
calling __len__
1
calling __len__
calling __len__
2
"""
cdef int my_int = 1
cdef object my_pyint = 2
cdef object my_list = loud_list([1,2,3])
print min(my_int, my_pyint, len(my_list))
print min(my_pyint, my_list.__len__(), len(my_list))
def test_minN():
"""
>>> test_minN()
calling __len__
0
calling __len__
0
calling __len__
0
"""
cdef int my_int = 1
cdef object my_pyint = 2
cdef object my_list = loud_list([1,2,3])
print min(my_int, 2, my_int, 0, my_pyint, my_int, len(my_list))
print min(my_int, my_int, 0, my_pyint, my_int, len(my_list))
print min(my_int, my_int, 2, my_int, 0, my_pyint, my_int, len(my_list))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment