Commit 97e1b964 authored by Stefan Behnel's avatar Stefan Behnel

Reimplement the "assert" statement by delegating the exception raising to a RaiseStatNode.

This allows taking advantage of the automatic "with gil" block handling for raising exceptions, allows proper control flow analysis, etc.
parent b0aaba60
...@@ -908,6 +908,26 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -908,6 +908,26 @@ class ControlFlowAnalysis(CythonTransform):
self.flow.block = None self.flow.block = None
return node return node
def visit_AssertStatNode(self, node):
"""Essentially an if-condition that wraps a RaiseStatNode.
"""
self.mark_position(node)
next_block = self.flow.newblock()
parent = self.flow.block
# failure case
parent = self.flow.nextblock(parent)
self._visit(node.condition)
self.flow.nextblock()
self._visit(node.exception)
if self.flow.block:
self.flow.block.add_child(next_block)
parent.add_child(next_block)
if next_block.parents:
self.flow.block = next_block
else:
self.flow.block = None
return node
def visit_WhileStatNode(self, node): def visit_WhileStatNode(self, node):
condition_block = self.flow.nextblock() condition_block = self.flow.nextblock()
next_block = self.flow.newblock() next_block = self.flow.newblock()
......
...@@ -6350,6 +6350,8 @@ class RaiseStatNode(StatNode): ...@@ -6350,6 +6350,8 @@ class RaiseStatNode(StatNode):
child_attrs = ["exc_type", "exc_value", "exc_tb", "cause"] child_attrs = ["exc_type", "exc_value", "exc_tb", "cause"]
is_terminator = True is_terminator = True
builtin_exc_name = None
wrap_tuple_value = False
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.exc_type: if self.exc_type:
...@@ -6357,6 +6359,12 @@ class RaiseStatNode(StatNode): ...@@ -6357,6 +6359,12 @@ class RaiseStatNode(StatNode):
self.exc_type = exc_type.coerce_to_pyobject(env) self.exc_type = exc_type.coerce_to_pyobject(env)
if self.exc_value: if self.exc_value:
exc_value = self.exc_value.analyse_types(env) exc_value = self.exc_value.analyse_types(env)
if self.wrap_tuple_value:
if exc_value.type is Builtin.tuple_type or not exc_value.type.is_builtin_type:
# prevent tuple values from being interpreted as argument value tuples
from .ExprNodes import TupleNode
exc_value = TupleNode(exc_value.pos, args=[exc_value.coerce_to_pyobject(env)], slow=True)
exc_value = exc_value.analyse_types(env, skip_children=True)
self.exc_value = exc_value.coerce_to_pyobject(env) self.exc_value = exc_value.coerce_to_pyobject(env)
if self.exc_tb: if self.exc_tb:
exc_tb = self.exc_tb.analyse_types(env) exc_tb = self.exc_tb.analyse_types(env)
...@@ -6365,7 +6373,6 @@ class RaiseStatNode(StatNode): ...@@ -6365,7 +6373,6 @@ class RaiseStatNode(StatNode):
cause = self.cause.analyse_types(env) cause = self.cause.analyse_types(env)
self.cause = cause.coerce_to_pyobject(env) self.cause = cause.coerce_to_pyobject(env)
# special cases for builtin exceptions # special cases for builtin exceptions
self.builtin_exc_name = None
if self.exc_type and not self.exc_value and not self.exc_tb: if self.exc_type and not self.exc_value and not self.exc_tb:
exc = self.exc_type exc = self.exc_type
from . import ExprNodes from . import ExprNodes
...@@ -6478,66 +6485,50 @@ class ReraiseStatNode(StatNode): ...@@ -6478,66 +6485,50 @@ class ReraiseStatNode(StatNode):
class AssertStatNode(StatNode): class AssertStatNode(StatNode):
# assert statement # assert statement
# #
# cond ExprNode # condition ExprNode
# value ExprNode or None # value ExprNode or None
# exception (Raise/GIL)StatNode created from 'value' in PostParse transform
child_attrs = ["condition", "value", "exception"]
exception = None
child_attrs = ["cond", "value"] def analyse_declarations(self, env):
assert self.value is None, "Message should have been replaced in PostParse()"
assert self.exception is not None, "Message should have been replaced in PostParse()"
self.condition.analyse_declarations(env)
self.exception.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.cond = self.cond.analyse_boolean_expression(env) self.condition = self.condition.analyse_boolean_expression(env)
if self.value: self.exception = self.exception.analyse_expressions(env)
value = self.value.analyse_types(env)
if value.type is Builtin.tuple_type or not value.type.is_builtin_type:
# prevent tuple values from being interpreted as argument value tuples
from .ExprNodes import TupleNode
value = TupleNode(value.pos, args=[value], slow=True)
self.value = value.analyse_types(env, skip_children=True).coerce_to_pyobject(env)
else:
self.value = value.coerce_to_pyobject(env)
return self return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
code.putln("#ifndef CYTHON_WITHOUT_ASSERTIONS") code.putln("#ifndef CYTHON_WITHOUT_ASSERTIONS")
code.putln("if (unlikely(!Py_OptimizeFlag)) {") code.putln("if (unlikely(!Py_OptimizeFlag)) {")
code.mark_pos(self.pos) code.mark_pos(self.pos)
self.cond.generate_evaluation_code(code) self.condition.generate_evaluation_code(code)
code.putln(
"if (unlikely(!%s)) {" % self.cond.result())
in_nogil = not code.funcstate.gil_owned
if in_nogil:
# Apparently, evaluating condition and value does not require the GIL,
# but raising the exception now does.
code.put_ensure_gil()
if self.value:
self.value.generate_evaluation_code(code)
code.putln(
"PyErr_SetObject(PyExc_AssertionError, %s);" % self.value.py_result())
self.value.generate_disposal_code(code)
self.value.free_temps(code)
else:
code.putln(
"PyErr_SetNone(PyExc_AssertionError);")
if in_nogil:
code.put_release_ensured_gil()
code.putln( code.putln(
code.error_goto(self.pos)) "if (unlikely(!%s)) {" % self.condition.result())
self.exception.generate_execution_code(code)
code.putln( code.putln(
"}") "}")
self.cond.generate_disposal_code(code) self.condition.generate_disposal_code(code)
self.cond.free_temps(code) self.condition.free_temps(code)
code.putln( code.putln(
"}") "}")
code.putln("#else")
# avoid unused labels etc.
code.putln("if ((1)); else %s" % code.error_goto(self.pos, used=False))
code.putln("#endif") code.putln("#endif")
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.cond.generate_function_definitions(env, code) self.condition.generate_function_definitions(env, code)
if self.value is not None: self.exception.generate_function_definitions(env, code)
self.value.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.cond.annotate(code) self.condition.annotate(code)
if self.value: self.exception.annotate(code)
self.value.annotate(code)
class IfStatNode(StatNode): class IfStatNode(StatNode):
......
...@@ -346,6 +346,23 @@ class PostParse(ScopeTrackingTransform): ...@@ -346,6 +346,23 @@ class PostParse(ScopeTrackingTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_AssertStatNode(self, node):
"""Extract the exception raising into a RaiseStatNode to simplify GIL handling.
"""
if node.exception is None:
node.exception = Nodes.RaiseStatNode(
node.pos,
exc_type=ExprNodes.NameNode(node.pos, name=EncodedString("AssertionError")),
exc_value=node.value,
exc_tb=None,
cause=None,
builtin_exc_name="AssertionError",
wrap_tuple_value=True,
)
node.value = None
self.visitchildren(node)
return node
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence): def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
"""Replace rhs items by LetRefNodes if they appear more than once. """Replace rhs items by LetRefNodes if they appear more than once.
...@@ -2856,7 +2873,6 @@ class InjectGilHandling(VisitorTransform, SkipDeclarations): ...@@ -2856,7 +2873,6 @@ class InjectGilHandling(VisitorTransform, SkipDeclarations):
visit_PrintStatNode = _inject_gil_in_nogil # sadly, not the function visit_PrintStatNode = _inject_gil_in_nogil # sadly, not the function
# further candidates: # further candidates:
# def visit_AssertStatNode(self, node): # -> try to keep the condition GIL-free if possible
# def visit_ReraiseStatNode(self, node): # def visit_ReraiseStatNode(self, node):
# nogil tracking # nogil tracking
......
...@@ -1814,7 +1814,7 @@ def p_assert_statement(s): ...@@ -1814,7 +1814,7 @@ def p_assert_statement(s):
value = p_test(s) value = p_test(s)
else: else:
value = None value = None
return Nodes.AssertStatNode(pos, cond = cond, value = value) return Nodes.AssertStatNode(pos, condition=cond, value=value)
statement_terminators = cython.declare(set, set([';', 'NEWLINE', 'EOF'])) statement_terminators = cython.declare(set, set([';', 'NEWLINE', 'EOF']))
......
...@@ -6,13 +6,14 @@ def nontrivial_assert_in_nogil(int a, obj): ...@@ -6,13 +6,14 @@ def nontrivial_assert_in_nogil(int a, obj):
# NOK # NOK
assert obj assert obj
assert a*obj assert a*obj
assert a, f"123{a}xyz" assert obj, "abc"
# OK # OK
assert a assert a
assert a*a assert a*a
assert a, "abc" assert a, "abc"
assert a, u"abc" assert a, u"abc"
assert a, f"123{a}xyz"
_ERRORS = """ _ERRORS = """
...@@ -20,6 +21,5 @@ _ERRORS = """ ...@@ -20,6 +21,5 @@ _ERRORS = """
8:15: Converting to Python object not allowed without gil 8:15: Converting to Python object not allowed without gil
8:16: Operation not allowed without gil 8:16: Operation not allowed without gil
8:16: Truth-testing Python object not allowed without gil 8:16: Truth-testing Python object not allowed without gil
9:18: String concatenation not allowed without gil 9:15: Truth-testing Python object not allowed without gil
9:18: String formatting not allowed without gil
""" """
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
cimport cython cimport cython
@cython.test_assert_path_exists(
'//AssertStatNode',
'//AssertStatNode//RaiseStatNode',
)
def f(a, b, int i): def f(a, b, int i):
""" """
>>> f(1, 2, 1) >>> f(1, 2, 1)
...@@ -22,7 +26,9 @@ def f(a, b, int i): ...@@ -22,7 +26,9 @@ def f(a, b, int i):
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
'//AssertStatNode', '//AssertStatNode',
'//AssertStatNode//TupleNode') '//AssertStatNode//RaiseStatNode',
'//AssertStatNode//RaiseStatNode//TupleNode',
)
def g(a, b): def g(a, b):
""" """
>>> g(1, "works") >>> g(1, "works")
...@@ -38,7 +44,9 @@ def g(a, b): ...@@ -38,7 +44,9 @@ def g(a, b):
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
'//AssertStatNode', '//AssertStatNode',
'//AssertStatNode//TupleNode') '//AssertStatNode//RaiseStatNode',
'//AssertStatNode//RaiseStatNode//TupleNode',
)
def g(a, b): def g(a, b):
""" """
>>> g(1, "works") >>> g(1, "works")
...@@ -54,8 +62,9 @@ def g(a, b): ...@@ -54,8 +62,9 @@ def g(a, b):
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
'//AssertStatNode', '//AssertStatNode',
'//AssertStatNode//TupleNode', '//AssertStatNode//RaiseStatNode',
'//AssertStatNode//TupleNode//TupleNode') '//AssertStatNode//RaiseStatNode//TupleNode',
'//AssertStatNode//RaiseStatNode//TupleNode//TupleNode',)
def assert_with_tuple_arg(a): def assert_with_tuple_arg(a):
""" """
>>> assert_with_tuple_arg(True) >>> assert_with_tuple_arg(True)
...@@ -67,9 +76,12 @@ def assert_with_tuple_arg(a): ...@@ -67,9 +76,12 @@ def assert_with_tuple_arg(a):
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
'//AssertStatNode') '//AssertStatNode',
'//AssertStatNode//RaiseStatNode',
)
@cython.test_fail_if_path_exists( @cython.test_fail_if_path_exists(
'//AssertStatNode//TupleNode') '//AssertStatNode//TupleNode',
)
def assert_with_str_arg(a): def assert_with_str_arg(a):
""" """
>>> assert_with_str_arg(True) >>> assert_with_str_arg(True)
......
...@@ -75,9 +75,8 @@ cpdef int test_raise_in_nogil_func(x) nogil except -1: ...@@ -75,9 +75,8 @@ cpdef int test_raise_in_nogil_func(x) nogil except -1:
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//GILStatNode", "//GILStatNode",
"//GILStatNode//AssertStatNode", "//GILStatNode//AssertStatNode",
) "//GILStatNode//AssertStatNode//GILStatNode",
@cython.test_fail_if_path_exists( "//GILStatNode//AssertStatNode//GILStatNode//RaiseStatNode",
"//GILStatNode//GILStatNode",
) )
def assert_in_nogil_section(int x): def assert_in_nogil_section(int x):
""" """
...@@ -93,9 +92,8 @@ def assert_in_nogil_section(int x): ...@@ -93,9 +92,8 @@ def assert_in_nogil_section(int x):
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//GILStatNode", "//GILStatNode",
"//GILStatNode//AssertStatNode", "//GILStatNode//AssertStatNode",
) "//GILStatNode//AssertStatNode//GILStatNode",
@cython.test_fail_if_path_exists( "//GILStatNode//AssertStatNode//GILStatNode//RaiseStatNode",
"//GILStatNode//GILStatNode",
) )
def assert_in_nogil_section_ustring(int x): def assert_in_nogil_section_ustring(int x):
""" """
...@@ -111,9 +109,8 @@ def assert_in_nogil_section_ustring(int x): ...@@ -111,9 +109,8 @@ def assert_in_nogil_section_ustring(int x):
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//GILStatNode", "//GILStatNode",
"//GILStatNode//AssertStatNode", "//GILStatNode//AssertStatNode",
) "//GILStatNode//AssertStatNode//GILStatNode",
@cython.test_fail_if_path_exists( "//GILStatNode//AssertStatNode//GILStatNode//RaiseStatNode",
"//GILStatNode//GILStatNode",
) )
def assert_in_nogil_section_string(int x): def assert_in_nogil_section_string(int x):
""" """
...@@ -126,8 +123,10 @@ def assert_in_nogil_section_string(int x): ...@@ -126,8 +123,10 @@ def assert_in_nogil_section_string(int x):
assert x, "failed!" assert x, "failed!"
@cython.test_fail_if_path_exists( @cython.test_assert_path_exists(
"//GILStatNode", "//AssertStatNode",
"//AssertStatNode//GILStatNode",
"//AssertStatNode//GILStatNode//RaiseStatNode",
) )
cpdef int assert_in_nogil_func(int x) nogil except -1: cpdef int assert_in_nogil_func(int x) nogil except -1:
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment