Commit 6d2ccb85 authored by Stefan Behnel's avatar Stefan Behnel

fix constant folding in PrimaryCmpNode/CascadedCmpNode

parent 3171ec18
...@@ -5475,9 +5475,11 @@ class CmpNode(object): ...@@ -5475,9 +5475,11 @@ class CmpNode(object):
func = compile_time_binary_operators[self.operator] func = compile_time_binary_operators[self.operator]
operand2_result = self.operand2.constant_result operand2_result = self.operand2.constant_result
result = func(operand1_result, operand2_result) result = func(operand1_result, operand2_result)
if result and self.cascade: if self.cascade:
result = result and \ self.cascade.calculate_cascaded_constant_result(operand2_result)
self.cascade.cascaded_compile_time_value(operand2_result) if self.cascade.constant_result:
self.constant_result = result and self.cascade.constant_result
else:
self.constant_result = result self.constant_result = result
def cascaded_compile_time_value(self, operand1, denv): def cascaded_compile_time_value(self, operand1, denv):
...@@ -5492,7 +5494,7 @@ class CmpNode(object): ...@@ -5492,7 +5494,7 @@ class CmpNode(object):
cascade = self.cascade cascade = self.cascade
if cascade: if cascade:
# FIXME: I bet this must call cascaded_compile_time_value() # FIXME: I bet this must call cascaded_compile_time_value()
result = result and cascade.compile_time_value(operand2, denv) result = result and cascade.cascaded_compile_time_value(operand2, denv)
return result return result
def is_cpp_comparison(self): def is_cpp_comparison(self):
...@@ -5787,8 +5789,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5787,8 +5789,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
return () return ()
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.calculate_cascaded_constant_result( self.calculate_cascaded_constant_result(self.operand1.constant_result)
self.operand1.constant_result)
def compile_time_value(self, denv): def compile_time_value(self, denv):
operand1 = self.operand1.compile_time_value(denv) operand1 = self.operand1.compile_time_value(denv)
...@@ -5966,6 +5967,10 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -5966,6 +5967,10 @@ class CascadedCmpNode(Node, CmpNode):
def type_dependencies(self, env): def type_dependencies(self, env):
return () return ()
def has_constant_result(self):
return self.constant_result is not constant_value_not_set and \
self.constant_result is not not_a_constant
def analyse_types(self, env): def analyse_types(self, env):
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.cascade: if self.cascade:
......
import sys import sys
IS_PY3 = sys.version_info[0] >= 3 IS_PY3 = sys.version_info[0] >= 3
cimport cython
DEF INT_VAL = 1
def _func(a,b,c): def _func(a,b,c):
return a+b+c return a+b+c
...@@ -76,3 +80,74 @@ def lists(): ...@@ -76,3 +80,74 @@ def lists():
True True
""" """
return [1,2,3] + [4,5,6] return [1,2,3] + [4,5,6]
def int_bool_result():
"""
>>> int_bool_result()
True
"""
if 5:
return True
else:
return False
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def if_compare_true():
"""
>>> if_compare_true()
True
"""
if 0 == 0:
return True
else:
return False
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def if_compare_false():
"""
>>> if_compare_false()
False
"""
if 0 == 1 or 1 == 0:
return True
else:
return False
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def if_compare_cascaded():
"""
>>> if_compare_cascaded()
True
"""
if 0 < 1 < 2 < 3:
return True
else:
return False
def list_bool_result():
"""
>>> list_bool_result()
True
"""
if [1,2,3]:
return True
else:
return False
def compile_time_DEF():
"""
>>> compile_time_DEF()
(1, False, True, True, False)
"""
return INT_VAL, INT_VAL == 0, INT_VAL != 0, INT_VAL == 1, INT_VAL != 1
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def compile_time_DEF_if():
"""
>>> compile_time_DEF_if()
True
"""
if INT_VAL != 0:
return True
else:
return False
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment