From d1dd80fd2abf283a7e4cd238aa3d002801b91c03 Mon Sep 17 00:00:00 2001 From: Stefan Behnel <scoder@users.berlios.de> Date: Sun, 7 Sep 2008 20:57:19 +0200 Subject: [PATCH] enable the switch transform also for long 'or' expressions in a single 'if' statement --- Cython/Compiler/Optimize.py | 6 +++-- tests/run/switch.pyx | 50 +++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 6036a5ade..2d6a16c63 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -56,9 +56,8 @@ class SwitchTransform(Visitor.VisitorTransform): def visit_IfStatNode(self, node): self.visitchildren(node) - if len(node.if_clauses) < 3: - return node common_var = None + case_count = 0 cases = [] for if_clause in node.if_clauses: var, conditions = self.extract_conditions(if_clause.condition) @@ -70,9 +69,12 @@ class SwitchTransform(Visitor.VisitorTransform): return node else: common_var = var + case_count += len(conditions) cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, conditions = conditions, body = if_clause.body)) + if case_count < 2: + return node common_var = unwrap_node(common_var) return Nodes.SwitchStatNode(pos = node.pos, diff --git a/tests/run/switch.pyx b/tests/run/switch.pyx index da5bed14b..80f16f7d4 100644 --- a/tests/run/switch.pyx +++ b/tests/run/switch.pyx @@ -62,6 +62,33 @@ __doc__ = u""" 12 >>> switch_c(13) 0 + +>>> switch_or(0) +0 +>>> switch_or(1) +1 +>>> switch_or(2) +1 +>>> switch_or(3) +1 +>>> switch_or(4) +0 + +>>> switch_short(0) +0 +>>> switch_short(1) +1 +>>> switch_short(2) +2 +>>> switch_short(3) +0 + +>>> switch_off(0) +0 +>>> switch_off(1) +1 +>>> switch_off(2) +0 """ def switch_simple_py(x): @@ -123,3 +150,26 @@ def switch_c(int x): else: return 0 return -1 + +def switch_or(int x): + if x == 1 or x == 2 or x == 3: + return 1 + else: + return 0 + return -1 + +def switch_short(int x): + if x == 1: + return 1 + elif 2 == x: + return 2 + else: + return 0 + return -1 + +def switch_off(int x): + if x == 1: + return 1 + else: + return 0 + return -1 -- 2.30.9