Commit f18ad43a authored by Alok Singhal's avatar Alok Singhal

add a directive to disable SwitchTransform

SwitchTransform is unable to detect all cases of duplicate values, which
result in errors at compile time in the generated code.
parent 87049841
...@@ -802,7 +802,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -802,7 +802,7 @@ class IterationTransform(Visitor.EnvTransform):
]) ])
class SwitchTransform(Visitor.VisitorTransform): class SwitchTransform(Visitor.CythonTransform):
""" """
This transformation tries to turn long if statements into C switch statements. This transformation tries to turn long if statements into C switch statements.
The requirement is that every clause be an (or of) var == value, where the var The requirement is that every clause be an (or of) var == value, where the var
...@@ -917,6 +917,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -917,6 +917,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return False return False
def visit_IfStatNode(self, node): def visit_IfStatNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
common_var = None common_var = None
cases = [] cases = []
for if_clause in node.if_clauses: for if_clause in node.if_clauses:
...@@ -946,6 +950,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -946,6 +950,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return switch_node return switch_node
def visit_CondExprNode(self, node): def visit_CondExprNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions( not_in, common_var, conditions = self.extract_common_conditions(
None, node.test, True) None, node.test, True)
if common_var is None \ if common_var is None \
...@@ -958,6 +966,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -958,6 +966,10 @@ class SwitchTransform(Visitor.VisitorTransform):
node.true_val, node.false_val) node.true_val, node.false_val)
def visit_BoolBinopNode(self, node): def visit_BoolBinopNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions( not_in, common_var, conditions = self.extract_common_conditions(
None, node, True) None, node, True)
if common_var is None \ if common_var is None \
...@@ -972,6 +984,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -972,6 +984,10 @@ class SwitchTransform(Visitor.VisitorTransform):
ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
def visit_PrimaryCmpNode(self, node): def visit_PrimaryCmpNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions( not_in, common_var, conditions = self.extract_common_conditions(
None, node, True) None, node, True)
if common_var is None \ if common_var is None \
...@@ -1015,6 +1031,10 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -1015,6 +1031,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return replacement return replacement
def visit_EvalWithTempExprNode(self, node): def visit_EvalWithTempExprNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
# drop unused expression temp from FlattenInListTransform # drop unused expression temp from FlattenInListTransform
orig_expr = node.subexpression orig_expr = node.subexpression
temp_ref = node.lazy_temp temp_ref = node.lazy_temp
......
...@@ -129,6 +129,7 @@ directive_defaults = { ...@@ -129,6 +129,7 @@ directive_defaults = {
# optimizations # optimizations
'optimize.inline_defnode_calls': True, 'optimize.inline_defnode_calls': True,
'optimize.switchcase_transform': True,
# remove unreachable code # remove unreachable code
'remove_unreachable': True, 'remove_unreachable': True,
......
...@@ -202,7 +202,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -202,7 +202,7 @@ def create_pipeline(context, mode, exclude_classes=()):
CalculateQualifiedNamesTransform(context), CalculateQualifiedNamesTransform(context),
ConsolidateOverflowCheck(context), ConsolidateOverflowCheck(context),
IterationTransform(context), IterationTransform(context),
SwitchTransform(), SwitchTransform(context),
DropRefcountingTransform(), DropRefcountingTransform(),
FinalOptimizePhase(context), FinalOptimizePhase(context),
GilCheck(), GilCheck(),
......
...@@ -399,6 +399,7 @@ class ErrorWriter(object): ...@@ -399,6 +399,7 @@ class ErrorWriter(object):
def _collect(self, collect_errors, collect_warnings): def _collect(self, collect_errors, collect_warnings):
s = ''.join(self.output) s = ''.join(self.output)
result = [] result = []
runtime_error = False
for line in s.split('\n'): for line in s.split('\n'):
match = self.match_error(line) match = self.match_error(line)
if match: if match:
...@@ -406,8 +407,10 @@ class ErrorWriter(object): ...@@ -406,8 +407,10 @@ class ErrorWriter(object):
if (is_warning and collect_warnings) or \ if (is_warning and collect_warnings) or \
(not is_warning and collect_errors): (not is_warning and collect_errors):
result.append( (int(line), int(column), message.strip()) ) result.append( (int(line), int(column), message.strip()) )
elif 'runtime error' in line:
runtime_error = True
result.sort() result.sort()
return [ "%d:%d: %s" % values for values in result ] return [ "%d:%d: %s" % values for values in result ], runtime_error
def geterrors(self): def geterrors(self):
return self._collect(True, False) return self._collect(True, False)
...@@ -701,7 +704,7 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -701,7 +704,7 @@ class CythonCompileTestCase(unittest.TestCase):
geterrors = out.geterrors geterrors = out.geterrors
except AttributeError: except AttributeError:
out.close() out.close()
return [] return [], False
else: else:
return geterrors() return geterrors()
...@@ -822,7 +825,7 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -822,7 +825,7 @@ class CythonCompileTestCase(unittest.TestCase):
expect_errors, annotate): expect_errors, annotate):
expected_errors = errors = () expected_errors = errors = ()
if expect_errors: if expect_errors:
expected_errors = self.split_source_and_output( expected_errors, runtime_error = self.split_source_and_output(
test_directory, module, workdir) test_directory, module, workdir)
test_directory = workdir test_directory = workdir
...@@ -831,7 +834,6 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -831,7 +834,6 @@ class CythonCompileTestCase(unittest.TestCase):
try: try:
sys.stderr = ErrorWriter() sys.stderr = ErrorWriter()
self.run_cython(test_directory, module, workdir, incdir, annotate) self.run_cython(test_directory, module, workdir, incdir, annotate)
errors = sys.stderr.geterrors()
finally: finally:
sys.stderr = old_stderr sys.stderr = old_stderr
...@@ -857,7 +859,13 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -857,7 +859,13 @@ class CythonCompileTestCase(unittest.TestCase):
if self.cython_only: if self.cython_only:
so_path = None so_path = None
else: else:
so_path = self.run_distutils(test_directory, module, workdir, incdir) try:
so_path = self.run_distutils(test_directory, module, workdir, incdir)
except:
if runtime_error:
return None
else:
raise
return so_path return so_path
class CythonRunTestCase(CythonCompileTestCase): class CythonRunTestCase(CythonCompileTestCase):
...@@ -882,7 +890,7 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -882,7 +890,7 @@ class CythonRunTestCase(CythonCompileTestCase):
self.success = False self.success = False
ext_so_path = self.runCompileTest() ext_so_path = self.runCompileTest()
failures, errors = len(result.failures), len(result.errors) failures, errors = len(result.failures), len(result.errors)
if not self.cython_only: if not self.cython_only and ext_so_path is not None:
self.run_tests(result, ext_so_path) self.run_tests(result, ext_so_path)
if failures == len(result.failures) and errors == len(result.errors): if failures == len(result.failures) and errors == len(result.errors):
# No new errors... # No new errors...
......
# cython: optimize.switchcase_transform=True
# mode: error
import cython
cdef extern from "includes/e_switch_transform_support.h":
enum:
ONE
ONE_AGAIN
def is_not_one(int i):
return i != ONE and i != ONE_AGAIN
_ERRORS = u'''
runtime error
'''
#ifndef E_SWITCH_TRANSFORM_SUPPORT_H_
#define E_SWITCH_TRANSFORM_SUPPORT_H_ 1
enum {
ONE=1,
ONE_AGAIN=1
};
#endif /* E_SWITCH_TRANSFORM_SUPPORT_H_ */
#ifndef SWITCH_TRANSFORM_SUPPORT_H_
#define SWITCH_TRANSFORM_SUPPORT_H_ 1
enum {
ONE=1,
ONE_AGAIN=1
};
#endif /* SWITCH_TRANSFORM_SUPPORT_H_ */
# cython: optimize.switchcase_transform=False
cdef extern from "includes/switch_transform_support.h":
enum:
ONE
ONE_AGAIN
def is_not_one(int i):
"""
>>> is_not_one(1)
False
>>> is_not_one(2)
True
"""
return i != ONE and i != ONE_AGAIN
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment