Commit a0618964 authored by Stefan Behnel's avatar Stefan Behnel

store GIL state in temp variable to safely keep it across generator yields (fixes ticket 803)

--HG--
extra : rebase_source : 5fd93f1cf2916c987bfc4a4d410bf4a77afb2ac9
parent c70e7c55
...@@ -1817,7 +1817,7 @@ class CCodeWriter(object): ...@@ -1817,7 +1817,7 @@ class CCodeWriter(object):
# GIL methods # GIL methods
def put_ensure_gil(self, declare_gilstate=True): def put_ensure_gil(self, declare_gilstate=True, variable=None):
""" """
Acquire the GIL. The generated code is safe even when no PyThreadState Acquire the GIL. The generated code is safe even when no PyThreadState
has been allocated for this thread (for threads not initialized by has been allocated for this thread (for threads not initialized by
...@@ -1827,32 +1827,42 @@ class CCodeWriter(object): ...@@ -1827,32 +1827,42 @@ class CCodeWriter(object):
self.globalstate.use_utility_code( self.globalstate.use_utility_code(
UtilityCode.load_cached("ForceInitThreads", "ModuleSetupCode.c")) UtilityCode.load_cached("ForceInitThreads", "ModuleSetupCode.c"))
self.putln("#ifdef WITH_THREAD") self.putln("#ifdef WITH_THREAD")
if declare_gilstate: if not variable:
self.put("PyGILState_STATE ") variable = '__pyx_gilstate_save'
self.putln("__pyx_gilstate_save = PyGILState_Ensure();") if declare_gilstate:
self.put("PyGILState_STATE ")
self.putln("%s = PyGILState_Ensure();" % variable)
self.putln("#endif") self.putln("#endif")
def put_release_ensured_gil(self): def put_release_ensured_gil(self, variable=None):
""" """
Releases the GIL, corresponds to `put_ensure_gil`. Releases the GIL, corresponds to `put_ensure_gil`.
""" """
if not variable:
variable = '__pyx_gilstate_save'
self.putln("#ifdef WITH_THREAD") self.putln("#ifdef WITH_THREAD")
self.putln("PyGILState_Release(__pyx_gilstate_save);") self.putln("PyGILState_Release(%s);" % variable)
self.putln("#endif") self.putln("#endif")
def put_acquire_gil(self): def put_acquire_gil(self, variable=None):
""" """
Acquire the GIL. The thread's thread state must have been initialized Acquire the GIL. The thread's thread state must have been initialized
by a previous `put_release_gil` by a previous `put_release_gil`
""" """
self.putln("#ifdef WITH_THREAD")
if variable:
self.putln('_save = %s;' % variable)
self.putln("Py_BLOCK_THREADS") self.putln("Py_BLOCK_THREADS")
self.putln("#endif")
def put_release_gil(self): def put_release_gil(self, variable=None):
"Release the GIL, corresponds to `put_acquire_gil`." "Release the GIL, corresponds to `put_acquire_gil`."
self.putln("#ifdef WITH_THREAD") self.putln("#ifdef WITH_THREAD")
self.putln("PyThreadState *_save = NULL;") self.putln("PyThreadState *_save = NULL;")
self.putln("#endif")
self.putln("Py_UNBLOCK_THREADS") self.putln("Py_UNBLOCK_THREADS")
if variable:
self.putln('%s = _save;' % variable)
self.putln("#endif")
def declare_gilstate(self): def declare_gilstate(self):
self.putln("#ifdef WITH_THREAD") self.putln("#ifdef WITH_THREAD")
......
...@@ -6481,9 +6481,16 @@ class GILStatNode(NogilTryFinallyStatNode): ...@@ -6481,9 +6481,16 @@ class GILStatNode(NogilTryFinallyStatNode):
def __init__(self, pos, state, body): def __init__(self, pos, state, body):
self.state = state self.state = state
if state == 'gil':
temp_type = PyrexTypes.c_gilstate_type
else:
temp_type = PyrexTypes.c_threadstate_ptr_type
import ExprNodes
self.state_temp = ExprNodes.TempNode(pos, temp_type)
TryFinallyStatNode.__init__(self, pos, TryFinallyStatNode.__init__(self, pos,
body = body, body=body,
finally_clause = GILExitNode(pos, state = state)) finally_clause=GILExitNode(
pos, state=state, state_temp=self.state_temp))
def analyse_declarations(self, env): def analyse_declarations(self, env):
env._in_with_gil_block = (self.state == 'gil') env._in_with_gil_block = (self.state == 'gil')
...@@ -6504,13 +6511,16 @@ class GILStatNode(NogilTryFinallyStatNode): ...@@ -6504,13 +6511,16 @@ class GILStatNode(NogilTryFinallyStatNode):
def generate_execution_code(self, code): def generate_execution_code(self, code):
code.mark_pos(self.pos) code.mark_pos(self.pos)
code.begin_block() code.begin_block()
self.state_temp.allocate(code)
if self.state == 'gil': if self.state == 'gil':
code.put_ensure_gil() code.put_ensure_gil(variable=self.state_temp.result())
else: else:
code.put_release_gil() code.put_release_gil(variable=self.state_temp.result())
TryFinallyStatNode.generate_execution_code(self, code) TryFinallyStatNode.generate_execution_code(self, code)
self.state_temp.release(code)
code.end_block() code.end_block()
...@@ -6522,15 +6532,21 @@ class GILExitNode(StatNode): ...@@ -6522,15 +6532,21 @@ class GILExitNode(StatNode):
""" """
child_attrs = [] child_attrs = []
state_temp = None
def analyse_expressions(self, env): def analyse_expressions(self, env):
return self return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
if self.state_temp:
variable = self.state_temp.result()
else:
variable = None
if self.state == 'gil': if self.state == 'gil':
code.put_release_ensured_gil() code.put_release_ensured_gil(variable)
else: else:
code.put_acquire_gil() code.put_acquire_gil(variable)
class EnsureGILNode(GILExitNode): class EnsureGILNode(GILExitNode):
......
...@@ -3453,6 +3453,10 @@ c_py_ssize_t_ptr_type = CPtrType(c_py_ssize_t_type) ...@@ -3453,6 +3453,10 @@ c_py_ssize_t_ptr_type = CPtrType(c_py_ssize_t_type)
c_ssize_t_ptr_type = CPtrType(c_ssize_t_type) c_ssize_t_ptr_type = CPtrType(c_ssize_t_type)
c_size_t_ptr_type = CPtrType(c_size_t_type) c_size_t_ptr_type = CPtrType(c_size_t_type)
# GIL state
c_gilstate_type = CEnumType("PyGILState_STATE", "PyGILState_STATE", True)
c_threadstate_type = CStructOrUnionType("PyThreadState", "struct", None, 1, "PyThreadState")
c_threadstate_ptr_type = CPtrType(c_threadstate_type)
# the Py_buffer type is defined in Builtin.py # the Py_buffer type is defined in Builtin.py
c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer") c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
......
...@@ -318,3 +318,30 @@ def test_lambda(n): ...@@ -318,3 +318,30 @@ def test_lambda(n):
""" """
for i in range(n): for i in range(n):
yield lambda : i yield lambda : i
def test_with_gil_section():
"""
>>> list(test_with_gil_section())
[0, 1, 2]
"""
cdef int i
with nogil:
for i in range(3):
with gil:
yield i
def test_double_with_gil_section():
"""
>>> list(test_double_with_gil_section())
[0, 1, 2, 3]
"""
cdef int i,j
with nogil:
for i in range(2):
with gil:
with nogil:
for j in range(2):
with gil:
yield i*2+j
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment