Commit 7b4079c5 authored by Stefan Behnel's avatar Stefan Behnel

Check some child nodes against the correct nogil context when they are...

Check some child nodes against the correct nogil context when they are actually being evaluated in the outer scope (e.g default arguments or annotations of a nogil function).
parent 04114b9b
......@@ -207,6 +207,9 @@ class Node(object):
# can either contain a single node or a list of nodes. See Visitor.py.
child_attrs = None
# Subset of attributes that are evaluated in the outer scope (e.g. function default arguments).
outer_attrs = None
cf_state = None
# This may be an additional (or 'actual') type that will be checked when
......@@ -222,6 +225,7 @@ class Node(object):
gil_message = "Operation"
nogil_check = None
in_nogil_context = False # For use only during code generation.
def gil_error(self, env=None):
error(self.pos, "%s not allowed without gil" % self.gil_message)
......@@ -848,6 +852,7 @@ class CArgDeclNode(Node):
# is_dynamic boolean Non-literal arg stored inside CyFunction
child_attrs = ["base_type", "declarator", "default", "annotation"]
outer_attrs = ["default", "annotation"]
is_self_arg = 0
is_type_arg = 0
......@@ -1680,10 +1685,6 @@ class FuncDefNode(StatNode, BlockNode):
return None
if not env.directives['annotation_typing'] or annotation.analyse_as_type(env) is None:
annotation = annotation.analyse_types(env)
elif isinstance(self, CFuncDefNode):
# Discard invisible type annotations from cdef functions after applying them,
# as they might get in the way of @nogil declarations etc.
return None
return annotation
def analyse_annotations(self, env):
......@@ -2741,6 +2742,7 @@ class DefNode(FuncDefNode):
# decorator_indirection IndirectionNode Used to remove __Pyx_Method_ClassMethod for fused functions
child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators", "return_type_annotation"]
outer_attrs = ["decorators", "return_type_annotation"]
is_staticmethod = False
is_classmethod = False
......
......@@ -2861,24 +2861,33 @@ class GilCheck(VisitorTransform):
self.nogil_declarator_only = False
return super(GilCheck, self).__call__(root)
def _visit_scoped_children(self, node, gil_state):
was_nogil = self.nogil
outer_attrs = node.outer_attrs
if outer_attrs and len(self.env_stack) > 1:
self.nogil = self.env_stack[-2].nogil
self.visitchildren(node, outer_attrs)
self.nogil = gil_state
self.visitchildren(node, exclude=outer_attrs)
self.nogil = was_nogil
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
was_nogil = self.nogil
self.nogil = node.local_scope.nogil
inner_nogil = node.local_scope.nogil
if self.nogil:
if inner_nogil:
self.nogil_declarator_only = True
if self.nogil and node.nogil_check:
if inner_nogil and node.nogil_check:
node.nogil_check(node.local_scope)
self.visitchildren(node)
self._visit_scoped_children(node, inner_nogil)
# This cannot be nested, so it doesn't need backup/restore
self.nogil_declarator_only = False
self.env_stack.pop()
self.nogil = was_nogil
return node
def visit_GILStatNode(self, node):
......@@ -2886,9 +2895,9 @@ class GilCheck(VisitorTransform):
node.nogil_check()
was_nogil = self.nogil
self.nogil = (node.state == 'nogil')
is_nogil = (node.state == 'nogil')
if was_nogil == self.nogil and not self.nogil_declarator_only:
if was_nogil == is_nogil and not self.nogil_declarator_only:
if not was_nogil:
error(node.pos, "Trying to acquire the GIL while it is "
"already held.")
......@@ -2901,8 +2910,7 @@ class GilCheck(VisitorTransform):
# which is wrapped in a StatListNode. Just unpack that.
node.finally_clause, = node.finally_clause.stats
self.visitchildren(node)
self.nogil = was_nogil
self._visit_scoped_children(node, is_nogil)
return node
def visit_ParallelRangeNode(self, node):
......@@ -2949,8 +2957,12 @@ class GilCheck(VisitorTransform):
def visit_Node(self, node):
if self.env_stack and self.nogil and node.nogil_check:
node.nogil_check(self.env_stack[-1])
self.visitchildren(node)
node.in_nogil_context = self.nogil
if node.outer_attrs:
self._visit_scoped_children(node, self.nogil)
else:
self.visitchildren(node)
if self.nogil:
node.in_nogil_context = True
return node
......
......@@ -16,8 +16,9 @@ cdef class TreeVisitor:
cdef class VisitorTransform(TreeVisitor):
cdef dict _process_children(self, parent, attrs=*)
cpdef visitchildren(self, parent, attrs=*)
cpdef visitchildren(self, parent, attrs=*, exclude=*)
cdef list _flatten_list(self, list orig_list)
cdef list _select_attrs(self, attrs, exclude)
cdef class CythonTransform(VisitorTransform):
cdef public context
......
......@@ -244,10 +244,16 @@ class VisitorTransform(TreeVisitor):
was not, an exception will be raised. (Typically you want to ensure that you
are within a StatListNode or similar before doing this.)
"""
def visitchildren(self, parent, attrs=None):
def visitchildren(self, parent, attrs=None, exclude=None):
# generic def entry point for calls from Python subclasses
if exclude is not None:
attrs = self._select_attrs(parent.child_attrs if attrs is None else attrs, exclude)
return self._process_children(parent, attrs)
@cython.final
def _select_attrs(self, attrs, exclude):
return [name for name in attrs if name not in exclude]
@cython.final
def _process_children(self, parent, attrs=None):
# fast cdef entry point for calls from Cython subclasses
......
......@@ -34,3 +34,21 @@ def two_dim(a: cython.double[:,:]):
"""
a[0,0] *= 3
return a[0,0], a[0,1], a.ndim
@cython.nogil
@cython.cfunc
def _one_dim_nogil_cfunc(a: cython.double[:]) -> cython.double:
a[0] *= 2
return a[0]
def one_dim_nogil_cfunc(a: cython.double[:]):
"""
>>> a = numpy.ones((10,), numpy.double)
>>> one_dim_nogil_cfunc(a)
2.0
"""
with cython.nogil:
result = _one_dim_nogil_cfunc(a)
return result
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment