Commit 2e0c27f6 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge pull request #287 from gandalf013/switch_transform

add a directive to disable SwitchTransform
parents f3daed10 edb10a47
...@@ -802,7 +802,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -802,7 +802,7 @@ class IterationTransform(Visitor.EnvTransform):
]) ])
class SwitchTransform(Visitor.VisitorTransform): class SwitchTransform(Visitor.CythonTransform):
""" """
This transformation tries to turn long if statements into C switch statements. This transformation tries to turn long if statements into C switch statements.
The requirement is that every clause be an (or of) var == value, where the var The requirement is that every clause be an (or of) var == value, where the var
...@@ -917,6 +917,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -917,6 +917,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return False return False
def visit_IfStatNode(self, node): def visit_IfStatNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
common_var = None common_var = None
cases = [] cases = []
for if_clause in node.if_clauses: for if_clause in node.if_clauses:
...@@ -946,6 +950,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -946,6 +950,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return switch_node return switch_node
def visit_CondExprNode(self, node): def visit_CondExprNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions( not_in, common_var, conditions = self.extract_common_conditions(
None, node.test, True) None, node.test, True)
if common_var is None \ if common_var is None \
...@@ -958,6 +966,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -958,6 +966,10 @@ class SwitchTransform(Visitor.VisitorTransform):
node.true_val, node.false_val) node.true_val, node.false_val)
def visit_BoolBinopNode(self, node): def visit_BoolBinopNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions( not_in, common_var, conditions = self.extract_common_conditions(
None, node, True) None, node, True)
if common_var is None \ if common_var is None \
...@@ -972,6 +984,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -972,6 +984,10 @@ class SwitchTransform(Visitor.VisitorTransform):
ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
def visit_PrimaryCmpNode(self, node): def visit_PrimaryCmpNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions( not_in, common_var, conditions = self.extract_common_conditions(
None, node, True) None, node, True)
if common_var is None \ if common_var is None \
...@@ -1015,6 +1031,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -1015,6 +1031,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return replacement return replacement
def visit_EvalWithTempExprNode(self, node): def visit_EvalWithTempExprNode(self, node):
if not self.current_directives.get('optimize.use_switch'):
self.visitchildren(node)
return node
# drop unused expression temp from FlattenInListTransform # drop unused expression temp from FlattenInListTransform
orig_expr = node.subexpression orig_expr = node.subexpression
temp_ref = node.lazy_temp temp_ref = node.lazy_temp
......
...@@ -129,6 +129,7 @@ directive_defaults = { ...@@ -129,6 +129,7 @@ directive_defaults = {
# optimizations # optimizations
'optimize.inline_defnode_calls': True, 'optimize.inline_defnode_calls': True,
'optimize.use_switch': True,
# remove unreachable code # remove unreachable code
'remove_unreachable': True, 'remove_unreachable': True,
......
...@@ -202,7 +202,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -202,7 +202,7 @@ def create_pipeline(context, mode, exclude_classes=()):
CalculateQualifiedNamesTransform(context), CalculateQualifiedNamesTransform(context),
ConsolidateOverflowCheck(context), ConsolidateOverflowCheck(context),
IterationTransform(context), IterationTransform(context),
SwitchTransform(), SwitchTransform(context),
DropRefcountingTransform(), DropRefcountingTransform(),
FinalOptimizePhase(context), FinalOptimizePhase(context),
GilCheck(), GilCheck(),
......
...@@ -693,10 +693,14 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -693,10 +693,14 @@ class CythonCompileTestCase(unittest.TestCase):
if line.startswith("_ERRORS"): if line.startswith("_ERRORS"):
out.close() out.close()
out = ErrorWriter() out = ErrorWriter()
elif line.startswith('_FAIL_C_COMPILE'):
out.close()
return '_FAIL_C_COMPILE'
else: else:
out.write(line) out.write(line)
finally: finally:
source_and_output.close() source_and_output.close()
try: try:
geterrors = out.geterrors geterrors = out.geterrors
except AttributeError: except AttributeError:
...@@ -835,7 +839,14 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -835,7 +839,14 @@ class CythonCompileTestCase(unittest.TestCase):
finally: finally:
sys.stderr = old_stderr sys.stderr = old_stderr
if errors or expected_errors: if expected_errors == '_FAIL_C_COMPILE':
if errors:
print("\n=== Expected C compile error ===")
print("\n\n=== Got Cython errors: ===")
print('\n'.join(errors))
print('\n')
raise RuntimeError('should have generated extension code')
elif errors or expected_errors:
try: try:
for expected, error in zip(expected_errors, errors): for expected, error in zip(expected_errors, errors):
self.assertEquals(expected, error) self.assertEquals(expected, error)
...@@ -854,10 +865,13 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -854,10 +865,13 @@ class CythonCompileTestCase(unittest.TestCase):
raise raise
return None return None
if self.cython_only: so_path = None
so_path = None if not self.cython_only:
else: try:
so_path = self.run_distutils(test_directory, module, workdir, incdir) so_path = self.run_distutils(test_directory, module, workdir, incdir)
except:
if expected_errors != '_FAIL_C_COMPILE':
raise
return so_path return so_path
class CythonRunTestCase(CythonCompileTestCase): class CythonRunTestCase(CythonCompileTestCase):
...@@ -882,7 +896,7 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -882,7 +896,7 @@ class CythonRunTestCase(CythonCompileTestCase):
self.success = False self.success = False
ext_so_path = self.runCompileTest() ext_so_path = self.runCompileTest()
failures, errors = len(result.failures), len(result.errors) failures, errors = len(result.failures), len(result.errors)
if not self.cython_only: if not self.cython_only and ext_so_path is not None:
self.run_tests(result, ext_so_path) self.run_tests(result, ext_so_path)
if failures == len(result.failures) and errors == len(result.errors): if failures == len(result.failures) and errors == len(result.errors):
# No new errors... # No new errors...
......
# cython: optimize.use_switch=True
# mode: error
import cython
cdef extern from "../run/includes/switch_transform_support.h":
enum:
ONE
ONE_AGAIN
def is_not_one(int i):
return i != ONE and i != ONE_AGAIN
_FAIL_C_COMPILE = True
#ifndef SWITCH_TRANSFORM_SUPPORT_H_
#define SWITCH_TRANSFORM_SUPPORT_H_ 1
enum {
ONE=1,
ONE_AGAIN=1
};
#endif /* SWITCH_TRANSFORM_SUPPORT_H_ */
# cython: optimize.use_switch=False
cdef extern from "includes/switch_transform_support.h":
enum:
ONE
ONE_AGAIN
def is_not_one(int i):
"""
>>> is_not_one(1)
False
>>> is_not_one(2)
True
"""
return i != ONE and i != ONE_AGAIN
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment