Commit ab7bd194 authored by Stefan Behnel's avatar Stefan Behnel

extend switch statement transformation to arbitrary 'in' tests: shorter, more readable C code

parent 1c5fd0c9
...@@ -540,34 +540,101 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -540,34 +540,101 @@ class SwitchTransform(Visitor.VisitorTransform):
if is_common_value(t1, t2): if is_common_value(t1, t2):
return t1, c1+c2 return t1, c1+c2
return None, None return None, None
def extract_common_conditions(self, common_var, condition):
var, conditions = self.extract_conditions(condition)
if var is None:
return None, None
elif common_var is not None and not is_common_value(var, common_var):
return None, None
elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
return None, None
return var, conditions
def visit_IfStatNode(self, node): def visit_IfStatNode(self, node):
self.visitchildren(node)
common_var = None common_var = None
case_count = 0
cases = [] cases = []
for if_clause in node.if_clauses: for if_clause in node.if_clauses:
var, conditions = self.extract_conditions(if_clause.condition) common_var, conditions = self.extract_common_conditions(
if var is None: common_var, if_clause.condition)
return node if common_var is None:
elif common_var is not None and not is_common_value(var, common_var):
return node return node
elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]): cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
return node conditions = conditions,
else: body = if_clause.body))
common_var = var
case_count += len(conditions) if sum([ len(case.conditions) for case in cases ]) < 2:
cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, return node
conditions = conditions,
body = if_clause.body)) common_var = unwrap_node(common_var)
if case_count < 2: switch_node = Nodes.SwitchStatNode(pos = node.pos,
return node test = common_var,
cases = cases,
else_clause = node.else_clause)
self.visitchildren(switch_node)
return switch_node
def visit_CondExprNode(self, node):
common_var, conditions = self.extract_common_conditions(None, node.test)
if common_var is None:
return node
if len(conditions) < 2:
return node
result_ref = UtilNodes.ResultRefNode(node)
true_body = Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
rhs = node.true_val,
first = True)
false_body = Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
rhs = node.false_val,
first = True)
cases = [Nodes.SwitchCaseNode(pos = node.pos,
conditions = conditions,
body = true_body)]
common_var = unwrap_node(common_var) common_var = unwrap_node(common_var)
return Nodes.SwitchStatNode(pos = node.pos, switch_node = Nodes.SwitchStatNode(pos = node.pos,
test = common_var, test = common_var,
cases = cases, cases = cases,
else_clause = node.else_clause) else_clause = false_body)
self.visitchildren(switch_node)
return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
def visit_BoolBinopNode(self, node):
common_var, conditions = self.extract_common_conditions(None, node)
if common_var is None:
return node
if len(conditions) < 2:
return node
result_ref = UtilNodes.ResultRefNode(node)
true_body = Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
rhs = ExprNodes.BoolNode(node.pos, value=True),
first = True)
false_body = Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
rhs = ExprNodes.BoolNode(node.pos, value=False),
first = True)
cases = [Nodes.SwitchCaseNode(pos = node.pos,
conditions = conditions,
body = true_body)]
common_var = unwrap_node(common_var)
switch_node = Nodes.SwitchStatNode(pos = node.pos,
test = common_var,
cases = cases,
else_clause = false_body)
self.visitchildren(switch_node)
return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
visit_Node = Visitor.VisitorTransform.recurse_to_children visit_Node = Visitor.VisitorTransform.recurse_to_children
...@@ -941,6 +1008,15 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -941,6 +1008,15 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
return node return node
return node.arg return node.arg
def visit_TypecastNode(self, node):
"""
Drop redundant type casts.
"""
self.visitchildren(node)
if node.type == node.operand.type:
return node.operand
return node
def visit_CoerceToBooleanNode(self, node): def visit_CoerceToBooleanNode(self, node):
"""Drop redundant conversion nodes after tree changes. """Drop redundant conversion nodes after tree changes.
""" """
......
...@@ -148,7 +148,8 @@ class ResultRefNode(AtomicExprNode): ...@@ -148,7 +148,8 @@ class ResultRefNode(AtomicExprNode):
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code):
if self.type.is_pyobject: if self.type.is_pyobject:
rhs.make_owned_reference(code) rhs.make_owned_reference(code)
code.put_decref(self.result(), self.ctype()) if not self.lhs_of_first_assignment:
code.put_decref(self.result(), self.ctype())
code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype()))) code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype())))
rhs.generate_post_assignment_code(code) rhs.generate_post_assignment_code(code)
rhs.free_temps(code) rhs.free_temps(code)
...@@ -250,3 +251,26 @@ class LetNode(Nodes.StatNode, LetNodeMixin): ...@@ -250,3 +251,26 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
self.setup_temp_expr(code) self.setup_temp_expr(code)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
self.teardown_temp_expr(code) self.teardown_temp_expr(code)
class TempResultFromStatNode(ExprNodes.ExprNode):
# An ExprNode wrapper around a StatNode that executes the StatNode
# body. Requires a ResultRefNode that it sets up to refer to its
# own temp result. The StatNode must assign a value to the result
# node, which then becomes the result of this node.
#
# This can only be used in/after type analysis.
#
subexprs = []
child_attrs = ['body']
def __init__(self, result_ref, body):
self.result_ref = result_ref
self.pos = body.pos
self.body = body
self.type = result_ref.type
self.is_temp = 1
def generate_result_code(self, code):
self.result_ref.result_code = self.result()
self.body.generate_execution_code(code)
cimport cython
def f(a,b): def f(a,b):
""" """
>>> f(1,[1,2,3]) >>> f(1,[1,2,3])
...@@ -42,6 +45,7 @@ def j(b): ...@@ -42,6 +45,7 @@ def j(b):
cdef int result = 2 in b cdef int result = 2 in b
return result return result
@cython.test_fail_if_path_exists("//SwitchStatNode")
def k(a): def k(a):
""" """
>>> k(1) >>> k(1)
...@@ -52,6 +56,8 @@ def k(a): ...@@ -52,6 +56,8 @@ def k(a):
cdef int result = a in [1,2,3,4] cdef int result = a in [1,2,3,4]
return result return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_list(int a): def m_list(int a):
""" """
>>> m_list(2) >>> m_list(2)
...@@ -62,6 +68,8 @@ def m_list(int a): ...@@ -62,6 +68,8 @@ def m_list(int a):
cdef int result = a in [1,2,3,4] cdef int result = a in [1,2,3,4]
return result return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_tuple(int a): def m_tuple(int a):
""" """
>>> m_tuple(2) >>> m_tuple(2)
...@@ -72,6 +80,8 @@ def m_tuple(int a): ...@@ -72,6 +80,8 @@ def m_tuple(int a):
cdef int result = a in (1,2,3,4) cdef int result = a in (1,2,3,4)
return result return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_set(int a): def m_set(int a):
""" """
>>> m_set(2) >>> m_set(2)
...@@ -82,6 +92,44 @@ def m_set(int a): ...@@ -82,6 +92,44 @@ def m_set(int a):
cdef int result = a in {1,2,3,4} cdef int result = a in {1,2,3,4}
return result return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def conditional_int(int a):
"""
>>> conditional_int(1)
1
>>> conditional_int(0)
2
>>> conditional_int(5)
2
"""
return 1 if a in (1,2,3,4) else 2
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def conditional_object(int a):
"""
>>> conditional_object(1)
1
>>> conditional_object(0)
'2'
>>> conditional_object(5)
'2'
"""
return 1 if a in (1,2,3,4) else '2'
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def conditional_none(int a):
"""
>>> conditional_none(1)
>>> conditional_none(0)
1
>>> conditional_none(5)
1
"""
return None if a in {1,2,3,4} else 1
def n(a): def n(a):
""" """
>>> n('d *') >>> n('d *')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment