Commit 2e0c27f6 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge pull request #287 from gandalf013/switch_transform

add a directive to disable SwitchTransform
parents f3daed10 edb10a47
......@@ -802,7 +802,7 @@ class IterationTransform(Visitor.EnvTransform):
])
class SwitchTransform(Visitor.VisitorTransform):
class SwitchTransform(Visitor.CythonTransform):
"""
This transformation tries to turn long if statements into C switch statements.
The requirement is that every clause be an (or of) var == value, where the var
......@@ -917,6 +917,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return False
def visit_IfStatNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
common_var = None
cases = []
for if_clause in node.if_clauses:
......@@ -946,6 +950,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return switch_node
def visit_CondExprNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions(
None, node.test, True)
if common_var is None \
......@@ -958,6 +966,10 @@ class SwitchTransform(Visitor.VisitorTransform):
node.true_val, node.false_val)
def visit_BoolBinopNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions(
None, node, True)
if common_var is None \
......@@ -972,6 +984,10 @@ class SwitchTransform(Visitor.VisitorTransform):
ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
def visit_PrimaryCmpNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions(
None, node, True)
if common_var is None \
......@@ -1015,6 +1031,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return replacement
def visit_EvalWithTempExprNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
# drop unused expression temp from FlattenInListTransform
orig_expr = node.subexpression
temp_ref = node.lazy_temp
......
......@@ -129,6 +129,7 @@ directive_defaults = {
# optimizations
'optimize.inline_defnode_calls': True,
'optimize.use_switch': True,
# remove unreachable code
'remove_unreachable': True,
......
......@@ -202,7 +202,7 @@ def create_pipeline(context, mode, exclude_classes=()):
CalculateQualifiedNamesTransform(context),
ConsolidateOverflowCheck(context),
IterationTransform(context),
SwitchTransform(),
SwitchTransform(context),
DropRefcountingTransform(),
FinalOptimizePhase(context),
GilCheck(),
......
......@@ -693,10 +693,14 @@ class CythonCompileTestCase(unittest.TestCase):
if line.startswith("_ERRORS"):
out.close()
out = ErrorWriter()
elif line.startswith('_FAIL_C_COMPILE'):
out.close()
return '_FAIL_C_COMPILE'
else:
out.write(line)
finally:
source_and_output.close()
try:
geterrors = out.geterrors
except AttributeError:
......@@ -835,7 +839,14 @@ class CythonCompileTestCase(unittest.TestCase):
finally:
sys.stderr = old_stderr
if errors or expected_errors:
if expected_errors == '_FAIL_C_COMPILE':
if errors:
print("\n=== Expected C compile error ===")
print("\n\n=== Got Cython errors: ===")
print('\n'.join(errors))
print('\n')
raise RuntimeError('should have generated extension code')
elif errors or expected_errors:
try:
for expected, error in zip(expected_errors, errors):
self.assertEquals(expected, error)
......@@ -854,10 +865,13 @@ class CythonCompileTestCase(unittest.TestCase):
raise
return None
if self.cython_only:
so_path = None
else:
if not self.cython_only:
try:
so_path = self.run_distutils(test_directory, module, workdir, incdir)
except:
if expected_errors != '_FAIL_C_COMPILE':
raise
return so_path
class CythonRunTestCase(CythonCompileTestCase):
......@@ -882,7 +896,7 @@ class CythonRunTestCase(CythonCompileTestCase):
self.success = False
ext_so_path = self.runCompileTest()
failures, errors = len(result.failures), len(result.errors)
if not self.cython_only:
if not self.cython_only and ext_so_path is not None:
self.run_tests(result, ext_so_path)
if failures == len(result.failures) and errors == len(result.errors):
# No new errors...
......
# cython: optimize.use_switch=True
# mode: error
import cython
cdef extern from "../run/includes/switch_transform_support.h":
enum:
ONE
ONE_AGAIN
def is_not_one(int i):
return i != ONE and i != ONE_AGAIN
_FAIL_C_COMPILE = True
#ifndef SWITCH_TRANSFORM_SUPPORT_H_
#define SWITCH_TRANSFORM_SUPPORT_H_ 1
enum {
ONE=1,
ONE_AGAIN=1
};
#endif /* SWITCH_TRANSFORM_SUPPORT_H_ */
# cython: optimize.use_switch=False
cdef extern from "includes/switch_transform_support.h":
enum:
ONE
ONE_AGAIN
def is_not_one(int i):
"""
>>> is_not_one(1)
False
>>> is_not_one(2)
True
"""
return i != ONE and i != ONE_AGAIN
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment