Commit b8782c99 authored by Stefan Behnel's avatar Stefan Behnel

initial constant folding transform: calculate constant values in node.constant_result

parent fb4ea137
......@@ -22,6 +22,13 @@ from Cython.Debugging import print_call_chain
from DebugFlags import debug_disposal_code, debug_temp_alloc, \
debug_coercion
try:
set
except NameError:
from sets import Set as set
not_a_constant = object()
constant_value_not_set = object()
class ExprNode(Node):
# subexprs [string] Class var holding names of subexpr node attrs
......@@ -172,6 +179,8 @@ class ExprNode(Node):
is_temp = 0
is_target = 0
constant_result = constant_value_not_set
def get_child_attrs(self):
return self.subexprs
child_attrs = property(fget=get_child_attrs)
......@@ -224,7 +233,17 @@ class ExprNode(Node):
# Return the native C type of the result (i.e. the
# C type of the result_code expression).
return self.result_ctype or self.type
def calculate_constant_result(self):
# Calculate the constant result of this expression and store
# it in ``self.constant_result``. Does nothing by default,
# thus leaving ``self.constant_result`` unknown.
#
# This must only be called when it is assured that all
# sub-expressions have a valid constant_result value. The
# ConstantFolding transform will do this.
pass
def compile_time_value(self, denv):
# Return value of compile-time expression, or report error.
error(self.pos, "Invalid compile-time expression")
......@@ -736,7 +755,9 @@ class NoneNode(PyConstNode):
# The constant value None
value = "Py_None"
constant_result = None
def compile_time_value(self, denv):
return None
......@@ -745,6 +766,8 @@ class EllipsisNode(PyConstNode):
value = "Py_Ellipsis"
constant_result = Ellipsis
def compile_time_value(self, denv):
return Ellipsis
......@@ -775,7 +798,10 @@ class ConstNode(AtomicNewTempExprNode):
class BoolNode(ConstNode):
type = PyrexTypes.c_bint_type
# The constant value True or False
def calculate_constant_result(self):
self.constant_result = self.value
def compile_time_value(self, denv):
return self.value
......@@ -785,10 +811,14 @@ class BoolNode(ConstNode):
class NullNode(ConstNode):
type = PyrexTypes.c_null_ptr_type
value = "NULL"
constant_result = 0
class CharNode(ConstNode):
type = PyrexTypes.c_char_type
def calculate_constant_result(self):
self.constant_result = ord(self.value)
def compile_time_value(self, denv):
return ord(self.value)
......@@ -830,6 +860,9 @@ class IntNode(ConstNode):
else:
return str(self.value) + self.unsigned + self.longness
def calculate_constant_result(self):
self.constant_result = int(self.value, 0)
def compile_time_value(self, denv):
return int(self.value, 0)
......@@ -953,6 +986,9 @@ class LongNode(AtomicNewTempExprNode):
# Python long integer literal
#
# value string
def calculate_constant_result(self):
self.constant_result = long(self.value)
def compile_time_value(self, denv):
return long(self.value)
......@@ -978,6 +1014,9 @@ class ImagNode(AtomicNewTempExprNode):
# Imaginary number literal
#
# value float imaginary part
def calculate_constant_result(self):
self.constant_result = complex(0.0, self.value)
def compile_time_value(self, denv):
return complex(0.0, self.value)
......@@ -1350,6 +1389,9 @@ class BackquoteNode(ExprNode):
gil_message = "Backquote expression"
def calculate_constant_result(self):
self.constant_result = repr(self.arg.constant_result)
def generate_result_code(self, code):
code.putln(
"%s = PyObject_Repr(%s); %s" % (
......@@ -1582,7 +1624,11 @@ class IndexNode(ExprNode):
def __init__(self, pos, index, *args, **kw):
ExprNode.__init__(self, pos, index=index, *args, **kw)
self._index = index
def calculate_constant_result(self):
self.constant_result = \
self.base.constant_result[self.index.constant_result]
def compile_time_value(self, denv):
base = self.base.compile_time_value(denv)
index = self.index.compile_time_value(denv)
......@@ -1881,7 +1927,11 @@ class SliceIndexNode(ExprNode):
# stop ExprNode or None
subexprs = ['base', 'start', 'stop']
def calculate_constant_result(self):
self.constant_result = self.base.constant_result[
self.start.constant_result : self.stop.constant_result]
def compile_time_value(self, denv):
base = self.base.compile_time_value(denv)
if self.start is None:
......@@ -2055,7 +2105,13 @@ class SliceNode(ExprNode):
# start ExprNode
# stop ExprNode
# step ExprNode
def calculate_constant_result(self):
self.constant_result = self.base.constant_result[
self.start.constant_result : \
self.stop.constant_result : \
self.step.constant_result]
def compile_time_value(self, denv):
start = self.start.compile_time_value(denv)
if self.stop is None:
......@@ -2452,6 +2508,9 @@ class AsTupleNode(ExprNode):
# arg ExprNode
subexprs = ['arg']
def calculate_constant_result(self):
self.constant_result = tuple(self.base.constant_result)
def compile_time_value(self, denv):
arg = self.arg.compile_time_value(denv)
......@@ -2517,7 +2576,13 @@ class AttributeNode(ExprNode):
self.analyse_as_python_attribute(env)
return self
return ExprNode.coerce_to(self, dst_type, env)
def calculate_constant_result(self):
attr = self.attribute
if attr.beginswith("__") and attr.endswith("__"):
return
self.constant_result = getattr(self.obj.constant_result, attr)
def compile_time_value(self, denv):
attr = self.attribute
if attr.beginswith("__") and attr.endswith("__"):
......@@ -2963,6 +3028,10 @@ class TupleNode(SequenceNode):
else:
return Naming.empty_tuple
def calculate_constant_result(self):
self.constant_result = tuple([
arg.constant_result for arg in self.args])
def compile_time_value(self, denv):
values = self.compile_time_value_list(denv)
try:
......@@ -3058,6 +3127,10 @@ class ListNode(SequenceNode):
else:
SequenceNode.release_temp(self, env)
def calculate_constant_result(self):
self.constant_result = [
arg.constant_result for arg in self.args]
def compile_time_value(self, denv):
return self.compile_time_value_list(denv)
......@@ -3228,12 +3301,12 @@ class SetNode(NewTempExprNode):
self.gil_check(env)
self.is_temp = 1
def calculate_constant_result(self):
self.constant_result = set([
arg.constant_result for arg in self.args])
def compile_time_value(self, denv):
values = [arg.compile_time_value(denv) for arg in self.args]
try:
set
except NameError:
from sets import Set as set
try:
return set(values)
except Exception, e:
......@@ -3264,6 +3337,10 @@ class DictNode(ExprNode):
# obj_conversion_errors [PyrexError] used internally
subexprs = ['key_value_pairs']
def calculate_constant_result(self):
self.constant_result = dict([
item.constant_result for item in self.key_value_pairs])
def compile_time_value(self, denv):
pairs = [(item.key.compile_time_value(denv), item.value.compile_time_value(denv))
......@@ -3366,6 +3443,10 @@ class DictItemNode(ExprNode):
# key ExprNode
# value ExprNode
subexprs = ['key', 'value']
def calculate_constant_result(self):
self.constant_result = (
self.key.constant_result, self.value.constant_result)
def analyse_types(self, env):
self.key.analyse_types(env)
......@@ -3507,6 +3588,10 @@ class UnopNode(ExprNode):
# - Allocate temporary for result if needed.
subexprs = ['operand']
def calculate_constant_result(self):
func = compile_time_unary_operators[self.operator]
self.constant_result = func(self.operand.constant_result)
def compile_time_value(self, denv):
func = compile_time_unary_operators.get(self.operator)
......@@ -3566,7 +3651,10 @@ class NotNode(ExprNode):
# 'not' operator
#
# operand ExprNode
def calculate_constant_result(self):
self.constant_result = not self.operand.constant_result
def compile_time_value(self, denv):
operand = self.operand.compile_time_value(denv)
try:
......@@ -3897,7 +3985,13 @@ class BinopNode(NewTempExprNode):
# - Allocate temporary for result if needed.
subexprs = ['operand1', 'operand2']
def calculate_constant_result(self):
func = compile_time_binary_operators[self.operator]
self.constant_result = func(
self.operand1.constant_result,
self.operand2.constant_result)
def compile_time_value(self, denv):
func = get_compile_time_binop(self)
operand1 = self.operand1.compile_time_value(denv)
......@@ -4137,6 +4231,16 @@ class BoolBinopNode(NewTempExprNode):
# operand2 ExprNode
subexprs = ['operand1', 'operand2']
def calculate_constant_result(self):
if self.operator == 'and':
self.constant_result = \
self.operand1.constant_result and \
self.operand2.constant_result
else:
self.constant_result = \
self.operand1.constant_result or \
self.operand2.constant_result
def compile_time_value(self, denv):
if self.operator == 'and':
......@@ -4261,7 +4365,13 @@ class CondExprNode(ExprNode):
false_val = None
subexprs = ['test', 'true_val', 'false_val']
def calculate_constant_result(self):
if self.test.constant_result:
self.constant_result = self.true_val.constant_result
else:
self.constant_result = self.false_val.constant_result
def analyse_types(self, env):
self.test.analyse_types(env)
self.test = self.test.coerce_to_boolean(env)
......@@ -4350,6 +4460,15 @@ richcmp_constants = {
class CmpNode:
# Mixin class containing code common to PrimaryCmpNodes
# and CascadedCmpNodes.
def calculate_cascaded_constant_result(self, operand1_result):
func = compile_time_binary_operators[self.operator]
operand2_result = self.operand2.constant_result
result = func(operand1_result, operand2_result)
if result and self.cascade:
result = result and \
self.cascade.cascaded_compile_time_value(operand2_result)
self.constant_result = result
def cascaded_compile_time_value(self, operand1, denv):
func = get_compile_time_binop(self)
......@@ -4362,6 +4481,7 @@ class CmpNode:
if result:
cascade = self.cascade
if cascade:
# FIXME: I bet this must call cascaded_compile_time_value()
result = result and cascade.compile_time_value(operand2, denv)
return result
......@@ -4468,6 +4588,10 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
child_attrs = ['operand1', 'operand2', 'cascade']
cascade = None
def calculate_constant_result(self):
self.constant_result = self.calculate_cascaded_constant_result(
self.operand1.constant_result)
def compile_time_value(self, denv):
operand1 = self.operand1.compile_time_value(denv)
......@@ -4598,7 +4722,8 @@ class CascadedCmpNode(Node, CmpNode):
child_attrs = ['operand2', 'cascade']
cascade = None
constant_result = constant_value_not_set # FIXME: where to calculate this?
def analyse_types(self, env, operand1):
self.operand2.analyse_types(env)
if self.cascade:
......
......@@ -83,7 +83,7 @@ class Context:
from ParseTreeTransforms import AlignFunctionDefinitions
from AutoDocTransforms import EmbedSignature
from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform
from Optimize import FlattenBuiltinTypeCreation, FinalOptimizePhase
from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_declarations
......@@ -123,6 +123,7 @@ class Context:
IntroduceBufferAuxiliaryVars(self),
_check_c_declarations,
AnalyseExpressionsTransform(self),
ConstantFolding(),
FlattenBuiltinTypeCreation(),
DictIterTransform(),
SwitchTransform(),
......
......@@ -387,6 +387,54 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
return node
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
"""Calculate the result of constant expressions to store it in
``expr_node.constant_result``, and replace trivial cases by their
constant result.
"""
def _calculate_const(self, node):
if node.constant_result is not ExprNodes.constant_value_not_set:
return
# make sure we always set the value
not_a_constant = ExprNodes.not_a_constant
node.constant_result = not_a_constant
# check if all children are constant
children = self.visitchildren(node)
for child_result in children.itervalues():
if type(child_result) is list:
for child in child_result:
if child.constant_result is not_a_constant:
return
elif child_result.constant_result is not_a_constant:
return
# now try to calculate the real constant value
try:
node.calculate_constant_result()
# if node.constant_result is not ExprNodes.not_a_constant:
# print node.__class__.__name__, node.constant_result
except (ValueError, TypeError, IndexError, AttributeError):
# ignore all 'normal' errors here => no constant result
pass
except Exception:
# this looks like a real error
import traceback, sys
traceback.print_exc(file=sys.stdout)
def visit_ExprNode(self, node):
self._calculate_const(node)
return node
# in the future, other nodes can have their own handler method here
# that can replace them with a constant result node
def visit_Node(self, node):
self.visitchildren(node)
return node
class FinalOptimizePhase(Visitor.CythonTransform):
"""
This visitor handles several commuting optimizations, and is run
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment