Commit 5db23f0e authored by Mark Florisson's avatar Mark Florisson

Nogil try/finally around 'with gil' blocks

parent f0cc8c42
...@@ -1423,6 +1423,11 @@ class CCodeWriter(object): ...@@ -1423,6 +1423,11 @@ class CCodeWriter(object):
self.putln("#endif") self.putln("#endif")
self.putln("Py_UNBLOCK_THREADS") self.putln("Py_UNBLOCK_THREADS")
def declare_gilstate(self):
self.putln("#ifdef WITH_THREAD")
self.putln("PyGILState_STATE _save;")
self.putln("#endif")
# error handling # error handling
def put_error_if_neg(self, pos, value): def put_error_if_neg(self, pos, value):
......
...@@ -5373,6 +5373,8 @@ class TryFinallyStatNode(StatNode): ...@@ -5373,6 +5373,8 @@ class TryFinallyStatNode(StatNode):
# continue in the try block, since we have no problem # continue in the try block, since we have no problem
# handling it. # handling it.
is_try_finally_in_nogil = False
def create_analysed(pos, env, body, finally_clause): def create_analysed(pos, env, body, finally_clause):
node = TryFinallyStatNode(pos, body=body, finally_clause=finally_clause) node = TryFinallyStatNode(pos, body=body, finally_clause=finally_clause)
return node return node
...@@ -5404,20 +5406,24 @@ class TryFinallyStatNode(StatNode): ...@@ -5404,20 +5406,24 @@ class TryFinallyStatNode(StatNode):
if not self.handle_error_case: if not self.handle_error_case:
code.error_label = old_error_label code.error_label = old_error_label
catch_label = code.new_label() catch_label = code.new_label()
code.putln(
"/*try:*/ {") code.putln("/*try:*/ {")
if self.disallow_continue_in_try_finally: if self.disallow_continue_in_try_finally:
was_in_try_finally = code.funcstate.in_try_finally was_in_try_finally = code.funcstate.in_try_finally
code.funcstate.in_try_finally = 1 code.funcstate.in_try_finally = 1
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
if self.disallow_continue_in_try_finally: if self.disallow_continue_in_try_finally:
code.funcstate.in_try_finally = was_in_try_finally code.funcstate.in_try_finally = was_in_try_finally
code.putln(
"}") code.putln("}")
temps_to_clean_up = code.funcstate.all_free_managed_temps() temps_to_clean_up = code.funcstate.all_free_managed_temps()
code.mark_pos(self.finally_clause.pos) code.mark_pos(self.finally_clause.pos)
code.putln( code.putln("/*finally:*/ {")
"/*finally:*/ {")
cases_used = [] cases_used = []
error_label_used = 0 error_label_used = 0
for i, new_label in enumerate(new_labels): for i, new_label in enumerate(new_labels):
...@@ -5426,22 +5432,25 @@ class TryFinallyStatNode(StatNode): ...@@ -5426,22 +5432,25 @@ class TryFinallyStatNode(StatNode):
if new_label == new_error_label: if new_label == new_error_label:
error_label_used = 1 error_label_used = 1
error_label_case = i error_label_case = i
if cases_used: if cases_used:
code.putln( code.putln("int __pyx_why;")
"int __pyx_why;")
if error_label_used and self.preserve_exception: if error_label_used and self.preserve_exception:
code.putln( code.putln("PyObject *%s, *%s, *%s;" % Naming.exc_vars)
"PyObject *%s, *%s, *%s;" % Naming.exc_vars) code.putln("int %s;" % Naming.exc_lineno_name)
code.putln( exc_var_init_zero = ''.join(
"int %s;" % Naming.exc_lineno_name) ["%s = 0; " % var for var in Naming.exc_vars])
exc_var_init_zero = ''.join(["%s = 0; " % var for var in Naming.exc_vars])
exc_var_init_zero += '%s = 0;' % Naming.exc_lineno_name exc_var_init_zero += '%s = 0;' % Naming.exc_lineno_name
code.putln(exc_var_init_zero) code.putln(exc_var_init_zero)
if self.is_try_finally_in_nogil:
code.declare_gilstate()
else: else:
exc_var_init_zero = None exc_var_init_zero = None
code.use_label(catch_label) code.use_label(catch_label)
code.putln( code.putln("__pyx_why = 0; goto %s;" % catch_label)
"__pyx_why = 0; goto %s;" % catch_label)
for i in cases_used: for i in cases_used:
new_label = new_labels[i] new_label = new_labels[i]
#if new_label and new_label != "<try>": #if new_label and new_label != "<try>":
...@@ -5452,19 +5461,20 @@ class TryFinallyStatNode(StatNode): ...@@ -5452,19 +5461,20 @@ class TryFinallyStatNode(StatNode):
code.put('%s: ' % new_label) code.put('%s: ' % new_label)
if exc_var_init_zero: if exc_var_init_zero:
code.putln(exc_var_init_zero) code.putln(exc_var_init_zero)
code.putln("__pyx_why = %s; goto %s;" % ( code.putln("__pyx_why = %s; goto %s;" % (i+1, catch_label))
i+1,
catch_label))
code.put_label(catch_label) code.put_label(catch_label)
code.set_all_labels(old_labels) code.set_all_labels(old_labels)
if error_label_used: if error_label_used:
code.new_error_label() code.new_error_label()
finally_error_label = code.error_label finally_error_label = code.error_label
self.finally_clause.generate_execution_code(code) self.finally_clause.generate_execution_code(code)
if error_label_used: if error_label_used:
if finally_error_label in code.labels_used and self.preserve_exception: if finally_error_label in code.labels_used and self.preserve_exception:
over_label = code.new_label() over_label = code.new_label()
code.put_goto(over_label); code.put_goto(over_label)
code.put_label(finally_error_label) code.put_label(finally_error_label)
code.putln("if (__pyx_why == %d) {" % (error_label_case + 1)) code.putln("if (__pyx_why == %d) {" % (error_label_case + 1))
for var in Naming.exc_vars: for var in Naming.exc_vars:
...@@ -5472,81 +5482,87 @@ class TryFinallyStatNode(StatNode): ...@@ -5472,81 +5482,87 @@ class TryFinallyStatNode(StatNode):
code.putln("}") code.putln("}")
code.put_goto(old_error_label) code.put_goto(old_error_label)
code.put_label(over_label) code.put_label(over_label)
code.error_label = old_error_label code.error_label = old_error_label
if cases_used: if cases_used:
code.putln( code.putln("switch (__pyx_why) {")
"switch (__pyx_why) {")
for i in cases_used: for i in cases_used:
old_label = old_labels[i] old_label = old_labels[i]
if old_label == old_error_label and self.preserve_exception: if old_label == old_error_label and self.preserve_exception:
self.put_error_uncatcher(code, i+1, old_error_label) self.put_error_uncatcher(code, i+1, old_error_label)
else: else:
code.use_label(old_label) code.use_label(old_label)
code.putln( code.putln("case %s: goto %s;" % (i+1, old_label))
"case %s: goto %s;" % (
i+1, code.putln("}")
old_label)) code.putln("}")
code.putln(
"}")
code.putln(
"}")
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code) self.body.generate_function_definitions(env, code)
self.finally_clause.generate_function_definitions(env, code) self.finally_clause.generate_function_definitions(env, code)
def put_error_catcher(self, code, error_label, i, catch_label, temps_to_clean_up): def put_error_catcher(self, code, error_label, i, catch_label,
temps_to_clean_up):
code.globalstate.use_utility_code(restore_exception_utility_code) code.globalstate.use_utility_code(restore_exception_utility_code)
code.putln( code.putln("%s: {" % error_label)
"%s: {" % code.putln("__pyx_why = %s;" % i)
error_label)
code.putln( if self.is_try_finally_in_nogil:
"__pyx_why = %s;" % code.put_ensure_gil(declare_gilstate=False)
i)
for temp_name, type in temps_to_clean_up: for temp_name, type in temps_to_clean_up:
code.put_xdecref_clear(temp_name, type) code.put_xdecref_clear(temp_name, type)
code.putln(
"__Pyx_ErrFetch(&%s, &%s, &%s);" % code.putln("__Pyx_ErrFetch(&%s, &%s, &%s);" % Naming.exc_vars)
Naming.exc_vars) code.putln("%s = %s;" % (Naming.exc_lineno_name, Naming.lineno_cname))
code.putln(
"%s = %s;" % ( if self.is_try_finally_in_nogil:
Naming.exc_lineno_name, Naming.lineno_cname)) code.put_release_ensured_gil()
code.put_goto(catch_label) code.put_goto(catch_label)
code.putln("}") code.putln("}")
def put_error_uncatcher(self, code, i, error_label): def put_error_uncatcher(self, code, i, error_label):
code.globalstate.use_utility_code(restore_exception_utility_code) code.globalstate.use_utility_code(restore_exception_utility_code)
code.putln( code.putln("case %s: {" % i)
"case %s: {" %
i) if self.is_try_finally_in_nogil:
code.putln( code.put_ensure_gil(declare_gilstate=False)
"__Pyx_ErrRestore(%s, %s, %s);" %
Naming.exc_vars) code.putln("__Pyx_ErrRestore(%s, %s, %s);" % Naming.exc_vars)
code.putln( code.putln("%s = %s;" % (Naming.lineno_cname, Naming.exc_lineno_name))
"%s = %s;" % (
Naming.lineno_cname, Naming.exc_lineno_name)) if self.is_try_finally_in_nogil:
code.put_release_ensured_gil()
for var in Naming.exc_vars: for var in Naming.exc_vars:
code.putln( code.putln("%s = 0;" % var)
"%s = 0;" %
var)
code.put_goto(error_label) code.put_goto(error_label)
code.putln( code.putln("}")
"}")
def annotate(self, code): def annotate(self, code):
self.body.annotate(code) self.body.annotate(code)
self.finally_clause.annotate(code) self.finally_clause.annotate(code)
class GILStatNode(TryFinallyStatNode): class NogilTryFinallyStatNode(TryFinallyStatNode):
"""
A try/finally statement that may be used in nogil code sections.
"""
preserve_exception = False
nogil_check = None
class GILStatNode(NogilTryFinallyStatNode):
# 'with gil' or 'with nogil' statement # 'with gil' or 'with nogil' statement
# #
# state string 'gil' or 'nogil' # state string 'gil' or 'nogil'
# child_attrs = [] # child_attrs = []
preserve_exception = 0
def __init__(self, pos, state, body): def __init__(self, pos, state, body):
self.state = state self.state = state
TryFinallyStatNode.__init__(self, pos, TryFinallyStatNode.__init__(self, pos,
...@@ -5557,6 +5573,7 @@ class GILStatNode(TryFinallyStatNode): ...@@ -5557,6 +5573,7 @@ class GILStatNode(TryFinallyStatNode):
env._in_with_gil_block = (self.state == 'gil') env._in_with_gil_block = (self.state == 'gil')
if self.state == 'gil': if self.state == 'gil':
env.has_with_gil_block = True env.has_with_gil_block = True
return super(GILStatNode, self).analyse_declarations(env) return super(GILStatNode, self).analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
...@@ -5566,11 +5583,10 @@ class GILStatNode(TryFinallyStatNode): ...@@ -5566,11 +5583,10 @@ class GILStatNode(TryFinallyStatNode):
TryFinallyStatNode.analyse_expressions(self, env) TryFinallyStatNode.analyse_expressions(self, env)
env.nogil = was_nogil env.nogil = was_nogil
nogil_check = None
def generate_execution_code(self, code): def generate_execution_code(self, code):
code.mark_pos(self.pos) code.mark_pos(self.pos)
code.begin_block() code.begin_block()
if self.state == 'gil': if self.state == 'gil':
code.put_ensure_gil() code.put_ensure_gil()
else: else:
...@@ -5581,9 +5597,11 @@ class GILStatNode(TryFinallyStatNode): ...@@ -5581,9 +5597,11 @@ class GILStatNode(TryFinallyStatNode):
class GILExitNode(StatNode): class GILExitNode(StatNode):
# Used as the 'finally' block in a GILStatNode """
# Used as the 'finally' block in a GILStatNode
# state string 'gil' or 'nogil'
state string 'gil' or 'nogil'
"""
child_attrs = [] child_attrs = []
......
...@@ -1365,15 +1365,14 @@ if VALUE is not None: ...@@ -1365,15 +1365,14 @@ if VALUE is not None:
node.body.analyse_declarations(lenv) node.body.analyse_declarations(lenv)
if lenv.nogil and lenv.has_with_gil_block: if lenv.nogil and lenv.has_with_gil_block:
# Acquire the GIL for cleanup in 'nogil' functions. The # Acquire the GIL for cleanup in 'nogil' functions, by wrapping
# corresponding release will be taken care of by # the entire function body in try/finally.
# The corresponding release will be taken care of by
# Nodes.FuncDefNode.generate_function_definitions() # Nodes.FuncDefNode.generate_function_definitions()
node.body = Nodes.TryFinallyStatNode( node.body = Nodes.NogilTryFinallyStatNode(
node.body.pos, node.body.pos,
body = node.body, body = node.body,
finally_clause = Nodes.EnsureGILNode(node.body.pos), finally_clause = Nodes.EnsureGILNode(node.body.pos),
preserve_exception = False,
nogil_check = None,
) )
self.env_stack.append(lenv) self.env_stack.append(lenv)
...@@ -1547,6 +1546,7 @@ if VALUE is not None: ...@@ -1547,6 +1546,7 @@ if VALUE is not None:
# --------------------------------------- # ---------------------------------------
return property return property
class AnalyseExpressionsTransform(CythonTransform): class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
...@@ -1568,6 +1568,7 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1568,6 +1568,7 @@ class AnalyseExpressionsTransform(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
class ExpandInplaceOperators(EnvTransform): class ExpandInplaceOperators(EnvTransform):
def visit_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
...@@ -1942,7 +1943,11 @@ class GilCheck(VisitorTransform): ...@@ -1942,7 +1943,11 @@ class GilCheck(VisitorTransform):
Call `node.gil_check(env)` on each node to make sure we hold the Call `node.gil_check(env)` on each node to make sure we hold the
GIL when we need it. Raise an error when on Python operations GIL when we need it. Raise an error when on Python operations
inside a `nogil` environment. inside a `nogil` environment.
Additionally, raise exceptions for closely nested with gil or with nogil
statements. The latter would abort Python.
""" """
def __call__(self, root): def __call__(self, root):
self.env_stack = [root.scope] self.env_stack = [root.scope]
self.nogil = False self.nogil = False
...@@ -1987,6 +1992,14 @@ class GilCheck(VisitorTransform): ...@@ -1987,6 +1992,14 @@ class GilCheck(VisitorTransform):
error(node.pos, "Trying to release the GIL while it was " error(node.pos, "Trying to release the GIL while it was "
"previously released.") "previously released.")
if isinstance(node.finally_clause, Nodes.StatListNode):
# The finally clause of the GILStatNode is a GILExitNode,
# which is wrapped in a StatListNode. Just unpack that.
node.finally_clause, = node.finally_clause.stats
if node.state == 'gil':
self.seen_with_gil_statement = True
self.visitchildren(node) self.visitchildren(node)
self.nogil = was_nogil self.nogil = was_nogil
return node return node
...@@ -2018,6 +2031,27 @@ class GilCheck(VisitorTransform): ...@@ -2018,6 +2031,27 @@ class GilCheck(VisitorTransform):
node.nogil_check(self.env_stack[-1]) node.nogil_check(self.env_stack[-1])
self.visitchildren(node) self.visitchildren(node)
def visit_TryFinallyStatNode(self, node):
"""
Take care of try/finally statements in nogil code sections. The
'try' must contain a 'with gil:' statement somewhere.
"""
if not self.nogil or isinstance(node, Nodes.GILStatNode):
return self.visit_Node(node)
node.nogil_check = None
node.is_try_finally_in_nogil = True
# First, visit the body and check for errors
self.seen_with_gil_statement = False
self.visitchildren(node.body)
if not self.seen_with_gil_statement:
error(node.pos, "Cannot use try/finally in nogil sections unless "
"it contains a 'with gil' statement.")
self.visitchildren(node.finally_clause)
return node return node
def visit_Node(self, node): def visit_Node(self, node):
......
...@@ -123,6 +123,22 @@ def test_try_finally_and_outer_except(): ...@@ -123,6 +123,22 @@ def test_try_finally_and_outer_except():
print "End of function" print "End of function"
def test_restore_exception():
"""
>>> test_restore_exception()
Traceback (most recent call last):
...
Exception: Override the raised exception
"""
with nogil:
with gil:
try:
with nogil:
with gil:
raise Exception("Override this please")
finally:
raise Exception("Override the raised exception")
def test_declared_variables(): def test_declared_variables():
""" """
>>> test_declared_variables() >>> test_declared_variables()
...@@ -299,3 +315,75 @@ def test_release_gil_call_gil_func(): ...@@ -299,3 +315,75 @@ def test_release_gil_call_gil_func():
with nogil: with nogil:
with gil: with gil:
with_gil_raise() with_gil_raise()
# Test try/finally in nogil blocks
def test_try_finally_in_nogil():
"""
>>> test_try_finally_in_nogil()
Traceback (most recent call last):
...
Exception: Override exception!
"""
with nogil:
try:
with gil:
raise Exception("This will be overridden")
finally:
with gil:
raise Exception("Override exception!")
with gil:
raise Exception("This code should not be executed!")
def test_nogil_try_finally_no_exception():
"""
>>> test_nogil_try_finally_no_exception()
first nogil try
nogil try gil
second nogil try
nogil finally
------
First with gil block
Second with gil block
finally block
"""
with nogil:
try:
puts("first nogil try")
with gil:
print "nogil try gil"
puts("second nogil try")
finally:
puts("nogil finally")
print '------'
with nogil:
try:
with gil:
print "First with gil block"
with gil:
print "Second with gil block"
finally:
puts("finally block")
def test_nogil_try_finally_propagate_exception():
"""
>>> test_nogil_try_finally_propagate_exception()
Execute finally clause
Propagate this!
"""
try:
with nogil:
try:
with gil:
raise Exception("Propagate this!")
with gil:
raise Exception("Don't reach this section!")
finally:
puts("Execute finally clause")
except Exception, e:
print e
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment