Commit 40417d36 authored by Mark Florisson's avatar Mark Florisson

Allocate thread-state for each OpenMP thread

Save exc_info in parallel section and re-raise outside in master thread
parent 04a8c7fe
...@@ -1400,7 +1400,7 @@ class CCodeWriter(object): ...@@ -1400,7 +1400,7 @@ class CCodeWriter(object):
self.putln("#ifdef WITH_THREAD") self.putln("#ifdef WITH_THREAD")
if declare_gilstate: if declare_gilstate:
self.put("PyGILState_STATE ") self.put("PyGILState_STATE ")
self.putln("_save = PyGILState_Ensure();") self.putln("__pyx_gilstate_save = PyGILState_Ensure();")
self.putln("#endif") self.putln("#endif")
def put_release_ensured_gil(self): def put_release_ensured_gil(self):
...@@ -1408,7 +1408,7 @@ class CCodeWriter(object): ...@@ -1408,7 +1408,7 @@ class CCodeWriter(object):
Releases the GIL, corresponds to `put_ensure_gil`. Releases the GIL, corresponds to `put_ensure_gil`.
""" """
self.putln("#ifdef WITH_THREAD") self.putln("#ifdef WITH_THREAD")
self.putln("PyGILState_Release(_save);") self.putln("PyGILState_Release(__pyx_gilstate_save);")
self.putln("#endif") self.putln("#endif")
def put_acquire_gil(self): def put_acquire_gil(self):
...@@ -1427,7 +1427,7 @@ class CCodeWriter(object): ...@@ -1427,7 +1427,7 @@ class CCodeWriter(object):
def declare_gilstate(self): def declare_gilstate(self):
self.putln("#ifdef WITH_THREAD") self.putln("#ifdef WITH_THREAD")
self.putln("PyGILState_STATE _save;") self.putln("PyGILState_STATE __pyx_gilstate_save;")
self.putln("#endif") self.putln("#endif")
# error handling # error handling
......
...@@ -108,6 +108,14 @@ exc_value_name = pyrex_prefix + "exc_value" ...@@ -108,6 +108,14 @@ exc_value_name = pyrex_prefix + "exc_value"
exc_tb_name = pyrex_prefix + "exc_tb" exc_tb_name = pyrex_prefix + "exc_tb"
exc_lineno_name = pyrex_prefix + "exc_lineno" exc_lineno_name = pyrex_prefix + "exc_lineno"
parallel_exc_type = pyrex_prefix + "parallel_exc_type"
parallel_exc_value = pyrex_prefix + "parallel_exc_value"
parallel_exc_tb = pyrex_prefix + "parallel_exc_tb"
parallel_filename = pyrex_prefix + "parallel_filename"
parallel_lineno = pyrex_prefix + "parallel_lineno"
parallel_clineno = pyrex_prefix + "parallel_clineno"
parallel_why = pyrex_prefix + "parallel_why"
exc_vars = (exc_type_name, exc_value_name, exc_tb_name) exc_vars = (exc_type_name, exc_value_name, exc_tb_name)
api_name = pyrex_prefix + "capi__" api_name = pyrex_prefix + "capi__"
......
...@@ -26,6 +26,7 @@ from Code import UtilityCode, ClosureTempAllocator ...@@ -26,6 +26,7 @@ from Code import UtilityCode, ClosureTempAllocator
from StringEncoding import EncodedString, escape_byte_string, split_string_literal from StringEncoding import EncodedString, escape_byte_string, split_string_literal
import Options import Options
import DebugFlags import DebugFlags
from itertools import chain
absolute_path_length = 0 absolute_path_length = 0
...@@ -5845,12 +5846,25 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5845,12 +5846,25 @@ class ParallelStatNode(StatNode, ParallelNode):
is_prange = False is_prange = False
# Labels to break out of parallel constructs
break_label_used = False
continue_label_used = False
error_label_used = False error_label_used = False
return_label_used = False
parallel_exc = (
Naming.parallel_exc_type,
Naming.parallel_exc_value,
Naming.parallel_exc_tb,
)
parallel_pos_info = (
Naming.parallel_filename,
Naming.parallel_lineno,
Naming.parallel_clineno,
)
pos_info = (
Naming.filename_cname,
Naming.lineno_cname,
Naming.clineno_cname,
)
def __init__(self, pos, **kwargs): def __init__(self, pos, **kwargs):
super(ParallelStatNode, self).__init__(pos, **kwargs) super(ParallelStatNode, self).__init__(pos, **kwargs)
...@@ -6094,11 +6108,17 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6094,11 +6108,17 @@ class ParallelStatNode(StatNode, ParallelNode):
if not type.is_pyobject: if not type.is_pyobject:
temps.append(cname) temps.append(cname)
c = self.privatization_insertion_point
if temps: if temps:
c = self.privatization_insertion_point
c.put(" private(%s)" % ", ".join(temps)) c.put(" private(%s)" % ", ".join(temps))
if self.breaking_label_used:
c.put(" shared(__pyx_parallel_why)") if self.breaking_label_used:
shared_vars = [Naming.parallel_why]
if self.error_label_used:
shared_vars.extend(self.parallel_exc)
c.put(" private(%s, %s, %s)" % self.pos_info)
c.put(" shared(%s)" % ', '.join(shared_vars))
def setup_parallel_control_flow_block(self, code): def setup_parallel_control_flow_block(self, code):
""" """
...@@ -6134,6 +6154,29 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6134,6 +6154,29 @@ class ParallelStatNode(StatNode, ParallelNode):
code.begin_block() # parallel control flow block code.begin_block() # parallel control flow block
self.begin_of_parallel_control_block_point = code.insertion_point() self.begin_of_parallel_control_block_point = code.insertion_point()
def begin_parallel_block(self, code):
"""
Each OpenMP thread in a parallel section that contains a with gil block
must have the thread-state initialized. The call to
PyGILState_Release() then deallocates our threadstate. If we wouldn't
do this, each with gil block would allocate and deallocate one, thereby
losing exception information before it can be saved before leaving the
parallel section.
"""
self.begin_of_parallel_block = code.insertion_point()
def end_parallel_block(self, code):
"Acquire the GIL, deallocate threadstate, release"
if self.error_label_used:
begin_code = self.begin_of_parallel_block
end_code = code
begin_code.put_ensure_gil(declare_gilstate=True)
begin_code.putln("Py_BEGIN_ALLOW_THREADS")
end_code.putln("Py_END_ALLOW_THREADS")
end_code.put_release_ensured_gil()
def trap_parallel_exit(self, code, should_flush=False): def trap_parallel_exit(self, code, should_flush=False):
""" """
Trap any kind of return inside a parallel construct. 'should_flush' Trap any kind of return inside a parallel construct. 'should_flush'
...@@ -6153,6 +6196,7 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6153,6 +6196,7 @@ class ParallelStatNode(StatNode, ParallelNode):
self.any_label_used = False self.any_label_used = False
self.breaking_label_used = False self.breaking_label_used = False
self.error_label_used = False
for i, label in enumerate(code.get_all_labels()): for i, label in enumerate(code.get_all_labels()):
if code.label_used(label): if code.label_used(label):
...@@ -6164,7 +6208,11 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6164,7 +6208,11 @@ class ParallelStatNode(StatNode, ParallelNode):
code.put_label(label) code.put_label(label)
if not (should_flush and is_continue_label): if not (should_flush and is_continue_label):
code.putln("__pyx_parallel_why = %d;" % (i + 1)) if label == code.error_label:
self.error_label_used = True
self.fetch_parallel_exception(code)
code.putln("%s = %d;" % (Naming.parallel_why, i + 1))
code.put_goto(dont_return_label) code.put_goto(dont_return_label)
...@@ -6174,7 +6222,65 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6174,7 +6222,65 @@ class ParallelStatNode(StatNode, ParallelNode):
code.put_label(dont_return_label) code.put_label(dont_return_label)
if should_flush and self.breaking_label_used: if should_flush and self.breaking_label_used:
code.putln_openmp("#pragma omp flush(__pyx_parallel_why)") code.putln_openmp("#pragma omp flush(%s)" % Naming.parallel_why)
def fetch_parallel_exception(self, code):
"""
As each OpenMP thread may raise an exception, we need to fetch that
exception from the threadstate and save it for after the parallel
section where it can be re-raised in the master thread.
Although it would seem that __pyx_filename, __pyx_lineno and
__pyx_clineno are only assigned to under exception conditions (i.e.,
when we have the GIL), and thus should be allowed to be shared without
any race condition, they are in fact subject to the same race
conditions that they were previously when they were global variables
and functions were allowed to release the GIL:
thread A thread B
acquire
set lineno
release
acquire
set lineno
release
acquire
fetch exception
release
skip the fetch
deallocate threadstate deallocate threadstate
"""
code.begin_block()
code.put_ensure_gil(declare_gilstate=True)
code.putln_openmp("#pragma omp flush(%s)" % Naming.parallel_exc_type)
code.putln(
"if (!%s) {" % Naming.parallel_exc_type)
code.putln("__Pyx_ErrFetch(&%s, &%s, &%s);" % self.parallel_exc)
pos_info = chain(*zip(self.parallel_pos_info, self.pos_info))
code.putln("%s = %s; %s = %s; %s = %s;" % tuple(pos_info))
code.putln('__Pyx_GOTREF(%s);' % Naming.parallel_exc_type)
code.putln(
"}")
code.put_release_ensured_gil()
code.end_block()
def restore_parallel_exception(self, code):
"Re-raise a parallel exception"
code.begin_block()
code.put_ensure_gil(declare_gilstate=True)
code.putln("__Pyx_ErrRestore(%s, %s, %s);" % self.parallel_exc)
pos_info = chain(*zip(self.pos_info, self.parallel_pos_info))
code.putln("%s = %s; %s = %s; %s = %s;" % tuple(pos_info))
code.putln("__Pyx_GIVEREF(%s);" % Naming.parallel_exc_type)
code.put_release_ensured_gil()
code.end_block()
def restore_labels(self, code): def restore_labels(self, code):
""" """
...@@ -6199,6 +6305,21 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6199,6 +6305,21 @@ class ParallelStatNode(StatNode, ParallelNode):
Here break should be trapped in the parallel block, and propagated to Here break should be trapped in the parallel block, and propagated to
the for loop. the for loop.
""" """
c = self.begin_of_parallel_control_block_point
# Firstly, always prefer errors over returning, continue or break
if self.error_label_used:
c.putln("PyObject *%s = NULL, *%s = NULL, *%s = NULL;" %
self.parallel_exc)
code.putln(
"if (%s) {" % Naming.parallel_exc_type)
code.putln("/* This may have been overridden by a continue, "
"break or return in another thread. Prefer the error. */")
code.putln("%s = 4;" % Naming.parallel_why)
code.putln(
"}")
if continue_: if continue_:
any_label_used = self.any_label_used any_label_used = self.any_label_used
else: else:
...@@ -6206,12 +6327,12 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6206,12 +6327,12 @@ class ParallelStatNode(StatNode, ParallelNode):
if any_label_used: if any_label_used:
# __pyx_parallel_why is used, declare and initialize # __pyx_parallel_why is used, declare and initialize
c = self.begin_of_parallel_control_block_point c.putln("const char *%s; int %s, %s;" % self.parallel_pos_info)
c.putln("int __pyx_parallel_why;") c.putln("int %s;" % Naming.parallel_why)
c.putln("__pyx_parallel_why = 0;") c.putln("%s = NULL; %s = %s = 0;" % self.parallel_pos_info)
c.putln("%s = 0;" % Naming.parallel_why)
code.putln("switch (__pyx_parallel_why) {")
code.putln("switch (%s) {" % Naming.parallel_why)
if continue_: if continue_:
code.put(" case 1: ") code.put(" case 1: ")
code.put_goto(code.continue_label) code.put_goto(code.continue_label)
...@@ -6222,8 +6343,11 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6222,8 +6343,11 @@ class ParallelStatNode(StatNode, ParallelNode):
code.put(" case 3: ") code.put(" case 3: ")
code.put_goto(code.return_label) code.put_goto(code.return_label)
code.put(" case 4: ")
code.put_goto(code.error_label) if self.error_label_used:
code.putln(" case 4:")
self.restore_parallel_exception(code)
code.put_goto(code.error_label)
code.putln("}") code.putln("}")
...@@ -6266,13 +6390,14 @@ class ParallelWithBlockNode(ParallelStatNode): ...@@ -6266,13 +6390,14 @@ class ParallelWithBlockNode(ParallelStatNode):
code.putln("#endif /* _OPENMP */") code.putln("#endif /* _OPENMP */")
code.begin_block() # parallel block code.begin_block() # parallel block
self.begin_parallel_block(code)
self.initialize_privates_to_nan(code) self.initialize_privates_to_nan(code)
code.funcstate.start_collecting_temps() code.funcstate.start_collecting_temps()
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
self.trap_parallel_exit(code) self.trap_parallel_exit(code)
code.end_block() # end parallel block
self.privatize_temps(code) self.privatize_temps(code)
self.end_parallel_block(code)
code.end_block() # end parallel block
continue_ = code.label_used(code.continue_label) continue_ = code.label_used(code.continue_label)
break_ = code.label_used(code.break_label) break_ = code.label_used(code.break_label)
...@@ -6492,7 +6617,7 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6492,7 +6617,7 @@ class ParallelRangeNode(ParallelStatNode):
if self.else_clause: if self.else_clause:
if self.breaking_label_used: if self.breaking_label_used:
code.put("if (__pyx_parallel_why < 2)" ) code.put("if (%s < 2)" % Naming.parallel_why)
code.begin_block() # else block code.begin_block() # else block
code.putln("/* else */") code.putln("/* else */")
...@@ -6520,7 +6645,19 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6520,7 +6645,19 @@ class ParallelRangeNode(ParallelStatNode):
if not self.is_parallel: if not self.is_parallel:
code.put("#pragma omp for") code.put("#pragma omp for")
else: else:
code.put("#pragma omp parallel for") code.put("#pragma omp parallel")
self.put_num_threads(code)
self.privatization_insertion_point = code.insertion_point()
code.putln("")
code.putln("#endif /* _OPENMP */")
code.begin_block() # pragma omp parallel begin block
# Initialize the GIL if needed for this thread
self.begin_parallel_block(code)
code.putln("#ifdef _OPENMP")
code.put("#pragma omp for")
for entry, op in self.privates.iteritems(): for entry, op in self.privates.iteritems():
# Don't declare the index variable as a reduction # Don't declare the index variable as a reduction
...@@ -6539,12 +6676,9 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6539,12 +6676,9 @@ class ParallelRangeNode(ParallelStatNode):
if self.schedule: if self.schedule:
code.put(" schedule(%s)" % self.schedule) code.put(" schedule(%s)" % self.schedule)
if self.is_parallel: if not self.is_parallel:
self.put_num_threads(code)
else:
self.put_num_threads(self.parent.privatization_insertion_point) self.put_num_threads(self.parent.privatization_insertion_point)
self.privatization_insertion_point = code.insertion_point()
self.privatization_insertion_point = code.insertion_point()
code.putln("") code.putln("")
code.putln("#endif /* _OPENMP */") code.putln("#endif /* _OPENMP */")
...@@ -6554,6 +6688,10 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6554,6 +6688,10 @@ class ParallelRangeNode(ParallelStatNode):
guard_around_body_codepoint = code.insertion_point() guard_around_body_codepoint = code.insertion_point()
# Start if guard block around the body. This may be unnecessary, but
# at least it doesn't spoil indentation
code.begin_block()
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict) code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.initialize_privates_to_nan(code, exclude=self.target.entry) self.initialize_privates_to_nan(code, exclude=self.target.entry)
...@@ -6567,11 +6705,16 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6567,11 +6705,16 @@ class ParallelRangeNode(ParallelStatNode):
if self.breaking_label_used: if self.breaking_label_used:
# Put a guard around the loop body in case return, break or # Put a guard around the loop body in case return, break or
# exceptions might be used # exceptions might be used
guard_around_body_codepoint.put("if (__pyx_parallel_why < 2) {") guard_around_body_codepoint.put("if (%s < 2)" % Naming.parallel_why)
code.putln("}")
code.end_block() # end guard around loop body
code.end_block() # end for loop block code.end_block() # end for loop block
if self.is_parallel:
# Release the GIL and deallocate the thread state
self.end_parallel_block(code)
code.end_block() # pragma omp parallel end block
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
......
...@@ -6,6 +6,8 @@ from cython.parallel import prange, threadid ...@@ -6,6 +6,8 @@ from cython.parallel import prange, threadid
cimport openmp cimport openmp
from libc.stdlib cimport malloc, free from libc.stdlib cimport malloc, free
openmp.omp_set_nested(1)
def test_parallel(): def test_parallel():
""" """
>>> test_parallel() >>> test_parallel()
......
...@@ -402,6 +402,25 @@ def test_parallel_exceptions(): ...@@ -402,6 +402,25 @@ def test_parallel_exceptions():
print mylist[0] print mylist[0]
print e.args, sum print e.args, sum
def test_parallel_exceptions2():
"""
>>> test_parallel_exceptions2()
Traceback (most recent call last):
...
Exception: propagate me
"""
cdef int i, j, k
for i in prange(10, nogil=True):
for j in prange(10):
for k in prange(10):
if i + j + k > 20:
with gil:
raise Exception("propagate me")
break
continue
return
def test_parallel_with_gil_return(): def test_parallel_with_gil_return():
""" """
>>> test_parallel_with_gil_return() >>> test_parallel_with_gil_return()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment