Commit ad1bed63 authored by Robert Bradshaw's avatar Robert Bradshaw

Switch statement omptimization

parent 17aca901
...@@ -149,7 +149,7 @@ class Node(object): ...@@ -149,7 +149,7 @@ class Node(object):
except AttributeError: except AttributeError:
flat = [] flat = []
for attr in self.child_attrs: for attr in self.child_attrs:
child = getattr(parent, attr) child = getattr(self, attr)
# Sometimes lists, sometimes nodes # Sometimes lists, sometimes nodes
if child is None: if child is None:
pass pass
...@@ -2850,7 +2850,50 @@ class IfClauseNode(Node): ...@@ -2850,7 +2850,50 @@ class IfClauseNode(Node):
self.condition.annotate(code) self.condition.annotate(code)
self.body.annotate(code) self.body.annotate(code)
class SwitchCaseNode(StatNode):
# Generated in the optimization of an if-elif-else node
#
# conditions [ExprNode]
# body StatNode
child_attrs = ['conditions', 'body']
def generate_execution_code(self, code):
for cond in self.conditions:
code.putln("case %s:" % cond.calculate_result_code())
self.body.generate_execution_code(code)
code.putln("break;")
def annotate(self, code):
for cond in self.conditions:
cond.annotate(code)
body.annotate(code)
class SwitchStatNode(StatNode):
# Generated in the optimization of an if-elif-else node
#
# test ExprNode
# cases [SwitchCaseNode]
# else_clause StatNode or None
child_attrs = ['test', 'cases', 'else_clause']
def generate_execution_code(self, code):
code.putln("switch (%s) {" % self.test.calculate_result_code())
for case in self.cases:
case.generate_execution_code(code)
if self.else_clause is not None:
code.putln("default:")
self.else_clause.generate_execution_code(code)
code.putln("}")
def annotate(self, code):
self.test.annotate(code)
for case in self.cases:
case.annotate(code)
self.else_clause.annotate(code)
class LoopNode: class LoopNode:
def analyse_control_flow(self, env): def analyse_control_flow(self, env):
......
import Nodes
import ExprNodes
import Visitor
def is_common_value(a, b):
if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
return a.name == b.name
if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
return not a.is_py_attr and is_common_value(a.obj, b.obj)
return False
class SwitchTransformVisitor(Visitor.VisitorTransform):
def extract_conditions(self, cond):
if isinstance(cond, ExprNodes.CoerceToTempNode):
cond = cond.arg
if (isinstance(cond, ExprNodes.PrimaryCmpNode)
and cond.cascade is None
and cond.operator == '=='
and not cond.is_python_comparison()):
if is_common_value(cond.operand1, cond.operand1):
if isinstance(cond.operand2, ExprNodes.ConstNode):
return cond.operand1, [cond.operand2]
elif hasattr(cond.operand2, 'entry') and cond.operand2.entry.is_const:
return cond.operand1, [cond.operand2]
if is_common_value(cond.operand2, cond.operand2):
if isinstance(cond.operand1, ExprNodes.ConstNode):
return cond.operand2, [cond.operand1]
elif hasattr(cond.operand1, 'entry') and cond.operand1.entry.is_const:
return cond.operand2, [cond.operand1]
elif (isinstance(cond, ExprNodes.BoolBinopNode)
and cond.operator == 'or'):
t1, c1 = self.extract_conditions(cond.operand1)
t2, c2 = self.extract_conditions(cond.operand2)
if is_common_value(t1, t2):
return t1, c1+c2
return None, None
def is_common_value(self, a, b):
if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
return a.name == b.name
if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
return not a.is_py_attr and is_common_value(a.obj, b.obj)
return False
def visit_IfStatNode(self, node):
if len(node.if_clauses) < 3:
return node
common_var = None
cases = []
for if_clause in node.if_clauses:
var, conditions = self.extract_conditions(if_clause.condition)
if var is None:
return node
elif common_var is not None and not self.is_common_value(var, common_var):
return node
else:
common_var = var
cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
conditions = conditions,
body = if_clause.body))
return Nodes.SwitchStatNode(pos = node.pos,
test = common_var,
cases = cases,
else_clause = node.else_clause)
def visit_Node(self, node):
self.visitchildren(node)
return node
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment