Commit b8782c99 authored by Stefan Behnel's avatar Stefan Behnel

initial constant folding transform: calculate constant values in node.constant_result

parent fb4ea137
......@@ -22,6 +22,13 @@ from Cython.Debugging import print_call_chain
from DebugFlags import debug_disposal_code, debug_temp_alloc, \
debug_coercion
try:
set
except NameError:
from sets import Set as set
not_a_constant = object()
constant_value_not_set = object()
class ExprNode(Node):
# subexprs [string] Class var holding names of subexpr node attrs
......@@ -172,6 +179,8 @@ class ExprNode(Node):
is_temp = 0
is_target = 0
constant_result = constant_value_not_set
def get_child_attrs(self):
return self.subexprs
child_attrs = property(fget=get_child_attrs)
......@@ -225,6 +234,16 @@ class ExprNode(Node):
# C type of the result_code expression).
return self.result_ctype or self.type
def calculate_constant_result(self):
# Calculate the constant result of this expression and store
# it in ``self.constant_result``. Does nothing by default,
# thus leaving ``self.constant_result`` unknown.
#
# This must only be called when it is assured that all
# sub-expressions have a valid constant_result value. The
# ConstantFolding transform will do this.
pass
def compile_time_value(self, denv):
# Return value of compile-time expression, or report error.
error(self.pos, "Invalid compile-time expression")
......@@ -737,6 +756,8 @@ class NoneNode(PyConstNode):
value = "Py_None"
constant_result = None
def compile_time_value(self, denv):
return None
......@@ -745,6 +766,8 @@ class EllipsisNode(PyConstNode):
value = "Py_Ellipsis"
constant_result = Ellipsis
def compile_time_value(self, denv):
return Ellipsis
......@@ -776,6 +799,9 @@ class BoolNode(ConstNode):
type = PyrexTypes.c_bint_type
# The constant value True or False
def calculate_constant_result(self):
self.constant_result = self.value
def compile_time_value(self, denv):
return self.value
......@@ -785,11 +811,15 @@ class BoolNode(ConstNode):
class NullNode(ConstNode):
type = PyrexTypes.c_null_ptr_type
value = "NULL"
constant_result = 0
class CharNode(ConstNode):
type = PyrexTypes.c_char_type
def calculate_constant_result(self):
self.constant_result = ord(self.value)
def compile_time_value(self, denv):
return ord(self.value)
......@@ -830,6 +860,9 @@ class IntNode(ConstNode):
else:
return str(self.value) + self.unsigned + self.longness
def calculate_constant_result(self):
self.constant_result = int(self.value, 0)
def compile_time_value(self, denv):
return int(self.value, 0)
......@@ -954,6 +987,9 @@ class LongNode(AtomicNewTempExprNode):
#
# value string
def calculate_constant_result(self):
self.constant_result = long(self.value)
def compile_time_value(self, denv):
return long(self.value)
......@@ -979,6 +1015,9 @@ class ImagNode(AtomicNewTempExprNode):
#
# value float imaginary part
def calculate_constant_result(self):
self.constant_result = complex(0.0, self.value)
def compile_time_value(self, denv):
return complex(0.0, self.value)
......@@ -1350,6 +1389,9 @@ class BackquoteNode(ExprNode):
gil_message = "Backquote expression"
def calculate_constant_result(self):
self.constant_result = repr(self.arg.constant_result)
def generate_result_code(self, code):
code.putln(
"%s = PyObject_Repr(%s); %s" % (
......@@ -1583,6 +1625,10 @@ class IndexNode(ExprNode):
ExprNode.__init__(self, pos, index=index, *args, **kw)
self._index = index
def calculate_constant_result(self):
self.constant_result = \
self.base.constant_result[self.index.constant_result]
def compile_time_value(self, denv):
base = self.base.compile_time_value(denv)
index = self.index.compile_time_value(denv)
......@@ -1882,6 +1928,10 @@ class SliceIndexNode(ExprNode):
subexprs = ['base', 'start', 'stop']
def calculate_constant_result(self):
self.constant_result = self.base.constant_result[
self.start.constant_result : self.stop.constant_result]
def compile_time_value(self, denv):
base = self.base.compile_time_value(denv)
if self.start is None:
......@@ -2056,6 +2106,12 @@ class SliceNode(ExprNode):
# stop ExprNode
# step ExprNode
def calculate_constant_result(self):
self.constant_result = self.base.constant_result[
self.start.constant_result : \
self.stop.constant_result : \
self.step.constant_result]
def compile_time_value(self, denv):
start = self.start.compile_time_value(denv)
if self.stop is None:
......@@ -2453,6 +2509,9 @@ class AsTupleNode(ExprNode):
subexprs = ['arg']
def calculate_constant_result(self):
self.constant_result = tuple(self.base.constant_result)
def compile_time_value(self, denv):
arg = self.arg.compile_time_value(denv)
try:
......@@ -2518,6 +2577,12 @@ class AttributeNode(ExprNode):
return self
return ExprNode.coerce_to(self, dst_type, env)
def calculate_constant_result(self):
attr = self.attribute
if attr.beginswith("__") and attr.endswith("__"):
return
self.constant_result = getattr(self.obj.constant_result, attr)
def compile_time_value(self, denv):
attr = self.attribute
if attr.beginswith("__") and attr.endswith("__"):
......@@ -2963,6 +3028,10 @@ class TupleNode(SequenceNode):
else:
return Naming.empty_tuple
def calculate_constant_result(self):
self.constant_result = tuple([
arg.constant_result for arg in self.args])
def compile_time_value(self, denv):
values = self.compile_time_value_list(denv)
try:
......@@ -3058,6 +3127,10 @@ class ListNode(SequenceNode):
else:
SequenceNode.release_temp(self, env)
def calculate_constant_result(self):
self.constant_result = [
arg.constant_result for arg in self.args]
def compile_time_value(self, denv):
return self.compile_time_value_list(denv)
......@@ -3228,12 +3301,12 @@ class SetNode(NewTempExprNode):
self.gil_check(env)
self.is_temp = 1
def calculate_constant_result(self):
self.constant_result = set([
arg.constant_result for arg in self.args])
def compile_time_value(self, denv):
values = [arg.compile_time_value(denv) for arg in self.args]
try:
set
except NameError:
from sets import Set as set
try:
return set(values)
except Exception, e:
......@@ -3265,6 +3338,10 @@ class DictNode(ExprNode):
subexprs = ['key_value_pairs']
def calculate_constant_result(self):
self.constant_result = dict([
item.constant_result for item in self.key_value_pairs])
def compile_time_value(self, denv):
pairs = [(item.key.compile_time_value(denv), item.value.compile_time_value(denv))
for item in self.key_value_pairs]
......@@ -3367,6 +3444,10 @@ class DictItemNode(ExprNode):
# value ExprNode
subexprs = ['key', 'value']
def calculate_constant_result(self):
self.constant_result = (
self.key.constant_result, self.value.constant_result)
def analyse_types(self, env):
self.key.analyse_types(env)
self.value.analyse_types(env)
......@@ -3508,6 +3589,10 @@ class UnopNode(ExprNode):
subexprs = ['operand']
def calculate_constant_result(self):
func = compile_time_unary_operators[self.operator]
self.constant_result = func(self.operand.constant_result)
def compile_time_value(self, denv):
func = compile_time_unary_operators.get(self.operator)
if not func:
......@@ -3567,6 +3652,9 @@ class NotNode(ExprNode):
#
# operand ExprNode
def calculate_constant_result(self):
self.constant_result = not self.operand.constant_result
def compile_time_value(self, denv):
operand = self.operand.compile_time_value(denv)
try:
......@@ -3898,6 +3986,12 @@ class BinopNode(NewTempExprNode):
subexprs = ['operand1', 'operand2']
def calculate_constant_result(self):
func = compile_time_binary_operators[self.operator]
self.constant_result = func(
self.operand1.constant_result,
self.operand2.constant_result)
def compile_time_value(self, denv):
func = get_compile_time_binop(self)
operand1 = self.operand1.compile_time_value(denv)
......@@ -4138,6 +4232,16 @@ class BoolBinopNode(NewTempExprNode):
subexprs = ['operand1', 'operand2']
def calculate_constant_result(self):
if self.operator == 'and':
self.constant_result = \
self.operand1.constant_result and \
self.operand2.constant_result
else:
self.constant_result = \
self.operand1.constant_result or \
self.operand2.constant_result
def compile_time_value(self, denv):
if self.operator == 'and':
return self.operand1.compile_time_value(denv) \
......@@ -4262,6 +4366,12 @@ class CondExprNode(ExprNode):
subexprs = ['test', 'true_val', 'false_val']
def calculate_constant_result(self):
if self.test.constant_result:
self.constant_result = self.true_val.constant_result
else:
self.constant_result = self.false_val.constant_result
def analyse_types(self, env):
self.test.analyse_types(env)
self.test = self.test.coerce_to_boolean(env)
......@@ -4351,6 +4461,15 @@ class CmpNode:
# Mixin class containing code common to PrimaryCmpNodes
# and CascadedCmpNodes.
def calculate_cascaded_constant_result(self, operand1_result):
func = compile_time_binary_operators[self.operator]
operand2_result = self.operand2.constant_result
result = func(operand1_result, operand2_result)
if result and self.cascade:
result = result and \
self.cascade.cascaded_compile_time_value(operand2_result)
self.constant_result = result
def cascaded_compile_time_value(self, operand1, denv):
func = get_compile_time_binop(self)
operand2 = self.operand2.compile_time_value(denv)
......@@ -4362,6 +4481,7 @@ class CmpNode:
if result:
cascade = self.cascade
if cascade:
# FIXME: I bet this must call cascaded_compile_time_value()
result = result and cascade.compile_time_value(operand2, denv)
return result
......@@ -4469,6 +4589,10 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
cascade = None
def calculate_constant_result(self):
self.constant_result = self.calculate_cascaded_constant_result(
self.operand1.constant_result)
def compile_time_value(self, denv):
operand1 = self.operand1.compile_time_value(denv)
return self.cascaded_compile_time_value(operand1, denv)
......@@ -4598,6 +4722,7 @@ class CascadedCmpNode(Node, CmpNode):
child_attrs = ['operand2', 'cascade']
cascade = None
constant_result = constant_value_not_set # FIXME: where to calculate this?
def analyse_types(self, env, operand1):
self.operand2.analyse_types(env)
......
......@@ -83,7 +83,7 @@ class Context:
from ParseTreeTransforms import AlignFunctionDefinitions
from AutoDocTransforms import EmbedSignature
from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform
from Optimize import FlattenBuiltinTypeCreation, FinalOptimizePhase
from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_declarations
......@@ -123,6 +123,7 @@ class Context:
IntroduceBufferAuxiliaryVars(self),
_check_c_declarations,
AnalyseExpressionsTransform(self),
ConstantFolding(),
FlattenBuiltinTypeCreation(),
DictIterTransform(),
SwitchTransform(),
......
......@@ -387,6 +387,54 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
return node
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
"""Calculate the result of constant expressions to store it in
``expr_node.constant_result``, and replace trivial cases by their
constant result.
"""
def _calculate_const(self, node):
if node.constant_result is not ExprNodes.constant_value_not_set:
return
# make sure we always set the value
not_a_constant = ExprNodes.not_a_constant
node.constant_result = not_a_constant
# check if all children are constant
children = self.visitchildren(node)
for child_result in children.itervalues():
if type(child_result) is list:
for child in child_result:
if child.constant_result is not_a_constant:
return
elif child_result.constant_result is not_a_constant:
return
# now try to calculate the real constant value
try:
node.calculate_constant_result()
# if node.constant_result is not ExprNodes.not_a_constant:
# print node.__class__.__name__, node.constant_result
except (ValueError, TypeError, IndexError, AttributeError):
# ignore all 'normal' errors here => no constant result
pass
except Exception:
# this looks like a real error
import traceback, sys
traceback.print_exc(file=sys.stdout)
def visit_ExprNode(self, node):
self._calculate_const(node)
return node
# in the future, other nodes can have their own handler method here
# that can replace them with a constant result node
def visit_Node(self, node):
self.visitchildren(node)
return node
class FinalOptimizePhase(Visitor.CythonTransform):
"""
This visitor handles several commuting optimizations, and is run
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment