Commit 1cae6142 authored by Stefan Behnel's avatar Stefan Behnel

rewrite constant folding for PrimaryCmpNode to properly support (and fix) cascaded comparisons

parent 9d9679d9
...@@ -9648,13 +9648,21 @@ class CmpNode(object): ...@@ -9648,13 +9648,21 @@ class CmpNode(object):
type(operand1_result) != type(operand2_result)): type(operand1_result) != type(operand2_result)):
# string comparison of different types isn't portable # string comparison of different types isn't portable
return return
result = func(operand1_result, operand2_result)
if self.cascade: if self.operator in ('in', 'not_in'):
self.cascade.calculate_cascaded_constant_result(operand2_result) if isinstance(self.operand2, (ListNode, TupleNode, SetNode)):
if self.cascade.has_constant_result(): if not self.operand2.args:
self.constant_result = result and self.cascade.constant_result self.constant_result = self.operator == 'not_in'
else: return
self.constant_result = result elif isinstance(self.operand2, ListNode) and not self.cascade:
# tuples are more efficient to store than lists
self.operand2 = self.operand2.as_tuple()
elif isinstance(self.operand2, DictNode):
if not self.operand2.key_value_pairs:
self.constant_result = self.operator == 'not_in'
return
self.constant_result = func(operand1_result, operand2_result)
def cascaded_compile_time_value(self, operand1, denv): def cascaded_compile_time_value(self, operand1, denv):
func = get_compile_time_binop(self) func = get_compile_time_binop(self)
...@@ -9966,6 +9974,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -9966,6 +9974,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
return () return ()
def calculate_constant_result(self): def calculate_constant_result(self):
assert not self.cascade
self.calculate_cascaded_constant_result(self.operand1.constant_result) self.calculate_cascaded_constant_result(self.operand1.constant_result)
def compile_time_value(self, denv): def compile_time_value(self, denv):
......
...@@ -1054,9 +1054,8 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): ...@@ -1054,9 +1054,8 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
args = node.operand2.args args = node.operand2.args
if len(args) == 0: if len(args) == 0:
constant_result = node.operator == 'not_in' # note: lhs may have side effects
return ExprNodes.BoolNode(pos = node.pos, value = constant_result, return node
constant_result = constant_result)
lhs = UtilNodes.ResultRefNode(node.operand1) lhs = UtilNodes.ResultRefNode(node.operand1)
...@@ -3285,19 +3284,89 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3285,19 +3284,89 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
return sequence_node return sequence_node
def visit_PrimaryCmpNode(self, node): def visit_PrimaryCmpNode(self, node):
if not node.cascade:
self._calculate_const(node) self._calculate_const(node)
if node.has_constant_result(): if node.has_constant_result():
return self._bool_node(node, node.constant_result) return self._bool_node(node, node.constant_result)
if node.operator in ('in', 'not_in') and not node.cascade: return node
if isinstance(node.operand2, (ExprNodes.ListNode, ExprNodes.TupleNode,
ExprNodes.SetNode)): # calculate constant partial results in the comparison cascade
if not node.operand2.args: left_node = node.operand1
return self._bool_node(node, node.operator == 'not_in') self._calculate_const(left_node)
if isinstance(node.operand2, ExprNodes.ListNode): cmp_node = node
node.operand2 = node.operand2.as_tuple() while cmp_node is not None:
elif isinstance(node.operand2, ExprNodes.DictNode): right_node = cmp_node.operand2
if not node.operand2.key_value_pairs: self._calculate_const(right_node)
return self._bool_node(node, node.operator == 'not_in') cmp_node.constant_result = not_a_constant
if left_node.has_constant_result() and right_node.has_constant_result():
try:
cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
pass # ignore all 'normal' errors here => no constant result
left_node = right_node
cmp_node = cmp_node.cascade
# collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
cascades = [[node.operand1]]
def split_cascades(cmp_node):
if cmp_node.has_constant_result():
if not cmp_node.constant_result:
# False => short-circuit
cascades.append([
self._bool_node(cmp_node, True),
ExprNodes.CascadedCmpNode(
cmp_node.pos,
operator='==',
operand2=self._bool_node(cmp_node, False),
constant_result=False)
])
return
else:
# True => discard and start new cascade
cascades.append([cmp_node.operand2])
else:
# not constant => append to current cascade
cascades[-1].append(cmp_node)
if cmp_node.cascade:
split_cascades(cmp_node.cascade)
split_cascades(node)
cmp_nodes = []
for cascade in cascades:
if len(cascade) < 2:
continue
cmp_node = cascade[1]
pcmp_node = ExprNodes.PrimaryCmpNode(
cmp_node.pos,
operand1=cascade[0],
operator=cmp_node.operator,
operand2=cmp_node.operand2,
constant_result=not_a_constant)
cmp_nodes.append(pcmp_node)
last_cmp_node = pcmp_node
for cmp_node in cascade[2:]:
last_cmp_node.cascade = cmp_node
last_cmp_node = cmp_node
last_cmp_node.cascade = None
if not cmp_nodes:
# only constants, but no False result
return self._bool_node(node, True)
node = cmp_nodes[0]
if len(cmp_nodes) == 1:
if node.has_constant_result():
return self._bool_node(node, node.constant_result)
else:
for cmp_node in cmp_nodes[1:]:
node = ExprNodes.BoolBinopNode(
node.pos,
operand1=node,
operator='and',
operand2=cmp_node,
constant_result=not_a_constant)
return node return node
def visit_CondExprNode(self, node): def visit_CondExprNode(self, node):
......
...@@ -339,8 +339,8 @@ def s(a): ...@@ -339,8 +339,8 @@ def s(a):
cdef int result = a in [1,2,3,4] in [[1,2,3],[2,3,4],[1,2,3,4]] cdef int result = a in [1,2,3,4] in [[1,2,3],[2,3,4],[1,2,3,4]]
return result return result
@cython.test_assert_path_exists("//ReturnStatNode//BoolNode") #@cython.test_assert_path_exists("//ReturnStatNode//BoolNode")
@cython.test_fail_if_path_exists("//SwitchStatNode") #@cython.test_fail_if_path_exists("//SwitchStatNode")
def constant_empty_sequence(a): def constant_empty_sequence(a):
""" """
>>> constant_empty_sequence(1) >>> constant_empty_sequence(1)
...@@ -350,6 +350,22 @@ def constant_empty_sequence(a): ...@@ -350,6 +350,22 @@ def constant_empty_sequence(a):
""" """
return a in () return a in ()
@cython.test_fail_if_path_exists("//ReturnStatNode//BoolNode")
@cython.test_assert_path_exists("//PrimaryCmpNode")
def constant_empty_sequence_side_effect(a):
"""
>>> l =[]
>>> def a():
... l.append(1)
... return 1
>>> constant_empty_sequence_side_effect(a)
False
>>> l
[1]
"""
return a() in ()
def test_error_non_iterable(x): def test_error_non_iterable(x):
""" """
>>> test_error_non_iterable(1) # doctest: +ELLIPSIS >>> test_error_non_iterable(1) # doctest: +ELLIPSIS
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment