Commit 97e1b964 authored by Stefan Behnel's avatar Stefan Behnel

Reimplement the "assert" statement by delegating the exception raising to a RaiseStatNode.

This allows taking advantage of the automatic "with gil" block handling for raising exceptions, allows proper control flow analysis, etc.
parent b0aaba60
......@@ -908,6 +908,26 @@ class ControlFlowAnalysis(CythonTransform):
self.flow.block = None
return node
def visit_AssertStatNode(self, node):
"""Essentially an if-condition that wraps a RaiseStatNode.
"""
self.mark_position(node)
next_block = self.flow.newblock()
parent = self.flow.block
# failure case
parent = self.flow.nextblock(parent)
self._visit(node.condition)
self.flow.nextblock()
self._visit(node.exception)
if self.flow.block:
self.flow.block.add_child(next_block)
parent.add_child(next_block)
if next_block.parents:
self.flow.block = next_block
else:
self.flow.block = None
return node
def visit_WhileStatNode(self, node):
condition_block = self.flow.nextblock()
next_block = self.flow.newblock()
......
......@@ -6350,6 +6350,8 @@ class RaiseStatNode(StatNode):
child_attrs = ["exc_type", "exc_value", "exc_tb", "cause"]
is_terminator = True
builtin_exc_name = None
wrap_tuple_value = False
def analyse_expressions(self, env):
if self.exc_type:
......@@ -6357,6 +6359,12 @@ class RaiseStatNode(StatNode):
self.exc_type = exc_type.coerce_to_pyobject(env)
if self.exc_value:
exc_value = self.exc_value.analyse_types(env)
if self.wrap_tuple_value:
if exc_value.type is Builtin.tuple_type or not exc_value.type.is_builtin_type:
# prevent tuple values from being interpreted as argument value tuples
from .ExprNodes import TupleNode
exc_value = TupleNode(exc_value.pos, args=[exc_value.coerce_to_pyobject(env)], slow=True)
exc_value = exc_value.analyse_types(env, skip_children=True)
self.exc_value = exc_value.coerce_to_pyobject(env)
if self.exc_tb:
exc_tb = self.exc_tb.analyse_types(env)
......@@ -6365,7 +6373,6 @@ class RaiseStatNode(StatNode):
cause = self.cause.analyse_types(env)
self.cause = cause.coerce_to_pyobject(env)
# special cases for builtin exceptions
self.builtin_exc_name = None
if self.exc_type and not self.exc_value and not self.exc_tb:
exc = self.exc_type
from . import ExprNodes
......@@ -6478,66 +6485,50 @@ class ReraiseStatNode(StatNode):
class AssertStatNode(StatNode):
# assert statement
#
# cond ExprNode
# value ExprNode or None
# condition ExprNode
# value ExprNode or None
# exception (Raise/GIL)StatNode created from 'value' in PostParse transform
child_attrs = ["condition", "value", "exception"]
exception = None
child_attrs = ["cond", "value"]
def analyse_declarations(self, env):
assert self.value is None, "Message should have been replaced in PostParse()"
assert self.exception is not None, "Message should have been replaced in PostParse()"
self.condition.analyse_declarations(env)
self.exception.analyse_declarations(env)
def analyse_expressions(self, env):
self.cond = self.cond.analyse_boolean_expression(env)
if self.value:
value = self.value.analyse_types(env)
if value.type is Builtin.tuple_type or not value.type.is_builtin_type:
# prevent tuple values from being interpreted as argument value tuples
from .ExprNodes import TupleNode
value = TupleNode(value.pos, args=[value], slow=True)
self.value = value.analyse_types(env, skip_children=True).coerce_to_pyobject(env)
else:
self.value = value.coerce_to_pyobject(env)
self.condition = self.condition.analyse_boolean_expression(env)
self.exception = self.exception.analyse_expressions(env)
return self
def generate_execution_code(self, code):
code.putln("#ifndef CYTHON_WITHOUT_ASSERTIONS")
code.putln("if (unlikely(!Py_OptimizeFlag)) {")
code.mark_pos(self.pos)
self.cond.generate_evaluation_code(code)
code.putln(
"if (unlikely(!%s)) {" % self.cond.result())
in_nogil = not code.funcstate.gil_owned
if in_nogil:
# Apparently, evaluating condition and value does not require the GIL,
# but raising the exception now does.
code.put_ensure_gil()
if self.value:
self.value.generate_evaluation_code(code)
code.putln(
"PyErr_SetObject(PyExc_AssertionError, %s);" % self.value.py_result())
self.value.generate_disposal_code(code)
self.value.free_temps(code)
else:
code.putln(
"PyErr_SetNone(PyExc_AssertionError);")
if in_nogil:
code.put_release_ensured_gil()
self.condition.generate_evaluation_code(code)
code.putln(
code.error_goto(self.pos))
"if (unlikely(!%s)) {" % self.condition.result())
self.exception.generate_execution_code(code)
code.putln(
"}")
self.cond.generate_disposal_code(code)
self.cond.free_temps(code)
self.condition.generate_disposal_code(code)
self.condition.free_temps(code)
code.putln(
"}")
code.putln("#else")
# avoid unused labels etc.
code.putln("if ((1)); else %s" % code.error_goto(self.pos, used=False))
code.putln("#endif")
def generate_function_definitions(self, env, code):
self.cond.generate_function_definitions(env, code)
if self.value is not None:
self.value.generate_function_definitions(env, code)
self.condition.generate_function_definitions(env, code)
self.exception.generate_function_definitions(env, code)
def annotate(self, code):
self.cond.annotate(code)
if self.value:
self.value.annotate(code)
self.condition.annotate(code)
self.exception.annotate(code)
class IfStatNode(StatNode):
......
......@@ -346,6 +346,23 @@ class PostParse(ScopeTrackingTransform):
self.visitchildren(node)
return node
def visit_AssertStatNode(self, node):
"""Extract the exception raising into a RaiseStatNode to simplify GIL handling.
"""
if node.exception is None:
node.exception = Nodes.RaiseStatNode(
node.pos,
exc_type=ExprNodes.NameNode(node.pos, name=EncodedString("AssertionError")),
exc_value=node.value,
exc_tb=None,
cause=None,
builtin_exc_name="AssertionError",
wrap_tuple_value=True,
)
node.value = None
self.visitchildren(node)
return node
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
"""Replace rhs items by LetRefNodes if they appear more than once.
......@@ -2856,7 +2873,6 @@ class InjectGilHandling(VisitorTransform, SkipDeclarations):
visit_PrintStatNode = _inject_gil_in_nogil # sadly, not the function
# further candidates:
# def visit_AssertStatNode(self, node): # -> try to keep the condition GIL-free if possible
# def visit_ReraiseStatNode(self, node):
# nogil tracking
......
......@@ -1814,7 +1814,7 @@ def p_assert_statement(s):
value = p_test(s)
else:
value = None
return Nodes.AssertStatNode(pos, cond = cond, value = value)
return Nodes.AssertStatNode(pos, condition=cond, value=value)
statement_terminators = cython.declare(set, set([';', 'NEWLINE', 'EOF']))
......
......@@ -6,13 +6,14 @@ def nontrivial_assert_in_nogil(int a, obj):
# NOK
assert obj
assert a*obj
assert a, f"123{a}xyz"
assert obj, "abc"
# OK
assert a
assert a*a
assert a, "abc"
assert a, u"abc"
assert a, f"123{a}xyz"
_ERRORS = """
......@@ -20,6 +21,5 @@ _ERRORS = """
8:15: Converting to Python object not allowed without gil
8:16: Operation not allowed without gil
8:16: Truth-testing Python object not allowed without gil
9:18: String concatenation not allowed without gil
9:18: String formatting not allowed without gil
9:15: Truth-testing Python object not allowed without gil
"""
......@@ -2,6 +2,10 @@
cimport cython
@cython.test_assert_path_exists(
'//AssertStatNode',
'//AssertStatNode//RaiseStatNode',
)
def f(a, b, int i):
"""
>>> f(1, 2, 1)
......@@ -22,7 +26,9 @@ def f(a, b, int i):
@cython.test_assert_path_exists(
'//AssertStatNode',
'//AssertStatNode//TupleNode')
'//AssertStatNode//RaiseStatNode',
'//AssertStatNode//RaiseStatNode//TupleNode',
)
def g(a, b):
"""
>>> g(1, "works")
......@@ -38,7 +44,9 @@ def g(a, b):
@cython.test_assert_path_exists(
'//AssertStatNode',
'//AssertStatNode//TupleNode')
'//AssertStatNode//RaiseStatNode',
'//AssertStatNode//RaiseStatNode//TupleNode',
)
def g(a, b):
"""
>>> g(1, "works")
......@@ -54,8 +62,9 @@ def g(a, b):
@cython.test_assert_path_exists(
'//AssertStatNode',
'//AssertStatNode//TupleNode',
'//AssertStatNode//TupleNode//TupleNode')
'//AssertStatNode//RaiseStatNode',
'//AssertStatNode//RaiseStatNode//TupleNode',
'//AssertStatNode//RaiseStatNode//TupleNode//TupleNode',)
def assert_with_tuple_arg(a):
"""
>>> assert_with_tuple_arg(True)
......@@ -67,9 +76,12 @@ def assert_with_tuple_arg(a):
@cython.test_assert_path_exists(
'//AssertStatNode')
'//AssertStatNode',
'//AssertStatNode//RaiseStatNode',
)
@cython.test_fail_if_path_exists(
'//AssertStatNode//TupleNode')
'//AssertStatNode//TupleNode',
)
def assert_with_str_arg(a):
"""
>>> assert_with_str_arg(True)
......
......@@ -75,9 +75,8 @@ cpdef int test_raise_in_nogil_func(x) nogil except -1:
@cython.test_assert_path_exists(
"//GILStatNode",
"//GILStatNode//AssertStatNode",
)
@cython.test_fail_if_path_exists(
"//GILStatNode//GILStatNode",
"//GILStatNode//AssertStatNode//GILStatNode",
"//GILStatNode//AssertStatNode//GILStatNode//RaiseStatNode",
)
def assert_in_nogil_section(int x):
"""
......@@ -93,9 +92,8 @@ def assert_in_nogil_section(int x):
@cython.test_assert_path_exists(
"//GILStatNode",
"//GILStatNode//AssertStatNode",
)
@cython.test_fail_if_path_exists(
"//GILStatNode//GILStatNode",
"//GILStatNode//AssertStatNode//GILStatNode",
"//GILStatNode//AssertStatNode//GILStatNode//RaiseStatNode",
)
def assert_in_nogil_section_ustring(int x):
"""
......@@ -111,9 +109,8 @@ def assert_in_nogil_section_ustring(int x):
@cython.test_assert_path_exists(
"//GILStatNode",
"//GILStatNode//AssertStatNode",
)
@cython.test_fail_if_path_exists(
"//GILStatNode//GILStatNode",
"//GILStatNode//AssertStatNode//GILStatNode",
"//GILStatNode//AssertStatNode//GILStatNode//RaiseStatNode",
)
def assert_in_nogil_section_string(int x):
"""
......@@ -126,8 +123,10 @@ def assert_in_nogil_section_string(int x):
assert x, "failed!"
@cython.test_fail_if_path_exists(
"//GILStatNode",
@cython.test_assert_path_exists(
"//AssertStatNode",
"//AssertStatNode//GILStatNode",
"//AssertStatNode//GILStatNode//RaiseStatNode",
)
cpdef int assert_in_nogil_func(int x) nogil except -1:
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment