Commit efa0d3cb authored by Robert Bradshaw's avatar Robert Bradshaw

Case statements and "x in [...]" flattening.

parent e42ddaa8
...@@ -4019,6 +4019,56 @@ class CloneNode(CoercionNode): ...@@ -4019,6 +4019,56 @@ class CloneNode(CoercionNode):
def release_temp(self, env): def release_temp(self, env):
pass pass
class PersistentNode(ExprNode):
# A PersistentNode is like a CloneNode except it handles the temporary
# allocation itself by keeping track of the number of times it has been
# used.
subexprs = ["arg"]
temp_counter = 0
generate_counter = 0
result_code = None
def __init__(self, arg, uses):
self.pos = arg.pos
self.arg = arg
self.uses = uses
def analyse_types(self, env):
self.arg.analyse_types(env)
self.type = self.arg.type
self.result_ctype = self.arg.result_ctype
self.is_temp = 1
def generate_evaluation_code(self, code):
if self.generate_counter == 0:
self.arg.generate_evaluation_code(code)
code.putln("%s = %s;" % (
self.result_code, self.arg.result_as(self.ctype())))
if self.type.is_pyobject:
code.put_incref(self.result_code, self.ctype())
self.arg.generate_disposal_code(code)
self.generate_counter += 1
def generate_disposal_code(self, code):
if self.generate_counter == self.uses:
if self.type.is_pyobject:
code.put_decref_clear(self.result_code, self.ctype())
def allocate_temps(self, env, result=None):
if self.temp_counter == 0:
self.arg.allocate_temps(env)
if result is None:
self.result_code = env.allocate_temp(self.type)
else:
self.result_code = result
self.arg.release_temp(env)
self.temp_counter += 1
def release_temp(self, env):
if self.temp_counter == self.uses:
env.release_temp(self.result_code)
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
# Runtime support code # Runtime support code
......
...@@ -357,6 +357,7 @@ def create_default_pipeline(context, options, result): ...@@ -357,6 +357,7 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform
from Buffer import BufferTransform from Buffer import BufferTransform
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
...@@ -364,12 +365,14 @@ def create_default_pipeline(context, options, result): ...@@ -364,12 +365,14 @@ def create_default_pipeline(context, options, result):
create_parse(context), create_parse(context),
NormalizeTree(context), NormalizeTree(context),
PostParse(context), PostParse(context),
FlattenInListTransform(),
WithTransform(context), WithTransform(context),
DecoratorTransform(context), DecoratorTransform(context),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(context), AnalyseExpressionsTransform(context),
BufferTransform(context), BufferTransform(context),
SwitchTransform(),
# CreateClosureClasses(context), # CreateClosureClasses(context),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
......
import Nodes import Nodes
import ExprNodes import ExprNodes
import PyrexTypes
import Visitor import Visitor
def unwrap_node(node):
while isinstance(node, ExprNodes.PersistentNode):
node = node.arg
return node
def is_common_value(a, b): def is_common_value(a, b):
a = unwrap_node(a)
b = unwrap_node(b)
if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
return a.name == b.name return a.name == b.name
if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
...@@ -11,13 +18,20 @@ def is_common_value(a, b): ...@@ -11,13 +18,20 @@ def is_common_value(a, b):
return False return False
class SwitchTransformVisitor(Visitor.VisitorTransform): class SwitchTransform(Visitor.VisitorTransform):
"""
This transformation tries to turn long if statements into C switch statements.
The requirement is that every clause be an (or of) var == value, where the var
is common among all clauses and both var and value are not Python objects.
"""
def extract_conditions(self, cond): def extract_conditions(self, cond):
if isinstance(cond, ExprNodes.CoerceToTempNode): if isinstance(cond, ExprNodes.CoerceToTempNode):
cond = cond.arg cond = cond.arg
if isinstance(cond, ExprNodes.TypecastNode):
cond = cond.operand
if (isinstance(cond, ExprNodes.PrimaryCmpNode) if (isinstance(cond, ExprNodes.PrimaryCmpNode)
and cond.cascade is None and cond.cascade is None
and cond.operator == '==' and cond.operator == '=='
...@@ -40,14 +54,8 @@ class SwitchTransformVisitor(Visitor.VisitorTransform): ...@@ -40,14 +54,8 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
return t1, c1+c2 return t1, c1+c2
return None, None return None, None
def is_common_value(self, a, b):
if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
return a.name == b.name
if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
return not a.is_py_attr and is_common_value(a.obj, b.obj)
return False
def visit_IfStatNode(self, node): def visit_IfStatNode(self, node):
self.visitchildren(node)
if len(node.if_clauses) < 3: if len(node.if_clauses) < 3:
return node return node
common_var = None common_var = None
...@@ -56,7 +64,7 @@ class SwitchTransformVisitor(Visitor.VisitorTransform): ...@@ -56,7 +64,7 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
var, conditions = self.extract_conditions(if_clause.condition) var, conditions = self.extract_conditions(if_clause.condition)
if var is None: if var is None:
return node return node
elif common_var is not None and not self.is_common_value(var, common_var): elif common_var is not None and not is_common_value(var, common_var):
return node return node
else: else:
common_var = var common_var = var
...@@ -68,7 +76,59 @@ class SwitchTransformVisitor(Visitor.VisitorTransform): ...@@ -68,7 +76,59 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
cases = cases, cases = cases,
else_clause = node.else_clause) else_clause = node.else_clause)
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self.visitchildren(node)
return node return node
class FlattenInListTransform(Visitor.VisitorTransform):
"""
This transformation flattens "x in [val1, ..., valn]" into a sequential list
of comparisons.
"""
def visit_PrimaryCmpNode(self, node):
self.visitchildren(node)
if node.cascade is not None:
return node
elif node.operator == 'in':
conjunction = 'or'
eq_or_neq = '=='
elif node.operator == 'not_in':
conjunction = 'and'
eq_or_neq = '!='
else:
return node
args = node.operand2.args
if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode):
if len(args) == 0:
return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
else:
lhs = ExprNodes.PersistentNode(node.operand1, len(args))
conds = []
for arg in args:
cond = ExprNodes.PrimaryCmpNode(
pos = node.pos,
operand1 = lhs,
operator = eq_or_neq,
operand2 = arg,
cascade = None)
conds.append(ExprNodes.TypecastNode(
pos = node.pos,
operand = cond,
type = PyrexTypes.c_bint_type))
def concat(left, right):
return ExprNodes.BoolBinopNode(
pos = node.pos,
operator = conjunction,
operand1 = left,
operand2 = right)
return reduce(concat, conds)
else:
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
...@@ -7,6 +7,7 @@ from Cython.Distutils import build_ext ...@@ -7,6 +7,7 @@ from Cython.Distutils import build_ext
ext_modules=[ ext_modules=[
Extension("primes", ["primes.pyx"]), Extension("primes", ["primes.pyx"]),
Extension("spam", ["spam.pyx"]), Extension("spam", ["spam.pyx"]),
# Extension("optargs", ["optargs.pyx"], language = "c++"),
] ]
for file in glob.glob("*.pyx"): for file in glob.glob("*.pyx"):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment