Commit b8782c99 authored by Stefan Behnel's avatar Stefan Behnel

initial constant folding transform: calculate constant values in node.constant_result

parent fb4ea137
...@@ -22,6 +22,13 @@ from Cython.Debugging import print_call_chain ...@@ -22,6 +22,13 @@ from Cython.Debugging import print_call_chain
from DebugFlags import debug_disposal_code, debug_temp_alloc, \ from DebugFlags import debug_disposal_code, debug_temp_alloc, \
debug_coercion debug_coercion
try:
set
except NameError:
from sets import Set as set
not_a_constant = object()
constant_value_not_set = object()
class ExprNode(Node): class ExprNode(Node):
# subexprs [string] Class var holding names of subexpr node attrs # subexprs [string] Class var holding names of subexpr node attrs
...@@ -172,6 +179,8 @@ class ExprNode(Node): ...@@ -172,6 +179,8 @@ class ExprNode(Node):
is_temp = 0 is_temp = 0
is_target = 0 is_target = 0
constant_result = constant_value_not_set
def get_child_attrs(self): def get_child_attrs(self):
return self.subexprs return self.subexprs
child_attrs = property(fget=get_child_attrs) child_attrs = property(fget=get_child_attrs)
...@@ -224,7 +233,17 @@ class ExprNode(Node): ...@@ -224,7 +233,17 @@ class ExprNode(Node):
# Return the native C type of the result (i.e. the # Return the native C type of the result (i.e. the
# C type of the result_code expression). # C type of the result_code expression).
return self.result_ctype or self.type return self.result_ctype or self.type
def calculate_constant_result(self):
# Calculate the constant result of this expression and store
# it in ``self.constant_result``. Does nothing by default,
# thus leaving ``self.constant_result`` unknown.
#
# This must only be called when it is assured that all
# sub-expressions have a valid constant_result value. The
# ConstantFolding transform will do this.
pass
def compile_time_value(self, denv): def compile_time_value(self, denv):
# Return value of compile-time expression, or report error. # Return value of compile-time expression, or report error.
error(self.pos, "Invalid compile-time expression") error(self.pos, "Invalid compile-time expression")
...@@ -736,7 +755,9 @@ class NoneNode(PyConstNode): ...@@ -736,7 +755,9 @@ class NoneNode(PyConstNode):
# The constant value None # The constant value None
value = "Py_None" value = "Py_None"
constant_result = None
def compile_time_value(self, denv): def compile_time_value(self, denv):
return None return None
...@@ -745,6 +766,8 @@ class EllipsisNode(PyConstNode): ...@@ -745,6 +766,8 @@ class EllipsisNode(PyConstNode):
value = "Py_Ellipsis" value = "Py_Ellipsis"
constant_result = Ellipsis
def compile_time_value(self, denv): def compile_time_value(self, denv):
return Ellipsis return Ellipsis
...@@ -775,7 +798,10 @@ class ConstNode(AtomicNewTempExprNode): ...@@ -775,7 +798,10 @@ class ConstNode(AtomicNewTempExprNode):
class BoolNode(ConstNode): class BoolNode(ConstNode):
type = PyrexTypes.c_bint_type type = PyrexTypes.c_bint_type
# The constant value True or False # The constant value True or False
def calculate_constant_result(self):
self.constant_result = self.value
def compile_time_value(self, denv): def compile_time_value(self, denv):
return self.value return self.value
...@@ -785,10 +811,14 @@ class BoolNode(ConstNode): ...@@ -785,10 +811,14 @@ class BoolNode(ConstNode):
class NullNode(ConstNode): class NullNode(ConstNode):
type = PyrexTypes.c_null_ptr_type type = PyrexTypes.c_null_ptr_type
value = "NULL" value = "NULL"
constant_result = 0
class CharNode(ConstNode): class CharNode(ConstNode):
type = PyrexTypes.c_char_type type = PyrexTypes.c_char_type
def calculate_constant_result(self):
self.constant_result = ord(self.value)
def compile_time_value(self, denv): def compile_time_value(self, denv):
return ord(self.value) return ord(self.value)
...@@ -830,6 +860,9 @@ class IntNode(ConstNode): ...@@ -830,6 +860,9 @@ class IntNode(ConstNode):
else: else:
return str(self.value) + self.unsigned + self.longness return str(self.value) + self.unsigned + self.longness
def calculate_constant_result(self):
self.constant_result = int(self.value, 0)
def compile_time_value(self, denv): def compile_time_value(self, denv):
return int(self.value, 0) return int(self.value, 0)
...@@ -953,6 +986,9 @@ class LongNode(AtomicNewTempExprNode): ...@@ -953,6 +986,9 @@ class LongNode(AtomicNewTempExprNode):
# Python long integer literal # Python long integer literal
# #
# value string # value string
def calculate_constant_result(self):
self.constant_result = long(self.value)
def compile_time_value(self, denv): def compile_time_value(self, denv):
return long(self.value) return long(self.value)
...@@ -978,6 +1014,9 @@ class ImagNode(AtomicNewTempExprNode): ...@@ -978,6 +1014,9 @@ class ImagNode(AtomicNewTempExprNode):
# Imaginary number literal # Imaginary number literal
# #
# value float imaginary part # value float imaginary part
def calculate_constant_result(self):
self.constant_result = complex(0.0, self.value)
def compile_time_value(self, denv): def compile_time_value(self, denv):
return complex(0.0, self.value) return complex(0.0, self.value)
...@@ -1350,6 +1389,9 @@ class BackquoteNode(ExprNode): ...@@ -1350,6 +1389,9 @@ class BackquoteNode(ExprNode):
gil_message = "Backquote expression" gil_message = "Backquote expression"
def calculate_constant_result(self):
self.constant_result = repr(self.arg.constant_result)
def generate_result_code(self, code): def generate_result_code(self, code):
code.putln( code.putln(
"%s = PyObject_Repr(%s); %s" % ( "%s = PyObject_Repr(%s); %s" % (
...@@ -1582,7 +1624,11 @@ class IndexNode(ExprNode): ...@@ -1582,7 +1624,11 @@ class IndexNode(ExprNode):
def __init__(self, pos, index, *args, **kw): def __init__(self, pos, index, *args, **kw):
ExprNode.__init__(self, pos, index=index, *args, **kw) ExprNode.__init__(self, pos, index=index, *args, **kw)
self._index = index self._index = index
def calculate_constant_result(self):
self.constant_result = \
self.base.constant_result[self.index.constant_result]
def compile_time_value(self, denv): def compile_time_value(self, denv):
base = self.base.compile_time_value(denv) base = self.base.compile_time_value(denv)
index = self.index.compile_time_value(denv) index = self.index.compile_time_value(denv)
...@@ -1881,7 +1927,11 @@ class SliceIndexNode(ExprNode): ...@@ -1881,7 +1927,11 @@ class SliceIndexNode(ExprNode):
# stop ExprNode or None # stop ExprNode or None
subexprs = ['base', 'start', 'stop'] subexprs = ['base', 'start', 'stop']
def calculate_constant_result(self):
self.constant_result = self.base.constant_result[
self.start.constant_result : self.stop.constant_result]
def compile_time_value(self, denv): def compile_time_value(self, denv):
base = self.base.compile_time_value(denv) base = self.base.compile_time_value(denv)
if self.start is None: if self.start is None:
...@@ -2055,7 +2105,13 @@ class SliceNode(ExprNode): ...@@ -2055,7 +2105,13 @@ class SliceNode(ExprNode):
# start ExprNode # start ExprNode
# stop ExprNode # stop ExprNode
# step ExprNode # step ExprNode
def calculate_constant_result(self):
self.constant_result = self.base.constant_result[
self.start.constant_result : \
self.stop.constant_result : \
self.step.constant_result]
def compile_time_value(self, denv): def compile_time_value(self, denv):
start = self.start.compile_time_value(denv) start = self.start.compile_time_value(denv)
if self.stop is None: if self.stop is None:
...@@ -2452,6 +2508,9 @@ class AsTupleNode(ExprNode): ...@@ -2452,6 +2508,9 @@ class AsTupleNode(ExprNode):
# arg ExprNode # arg ExprNode
subexprs = ['arg'] subexprs = ['arg']
def calculate_constant_result(self):
self.constant_result = tuple(self.base.constant_result)
def compile_time_value(self, denv): def compile_time_value(self, denv):
arg = self.arg.compile_time_value(denv) arg = self.arg.compile_time_value(denv)
...@@ -2517,7 +2576,13 @@ class AttributeNode(ExprNode): ...@@ -2517,7 +2576,13 @@ class AttributeNode(ExprNode):
self.analyse_as_python_attribute(env) self.analyse_as_python_attribute(env)
return self return self
return ExprNode.coerce_to(self, dst_type, env) return ExprNode.coerce_to(self, dst_type, env)
def calculate_constant_result(self):
attr = self.attribute
if attr.beginswith("__") and attr.endswith("__"):
return
self.constant_result = getattr(self.obj.constant_result, attr)
def compile_time_value(self, denv): def compile_time_value(self, denv):
attr = self.attribute attr = self.attribute
if attr.beginswith("__") and attr.endswith("__"): if attr.beginswith("__") and attr.endswith("__"):
...@@ -2963,6 +3028,10 @@ class TupleNode(SequenceNode): ...@@ -2963,6 +3028,10 @@ class TupleNode(SequenceNode):
else: else:
return Naming.empty_tuple return Naming.empty_tuple
def calculate_constant_result(self):
self.constant_result = tuple([
arg.constant_result for arg in self.args])
def compile_time_value(self, denv): def compile_time_value(self, denv):
values = self.compile_time_value_list(denv) values = self.compile_time_value_list(denv)
try: try:
...@@ -3058,6 +3127,10 @@ class ListNode(SequenceNode): ...@@ -3058,6 +3127,10 @@ class ListNode(SequenceNode):
else: else:
SequenceNode.release_temp(self, env) SequenceNode.release_temp(self, env)
def calculate_constant_result(self):
self.constant_result = [
arg.constant_result for arg in self.args]
def compile_time_value(self, denv): def compile_time_value(self, denv):
return self.compile_time_value_list(denv) return self.compile_time_value_list(denv)
...@@ -3228,12 +3301,12 @@ class SetNode(NewTempExprNode): ...@@ -3228,12 +3301,12 @@ class SetNode(NewTempExprNode):
self.gil_check(env) self.gil_check(env)
self.is_temp = 1 self.is_temp = 1
def calculate_constant_result(self):
self.constant_result = set([
arg.constant_result for arg in self.args])
def compile_time_value(self, denv): def compile_time_value(self, denv):
values = [arg.compile_time_value(denv) for arg in self.args] values = [arg.compile_time_value(denv) for arg in self.args]
try:
set
except NameError:
from sets import Set as set
try: try:
return set(values) return set(values)
except Exception, e: except Exception, e:
...@@ -3264,6 +3337,10 @@ class DictNode(ExprNode): ...@@ -3264,6 +3337,10 @@ class DictNode(ExprNode):
# obj_conversion_errors [PyrexError] used internally # obj_conversion_errors [PyrexError] used internally
subexprs = ['key_value_pairs'] subexprs = ['key_value_pairs']
def calculate_constant_result(self):
self.constant_result = dict([
item.constant_result for item in self.key_value_pairs])
def compile_time_value(self, denv): def compile_time_value(self, denv):
pairs = [(item.key.compile_time_value(denv), item.value.compile_time_value(denv)) pairs = [(item.key.compile_time_value(denv), item.value.compile_time_value(denv))
...@@ -3366,6 +3443,10 @@ class DictItemNode(ExprNode): ...@@ -3366,6 +3443,10 @@ class DictItemNode(ExprNode):
# key ExprNode # key ExprNode
# value ExprNode # value ExprNode
subexprs = ['key', 'value'] subexprs = ['key', 'value']
def calculate_constant_result(self):
self.constant_result = (
self.key.constant_result, self.value.constant_result)
def analyse_types(self, env): def analyse_types(self, env):
self.key.analyse_types(env) self.key.analyse_types(env)
...@@ -3507,6 +3588,10 @@ class UnopNode(ExprNode): ...@@ -3507,6 +3588,10 @@ class UnopNode(ExprNode):
# - Allocate temporary for result if needed. # - Allocate temporary for result if needed.
subexprs = ['operand'] subexprs = ['operand']
def calculate_constant_result(self):
func = compile_time_unary_operators[self.operator]
self.constant_result = func(self.operand.constant_result)
def compile_time_value(self, denv): def compile_time_value(self, denv):
func = compile_time_unary_operators.get(self.operator) func = compile_time_unary_operators.get(self.operator)
...@@ -3566,7 +3651,10 @@ class NotNode(ExprNode): ...@@ -3566,7 +3651,10 @@ class NotNode(ExprNode):
# 'not' operator # 'not' operator
# #
# operand ExprNode # operand ExprNode
def calculate_constant_result(self):
self.constant_result = not self.operand.constant_result
def compile_time_value(self, denv): def compile_time_value(self, denv):
operand = self.operand.compile_time_value(denv) operand = self.operand.compile_time_value(denv)
try: try:
...@@ -3897,7 +3985,13 @@ class BinopNode(NewTempExprNode): ...@@ -3897,7 +3985,13 @@ class BinopNode(NewTempExprNode):
# - Allocate temporary for result if needed. # - Allocate temporary for result if needed.
subexprs = ['operand1', 'operand2'] subexprs = ['operand1', 'operand2']
def calculate_constant_result(self):
func = compile_time_binary_operators[self.operator]
self.constant_result = func(
self.operand1.constant_result,
self.operand2.constant_result)
def compile_time_value(self, denv): def compile_time_value(self, denv):
func = get_compile_time_binop(self) func = get_compile_time_binop(self)
operand1 = self.operand1.compile_time_value(denv) operand1 = self.operand1.compile_time_value(denv)
...@@ -4137,6 +4231,16 @@ class BoolBinopNode(NewTempExprNode): ...@@ -4137,6 +4231,16 @@ class BoolBinopNode(NewTempExprNode):
# operand2 ExprNode # operand2 ExprNode
subexprs = ['operand1', 'operand2'] subexprs = ['operand1', 'operand2']
def calculate_constant_result(self):
if self.operator == 'and':
self.constant_result = \
self.operand1.constant_result and \
self.operand2.constant_result
else:
self.constant_result = \
self.operand1.constant_result or \
self.operand2.constant_result
def compile_time_value(self, denv): def compile_time_value(self, denv):
if self.operator == 'and': if self.operator == 'and':
...@@ -4261,7 +4365,13 @@ class CondExprNode(ExprNode): ...@@ -4261,7 +4365,13 @@ class CondExprNode(ExprNode):
false_val = None false_val = None
subexprs = ['test', 'true_val', 'false_val'] subexprs = ['test', 'true_val', 'false_val']
def calculate_constant_result(self):
if self.test.constant_result:
self.constant_result = self.true_val.constant_result
else:
self.constant_result = self.false_val.constant_result
def analyse_types(self, env): def analyse_types(self, env):
self.test.analyse_types(env) self.test.analyse_types(env)
self.test = self.test.coerce_to_boolean(env) self.test = self.test.coerce_to_boolean(env)
...@@ -4350,6 +4460,15 @@ richcmp_constants = { ...@@ -4350,6 +4460,15 @@ richcmp_constants = {
class CmpNode: class CmpNode:
# Mixin class containing code common to PrimaryCmpNodes # Mixin class containing code common to PrimaryCmpNodes
# and CascadedCmpNodes. # and CascadedCmpNodes.
def calculate_cascaded_constant_result(self, operand1_result):
func = compile_time_binary_operators[self.operator]
operand2_result = self.operand2.constant_result
result = func(operand1_result, operand2_result)
if result and self.cascade:
result = result and \
self.cascade.cascaded_compile_time_value(operand2_result)
self.constant_result = result
def cascaded_compile_time_value(self, operand1, denv): def cascaded_compile_time_value(self, operand1, denv):
func = get_compile_time_binop(self) func = get_compile_time_binop(self)
...@@ -4362,6 +4481,7 @@ class CmpNode: ...@@ -4362,6 +4481,7 @@ class CmpNode:
if result: if result:
cascade = self.cascade cascade = self.cascade
if cascade: if cascade:
# FIXME: I bet this must call cascaded_compile_time_value()
result = result and cascade.compile_time_value(operand2, denv) result = result and cascade.compile_time_value(operand2, denv)
return result return result
...@@ -4468,6 +4588,10 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): ...@@ -4468,6 +4588,10 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
child_attrs = ['operand1', 'operand2', 'cascade'] child_attrs = ['operand1', 'operand2', 'cascade']
cascade = None cascade = None
def calculate_constant_result(self):
self.constant_result = self.calculate_cascaded_constant_result(
self.operand1.constant_result)
def compile_time_value(self, denv): def compile_time_value(self, denv):
operand1 = self.operand1.compile_time_value(denv) operand1 = self.operand1.compile_time_value(denv)
...@@ -4598,7 +4722,8 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -4598,7 +4722,8 @@ class CascadedCmpNode(Node, CmpNode):
child_attrs = ['operand2', 'cascade'] child_attrs = ['operand2', 'cascade']
cascade = None cascade = None
constant_result = constant_value_not_set # FIXME: where to calculate this?
def analyse_types(self, env, operand1): def analyse_types(self, env, operand1):
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.cascade: if self.cascade:
......
...@@ -83,7 +83,7 @@ class Context: ...@@ -83,7 +83,7 @@ class Context:
from ParseTreeTransforms import AlignFunctionDefinitions from ParseTreeTransforms import AlignFunctionDefinitions
from AutoDocTransforms import EmbedSignature from AutoDocTransforms import EmbedSignature
from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform
from Optimize import FlattenBuiltinTypeCreation, FinalOptimizePhase from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
from Buffer import IntroduceBufferAuxiliaryVars from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_declarations from ModuleNode import check_c_declarations
...@@ -123,6 +123,7 @@ class Context: ...@@ -123,6 +123,7 @@ class Context:
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
ConstantFolding(),
FlattenBuiltinTypeCreation(), FlattenBuiltinTypeCreation(),
DictIterTransform(), DictIterTransform(),
SwitchTransform(), SwitchTransform(),
......
...@@ -387,6 +387,54 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform): ...@@ -387,6 +387,54 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
return node return node
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
"""Calculate the result of constant expressions to store it in
``expr_node.constant_result``, and replace trivial cases by their
constant result.
"""
def _calculate_const(self, node):
if node.constant_result is not ExprNodes.constant_value_not_set:
return
# make sure we always set the value
not_a_constant = ExprNodes.not_a_constant
node.constant_result = not_a_constant
# check if all children are constant
children = self.visitchildren(node)
for child_result in children.itervalues():
if type(child_result) is list:
for child in child_result:
if child.constant_result is not_a_constant:
return
elif child_result.constant_result is not_a_constant:
return
# now try to calculate the real constant value
try:
node.calculate_constant_result()
# if node.constant_result is not ExprNodes.not_a_constant:
# print node.__class__.__name__, node.constant_result
except (ValueError, TypeError, IndexError, AttributeError):
# ignore all 'normal' errors here => no constant result
pass
except Exception:
# this looks like a real error
import traceback, sys
traceback.print_exc(file=sys.stdout)
def visit_ExprNode(self, node):
self._calculate_const(node)
return node
# in the future, other nodes can have their own handler method here
# that can replace them with a constant result node
def visit_Node(self, node):
self.visitchildren(node)
return node
class FinalOptimizePhase(Visitor.CythonTransform): class FinalOptimizePhase(Visitor.CythonTransform):
""" """
This visitor handles several commuting optimizations, and is run This visitor handles several commuting optimizations, and is run
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment