Commit a0618964 authored by Stefan Behnel's avatar Stefan Behnel

store GIL state in temp variable to safely keep it across generator yields (fixes ticket 803)

--HG--
extra : rebase_source : 5fd93f1cf2916c987bfc4a4d410bf4a77afb2ac9
parent c70e7c55
......@@ -1817,7 +1817,7 @@ class CCodeWriter(object):
# GIL methods
def put_ensure_gil(self, declare_gilstate=True):
def put_ensure_gil(self, declare_gilstate=True, variable=None):
"""
Acquire the GIL. The generated code is safe even when no PyThreadState
has been allocated for this thread (for threads not initialized by
......@@ -1827,32 +1827,42 @@ class CCodeWriter(object):
self.globalstate.use_utility_code(
UtilityCode.load_cached("ForceInitThreads", "ModuleSetupCode.c"))
self.putln("#ifdef WITH_THREAD")
if not variable:
variable = '__pyx_gilstate_save'
if declare_gilstate:
self.put("PyGILState_STATE ")
self.putln("__pyx_gilstate_save = PyGILState_Ensure();")
self.putln("%s = PyGILState_Ensure();" % variable)
self.putln("#endif")
def put_release_ensured_gil(self):
def put_release_ensured_gil(self, variable=None):
"""
Releases the GIL, corresponds to `put_ensure_gil`.
"""
if not variable:
variable = '__pyx_gilstate_save'
self.putln("#ifdef WITH_THREAD")
self.putln("PyGILState_Release(__pyx_gilstate_save);")
self.putln("PyGILState_Release(%s);" % variable)
self.putln("#endif")
def put_acquire_gil(self):
def put_acquire_gil(self, variable=None):
"""
Acquire the GIL. The thread's thread state must have been initialized
by a previous `put_release_gil`
"""
self.putln("#ifdef WITH_THREAD")
if variable:
self.putln('_save = %s;' % variable)
self.putln("Py_BLOCK_THREADS")
self.putln("#endif")
def put_release_gil(self):
def put_release_gil(self, variable=None):
"Release the GIL, corresponds to `put_acquire_gil`."
self.putln("#ifdef WITH_THREAD")
self.putln("PyThreadState *_save = NULL;")
self.putln("#endif")
self.putln("Py_UNBLOCK_THREADS")
if variable:
self.putln('%s = _save;' % variable)
self.putln("#endif")
def declare_gilstate(self):
self.putln("#ifdef WITH_THREAD")
......
......@@ -6481,9 +6481,16 @@ class GILStatNode(NogilTryFinallyStatNode):
def __init__(self, pos, state, body):
self.state = state
if state == 'gil':
temp_type = PyrexTypes.c_gilstate_type
else:
temp_type = PyrexTypes.c_threadstate_ptr_type
import ExprNodes
self.state_temp = ExprNodes.TempNode(pos, temp_type)
TryFinallyStatNode.__init__(self, pos,
body = body,
finally_clause = GILExitNode(pos, state = state))
body=body,
finally_clause=GILExitNode(
pos, state=state, state_temp=self.state_temp))
def analyse_declarations(self, env):
env._in_with_gil_block = (self.state == 'gil')
......@@ -6504,13 +6511,16 @@ class GILStatNode(NogilTryFinallyStatNode):
def generate_execution_code(self, code):
code.mark_pos(self.pos)
code.begin_block()
self.state_temp.allocate(code)
if self.state == 'gil':
code.put_ensure_gil()
code.put_ensure_gil(variable=self.state_temp.result())
else:
code.put_release_gil()
code.put_release_gil(variable=self.state_temp.result())
TryFinallyStatNode.generate_execution_code(self, code)
self.state_temp.release(code)
code.end_block()
......@@ -6522,15 +6532,21 @@ class GILExitNode(StatNode):
"""
child_attrs = []
state_temp = None
def analyse_expressions(self, env):
return self
def generate_execution_code(self, code):
if self.state_temp:
variable = self.state_temp.result()
else:
variable = None
if self.state == 'gil':
code.put_release_ensured_gil()
code.put_release_ensured_gil(variable)
else:
code.put_acquire_gil()
code.put_acquire_gil(variable)
class EnsureGILNode(GILExitNode):
......
......@@ -3453,6 +3453,10 @@ c_py_ssize_t_ptr_type = CPtrType(c_py_ssize_t_type)
c_ssize_t_ptr_type = CPtrType(c_ssize_t_type)
c_size_t_ptr_type = CPtrType(c_size_t_type)
# GIL state
c_gilstate_type = CEnumType("PyGILState_STATE", "PyGILState_STATE", True)
c_threadstate_type = CStructOrUnionType("PyThreadState", "struct", None, 1, "PyThreadState")
c_threadstate_ptr_type = CPtrType(c_threadstate_type)
# the Py_buffer type is defined in Builtin.py
c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
......
......@@ -318,3 +318,30 @@ def test_lambda(n):
"""
for i in range(n):
yield lambda : i
def test_with_gil_section():
"""
>>> list(test_with_gil_section())
[0, 1, 2]
"""
cdef int i
with nogil:
for i in range(3):
with gil:
yield i
def test_double_with_gil_section():
"""
>>> list(test_double_with_gil_section())
[0, 1, 2, 3]
"""
cdef int i,j
with nogil:
for i in range(2):
with gil:
with nogil:
for j in range(2):
with gil:
yield i*2+j
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment