Commit f18ad43a authored by Alok Singhal's avatar Alok Singhal

add a directive to disable SwitchTransform

SwitchTransform is unable to detect all cases of duplicate values, which
result in errors at compile time in the generated code.
parent 87049841
......@@ -802,7 +802,7 @@ class IterationTransform(Visitor.EnvTransform):
])
class SwitchTransform(Visitor.VisitorTransform):
class SwitchTransform(Visitor.CythonTransform):
"""
This transformation tries to turn long if statements into C switch statements.
The requirement is that every clause be an (or of) var == value, where the var
......@@ -917,6 +917,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return False
def visit_IfStatNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
common_var = None
cases = []
for if_clause in node.if_clauses:
......@@ -946,6 +950,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return switch_node
def visit_CondExprNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions(
None, node.test, True)
if common_var is None \
......@@ -958,6 +966,10 @@ class SwitchTransform(Visitor.VisitorTransform):
node.true_val, node.false_val)
def visit_BoolBinopNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions(
None, node, True)
if common_var is None \
......@@ -972,6 +984,10 @@ class SwitchTransform(Visitor.VisitorTransform):
ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
def visit_PrimaryCmpNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
not_in, common_var, conditions = self.extract_common_conditions(
None, node, True)
if common_var is None \
......@@ -1015,6 +1031,10 @@ class SwitchTransform(Visitor.VisitorTransform):
return replacement
def visit_EvalWithTempExprNode(self, node):
if not self.current_directives.get('optimize.switchcase_transform'):
self.visitchildren(node)
return node
# drop unused expression temp from FlattenInListTransform
orig_expr = node.subexpression
temp_ref = node.lazy_temp
......
......@@ -129,6 +129,7 @@ directive_defaults = {
# optimizations
'optimize.inline_defnode_calls': True,
'optimize.switchcase_transform': True,
# remove unreachable code
'remove_unreachable': True,
......
......@@ -202,7 +202,7 @@ def create_pipeline(context, mode, exclude_classes=()):
CalculateQualifiedNamesTransform(context),
ConsolidateOverflowCheck(context),
IterationTransform(context),
SwitchTransform(),
SwitchTransform(context),
DropRefcountingTransform(),
FinalOptimizePhase(context),
GilCheck(),
......
......@@ -399,6 +399,7 @@ class ErrorWriter(object):
def _collect(self, collect_errors, collect_warnings):
s = ''.join(self.output)
result = []
runtime_error = False
for line in s.split('\n'):
match = self.match_error(line)
if match:
......@@ -406,8 +407,10 @@ class ErrorWriter(object):
if (is_warning and collect_warnings) or \
(not is_warning and collect_errors):
result.append( (int(line), int(column), message.strip()) )
elif 'runtime error' in line:
runtime_error = True
result.sort()
return [ "%d:%d: %s" % values for values in result ]
return [ "%d:%d: %s" % values for values in result ], runtime_error
def geterrors(self):
return self._collect(True, False)
......@@ -701,7 +704,7 @@ class CythonCompileTestCase(unittest.TestCase):
geterrors = out.geterrors
except AttributeError:
out.close()
return []
return [], False
else:
return geterrors()
......@@ -822,7 +825,7 @@ class CythonCompileTestCase(unittest.TestCase):
expect_errors, annotate):
expected_errors = errors = ()
if expect_errors:
expected_errors = self.split_source_and_output(
expected_errors, runtime_error = self.split_source_and_output(
test_directory, module, workdir)
test_directory = workdir
......@@ -831,7 +834,6 @@ class CythonCompileTestCase(unittest.TestCase):
try:
sys.stderr = ErrorWriter()
self.run_cython(test_directory, module, workdir, incdir, annotate)
errors = sys.stderr.geterrors()
finally:
sys.stderr = old_stderr
......@@ -857,7 +859,13 @@ class CythonCompileTestCase(unittest.TestCase):
if self.cython_only:
so_path = None
else:
so_path = self.run_distutils(test_directory, module, workdir, incdir)
try:
so_path = self.run_distutils(test_directory, module, workdir, incdir)
except:
if runtime_error:
return None
else:
raise
return so_path
class CythonRunTestCase(CythonCompileTestCase):
......@@ -882,7 +890,7 @@ class CythonRunTestCase(CythonCompileTestCase):
self.success = False
ext_so_path = self.runCompileTest()
failures, errors = len(result.failures), len(result.errors)
if not self.cython_only:
if not self.cython_only and ext_so_path is not None:
self.run_tests(result, ext_so_path)
if failures == len(result.failures) and errors == len(result.errors):
# No new errors...
......
# cython: optimize.switchcase_transform=True
# mode: error
import cython
cdef extern from "includes/e_switch_transform_support.h":
enum:
ONE
ONE_AGAIN
def is_not_one(int i):
return i != ONE and i != ONE_AGAIN
_ERRORS = u'''
runtime error
'''
#ifndef E_SWITCH_TRANSFORM_SUPPORT_H_
#define E_SWITCH_TRANSFORM_SUPPORT_H_ 1
enum {
ONE=1,
ONE_AGAIN=1
};
#endif /* E_SWITCH_TRANSFORM_SUPPORT_H_ */
#ifndef SWITCH_TRANSFORM_SUPPORT_H_
#define SWITCH_TRANSFORM_SUPPORT_H_ 1
enum {
ONE=1,
ONE_AGAIN=1
};
#endif /* SWITCH_TRANSFORM_SUPPORT_H_ */
# cython: optimize.switchcase_transform=False
cdef extern from "includes/switch_transform_support.h":
enum:
ONE
ONE_AGAIN
def is_not_one(int i):
"""
>>> is_not_one(1)
False
>>> is_not_one(2)
True
"""
return i != ONE and i != ONE_AGAIN
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment