Commit efa0d3cb authored by Robert Bradshaw's avatar Robert Bradshaw

Case statements and "x in [...]" flattening.

parent e42ddaa8
......@@ -4018,6 +4018,56 @@ class CloneNode(CoercionNode):
def release_temp(self, env):
pass
class PersistentNode(ExprNode):
# A PersistentNode is like a CloneNode except it handles the temporary
# allocation itself by keeping track of the number of times it has been
# used.
subexprs = ["arg"]
temp_counter = 0
generate_counter = 0
result_code = None
def __init__(self, arg, uses):
self.pos = arg.pos
self.arg = arg
self.uses = uses
def analyse_types(self, env):
self.arg.analyse_types(env)
self.type = self.arg.type
self.result_ctype = self.arg.result_ctype
self.is_temp = 1
def generate_evaluation_code(self, code):
if self.generate_counter == 0:
self.arg.generate_evaluation_code(code)
code.putln("%s = %s;" % (
self.result_code, self.arg.result_as(self.ctype())))
if self.type.is_pyobject:
code.put_incref(self.result_code, self.ctype())
self.arg.generate_disposal_code(code)
self.generate_counter += 1
def generate_disposal_code(self, code):
if self.generate_counter == self.uses:
if self.type.is_pyobject:
code.put_decref_clear(self.result_code, self.ctype())
def allocate_temps(self, env, result=None):
if self.temp_counter == 0:
self.arg.allocate_temps(env)
if result is None:
self.result_code = env.allocate_temp(self.type)
else:
self.result_code = result
self.arg.release_temp(env)
self.temp_counter += 1
def release_temp(self, env):
if self.temp_counter == self.uses:
env.release_temp(self.result_code)
#------------------------------------------------------------------------------------
#
......
......@@ -357,6 +357,7 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform
from Buffer import BufferTransform
from ModuleNode import check_c_classes
......@@ -364,12 +365,14 @@ def create_default_pipeline(context, options, result):
create_parse(context),
NormalizeTree(context),
PostParse(context),
FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context),
AnalyseDeclarationsTransform(context),
check_c_classes,
AnalyseExpressionsTransform(context),
BufferTransform(context),
SwitchTransform(),
# CreateClosureClasses(context),
create_generate_code(context, options, result)
]
......
import Nodes
import ExprNodes
import PyrexTypes
import Visitor
def unwrap_node(node):
while isinstance(node, ExprNodes.PersistentNode):
node = node.arg
return node
def is_common_value(a, b):
a = unwrap_node(a)
b = unwrap_node(b)
if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
return a.name == b.name
if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
......@@ -11,13 +18,20 @@ def is_common_value(a, b):
return False
class SwitchTransformVisitor(Visitor.VisitorTransform):
class SwitchTransform(Visitor.VisitorTransform):
"""
This transformation tries to turn long if statements into C switch statements.
The requirement is that every clause be an (or of) var == value, where the var
is common among all clauses and both var and value are not Python objects.
"""
def extract_conditions(self, cond):
if isinstance(cond, ExprNodes.CoerceToTempNode):
cond = cond.arg
if isinstance(cond, ExprNodes.TypecastNode):
cond = cond.operand
if (isinstance(cond, ExprNodes.PrimaryCmpNode)
and cond.cascade is None
and cond.operator == '=='
......@@ -40,14 +54,8 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
return t1, c1+c2
return None, None
def is_common_value(self, a, b):
if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
return a.name == b.name
if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
return not a.is_py_attr and is_common_value(a.obj, b.obj)
return False
def visit_IfStatNode(self, node):
self.visitchildren(node)
if len(node.if_clauses) < 3:
return node
common_var = None
......@@ -56,7 +64,7 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
var, conditions = self.extract_conditions(if_clause.condition)
if var is None:
return node
elif common_var is not None and not self.is_common_value(var, common_var):
elif common_var is not None and not is_common_value(var, common_var):
return node
else:
common_var = var
......@@ -67,8 +75,60 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
test = common_var,
cases = cases,
else_clause = node.else_clause)
def visit_Node(self, node):
self.visitchildren(node)
return node
class FlattenInListTransform(Visitor.VisitorTransform):
"""
This transformation flattens "x in [val1, ..., valn]" into a sequential list
of comparisons.
"""
def visit_PrimaryCmpNode(self, node):
self.visitchildren(node)
if node.cascade is not None:
return node
elif node.operator == 'in':
conjunction = 'or'
eq_or_neq = '=='
elif node.operator == 'not_in':
conjunction = 'and'
eq_or_neq = '!='
else:
return node
args = node.operand2.args
if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode):
if len(args) == 0:
return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
else:
lhs = ExprNodes.PersistentNode(node.operand1, len(args))
conds = []
for arg in args:
cond = ExprNodes.PrimaryCmpNode(
pos = node.pos,
operand1 = lhs,
operator = eq_or_neq,
operand2 = arg,
cascade = None)
conds.append(ExprNodes.TypecastNode(
pos = node.pos,
operand = cond,
type = PyrexTypes.c_bint_type))
def concat(left, right):
return ExprNodes.BoolBinopNode(
pos = node.pos,
operator = conjunction,
operand1 = left,
operand2 = right)
return reduce(concat, conds)
else:
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
......@@ -7,6 +7,7 @@ from Cython.Distutils import build_ext
ext_modules=[
Extension("primes", ["primes.pyx"]),
Extension("spam", ["spam.pyx"]),
# Extension("optargs", ["optargs.pyx"], language = "c++"),
]
for file in glob.glob("*.pyx"):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment