Commit cb01a073 authored by Stefan Behnel's avatar Stefan Behnel

code cleanup in ConstantFolding transform to make boolean handling less error prone

parent 60f24e7f
...@@ -2970,12 +2970,23 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -2970,12 +2970,23 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
self._calculate_const(node) self._calculate_const(node)
return node return node
def visit_UnaryMinusNode(self, node): def visit_UnopNode(self, node):
self._calculate_const(node) self._calculate_const(node)
if node.constant_result is ExprNodes.not_a_constant: if node.constant_result is ExprNodes.not_a_constant:
return node return node
if not node.operand.is_literal: if not node.operand.is_literal:
return node return node
if isinstance(node.operand, ExprNodes.BoolNode):
return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
type = PyrexTypes.c_int_type,
constant_result = node.constant_result)
if node.operator == '+':
return self._handle_UnaryPlusNode(node)
elif node.operator == '-':
return self._handle_UnaryMinusNode(node)
return node
def _handle_UnaryMinusNode(self, node):
if isinstance(node.operand, ExprNodes.LongNode): if isinstance(node.operand, ExprNodes.LongNode):
return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value, return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
constant_result = node.constant_result) constant_result = node.constant_result)
...@@ -2983,11 +2994,6 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -2983,11 +2994,6 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
# this is a safe operation # this is a safe operation
return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value, return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
constant_result = node.constant_result) constant_result = node.constant_result)
if isinstance(node.operand, ExprNodes.BoolNode):
# not important at all, but simplifies the code below
return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
type = PyrexTypes.c_int_type,
constant_result = node.constant_result)
node_type = node.operand.type node_type = node.operand.type
if node_type.is_int and node_type.signed or \ if node_type.is_int and node_type.signed or \
isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject: isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
...@@ -2997,10 +3003,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -2997,10 +3003,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
constant_result = node.constant_result) constant_result = node.constant_result)
return node return node
def visit_UnaryPlusNode(self, node): def _handle_UnaryPlusNode(self, node):
self._calculate_const(node)
if node.constant_result is ExprNodes.not_a_constant:
return node
if node.constant_result == node.operand.constant_result: if node.constant_result == node.operand.constant_result:
return node.operand return node.operand
return node return node
...@@ -3026,12 +3029,13 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3026,12 +3029,13 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
return node return node
if isinstance(node.constant_result, float): if isinstance(node.constant_result, float):
return node return node
if not node.operand1.is_literal or not node.operand2.is_literal: operand1, operand2 = node.operand1, node.operand2
if not operand1.is_literal or not operand2.is_literal:
return node return node
# now inject a new constant node with the calculated value # now inject a new constant node with the calculated value
try: try:
type1, type2 = node.operand1.type, node.operand2.type type1, type2 = operand1.type, operand2.type
if type1 is None or type2 is None: if type1 is None or type2 is None:
return node return node
except AttributeError: except AttributeError:
...@@ -3041,14 +3045,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3041,14 +3045,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
widest_type = PyrexTypes.widest_numeric_type(type1, type2) widest_type = PyrexTypes.widest_numeric_type(type1, type2)
else: else:
widest_type = PyrexTypes.py_object_type widest_type = PyrexTypes.py_object_type
target_class = self._widest_node_class(node.operand1, node.operand2) target_class = self._widest_node_class(operand1, operand2)
if target_class is None: if target_class is None:
return node return node
elif target_class is ExprNodes.IntNode: elif target_class is ExprNodes.IntNode:
unsigned = getattr(node.operand1, 'unsigned', '') and \ unsigned = getattr(operand1, 'unsigned', '') and \
getattr(node.operand2, 'unsigned', '') getattr(operand2, 'unsigned', '')
longness = "LL"[:max(len(getattr(node.operand1, 'longness', '')), longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
len(getattr(node.operand2, 'longness', '')))] len(getattr(operand2, 'longness', '')))]
new_node = ExprNodes.IntNode(pos=node.pos, new_node = ExprNodes.IntNode(pos=node.pos,
unsigned = unsigned, longness = longness, unsigned = unsigned, longness = longness,
value = str(node.constant_result), value = str(node.constant_result),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment