Commit 40417d36 authored by Mark Florisson's avatar Mark Florisson

Allocate thread-state for each OpenMP thread

Save exc_info in parallel section and re-raise outside in master thread
parent 04a8c7fe
......@@ -1400,7 +1400,7 @@ class CCodeWriter(object):
self.putln("#ifdef WITH_THREAD")
if declare_gilstate:
self.put("PyGILState_STATE ")
self.putln("_save = PyGILState_Ensure();")
self.putln("__pyx_gilstate_save = PyGILState_Ensure();")
self.putln("#endif")
def put_release_ensured_gil(self):
......@@ -1408,7 +1408,7 @@ class CCodeWriter(object):
Releases the GIL, corresponds to `put_ensure_gil`.
"""
self.putln("#ifdef WITH_THREAD")
self.putln("PyGILState_Release(_save);")
self.putln("PyGILState_Release(__pyx_gilstate_save);")
self.putln("#endif")
def put_acquire_gil(self):
......@@ -1427,7 +1427,7 @@ class CCodeWriter(object):
def declare_gilstate(self):
self.putln("#ifdef WITH_THREAD")
self.putln("PyGILState_STATE _save;")
self.putln("PyGILState_STATE __pyx_gilstate_save;")
self.putln("#endif")
# error handling
......
......@@ -108,6 +108,14 @@ exc_value_name = pyrex_prefix + "exc_value"
exc_tb_name = pyrex_prefix + "exc_tb"
exc_lineno_name = pyrex_prefix + "exc_lineno"
parallel_exc_type = pyrex_prefix + "parallel_exc_type"
parallel_exc_value = pyrex_prefix + "parallel_exc_value"
parallel_exc_tb = pyrex_prefix + "parallel_exc_tb"
parallel_filename = pyrex_prefix + "parallel_filename"
parallel_lineno = pyrex_prefix + "parallel_lineno"
parallel_clineno = pyrex_prefix + "parallel_clineno"
parallel_why = pyrex_prefix + "parallel_why"
exc_vars = (exc_type_name, exc_value_name, exc_tb_name)
api_name = pyrex_prefix + "capi__"
......
......@@ -26,6 +26,7 @@ from Code import UtilityCode, ClosureTempAllocator
from StringEncoding import EncodedString, escape_byte_string, split_string_literal
import Options
import DebugFlags
from itertools import chain
absolute_path_length = 0
......@@ -5845,12 +5846,25 @@ class ParallelStatNode(StatNode, ParallelNode):
is_prange = False
# Labels to break out of parallel constructs
break_label_used = False
continue_label_used = False
error_label_used = False
return_label_used = False
parallel_exc = (
Naming.parallel_exc_type,
Naming.parallel_exc_value,
Naming.parallel_exc_tb,
)
parallel_pos_info = (
Naming.parallel_filename,
Naming.parallel_lineno,
Naming.parallel_clineno,
)
pos_info = (
Naming.filename_cname,
Naming.lineno_cname,
Naming.clineno_cname,
)
def __init__(self, pos, **kwargs):
super(ParallelStatNode, self).__init__(pos, **kwargs)
......@@ -6094,11 +6108,17 @@ class ParallelStatNode(StatNode, ParallelNode):
if not type.is_pyobject:
temps.append(cname)
c = self.privatization_insertion_point
if temps:
c = self.privatization_insertion_point
c.put(" private(%s)" % ", ".join(temps))
if self.breaking_label_used:
c.put(" shared(__pyx_parallel_why)")
if self.breaking_label_used:
shared_vars = [Naming.parallel_why]
if self.error_label_used:
shared_vars.extend(self.parallel_exc)
c.put(" private(%s, %s, %s)" % self.pos_info)
c.put(" shared(%s)" % ', '.join(shared_vars))
def setup_parallel_control_flow_block(self, code):
"""
......@@ -6134,6 +6154,29 @@ class ParallelStatNode(StatNode, ParallelNode):
code.begin_block() # parallel control flow block
self.begin_of_parallel_control_block_point = code.insertion_point()
def begin_parallel_block(self, code):
"""
Each OpenMP thread in a parallel section that contains a with gil block
must have the thread-state initialized. The call to
PyGILState_Release() then deallocates our threadstate. If we wouldn't
do this, each with gil block would allocate and deallocate one, thereby
losing exception information before it can be saved before leaving the
parallel section.
"""
self.begin_of_parallel_block = code.insertion_point()
def end_parallel_block(self, code):
"Acquire the GIL, deallocate threadstate, release"
if self.error_label_used:
begin_code = self.begin_of_parallel_block
end_code = code
begin_code.put_ensure_gil(declare_gilstate=True)
begin_code.putln("Py_BEGIN_ALLOW_THREADS")
end_code.putln("Py_END_ALLOW_THREADS")
end_code.put_release_ensured_gil()
def trap_parallel_exit(self, code, should_flush=False):
"""
Trap any kind of return inside a parallel construct. 'should_flush'
......@@ -6153,6 +6196,7 @@ class ParallelStatNode(StatNode, ParallelNode):
self.any_label_used = False
self.breaking_label_used = False
self.error_label_used = False
for i, label in enumerate(code.get_all_labels()):
if code.label_used(label):
......@@ -6164,7 +6208,11 @@ class ParallelStatNode(StatNode, ParallelNode):
code.put_label(label)
if not (should_flush and is_continue_label):
code.putln("__pyx_parallel_why = %d;" % (i + 1))
if label == code.error_label:
self.error_label_used = True
self.fetch_parallel_exception(code)
code.putln("%s = %d;" % (Naming.parallel_why, i + 1))
code.put_goto(dont_return_label)
......@@ -6174,7 +6222,65 @@ class ParallelStatNode(StatNode, ParallelNode):
code.put_label(dont_return_label)
if should_flush and self.breaking_label_used:
code.putln_openmp("#pragma omp flush(__pyx_parallel_why)")
code.putln_openmp("#pragma omp flush(%s)" % Naming.parallel_why)
def fetch_parallel_exception(self, code):
"""
As each OpenMP thread may raise an exception, we need to fetch that
exception from the threadstate and save it for after the parallel
section where it can be re-raised in the master thread.
Although it would seem that __pyx_filename, __pyx_lineno and
__pyx_clineno are only assigned to under exception conditions (i.e.,
when we have the GIL), and thus should be allowed to be shared without
any race condition, they are in fact subject to the same race
conditions that they were previously when they were global variables
and functions were allowed to release the GIL:
thread A thread B
acquire
set lineno
release
acquire
set lineno
release
acquire
fetch exception
release
skip the fetch
deallocate threadstate deallocate threadstate
"""
code.begin_block()
code.put_ensure_gil(declare_gilstate=True)
code.putln_openmp("#pragma omp flush(%s)" % Naming.parallel_exc_type)
code.putln(
"if (!%s) {" % Naming.parallel_exc_type)
code.putln("__Pyx_ErrFetch(&%s, &%s, &%s);" % self.parallel_exc)
pos_info = chain(*zip(self.parallel_pos_info, self.pos_info))
code.putln("%s = %s; %s = %s; %s = %s;" % tuple(pos_info))
code.putln('__Pyx_GOTREF(%s);' % Naming.parallel_exc_type)
code.putln(
"}")
code.put_release_ensured_gil()
code.end_block()
def restore_parallel_exception(self, code):
"Re-raise a parallel exception"
code.begin_block()
code.put_ensure_gil(declare_gilstate=True)
code.putln("__Pyx_ErrRestore(%s, %s, %s);" % self.parallel_exc)
pos_info = chain(*zip(self.pos_info, self.parallel_pos_info))
code.putln("%s = %s; %s = %s; %s = %s;" % tuple(pos_info))
code.putln("__Pyx_GIVEREF(%s);" % Naming.parallel_exc_type)
code.put_release_ensured_gil()
code.end_block()
def restore_labels(self, code):
"""
......@@ -6199,6 +6305,21 @@ class ParallelStatNode(StatNode, ParallelNode):
Here break should be trapped in the parallel block, and propagated to
the for loop.
"""
c = self.begin_of_parallel_control_block_point
# Firstly, always prefer errors over returning, continue or break
if self.error_label_used:
c.putln("PyObject *%s = NULL, *%s = NULL, *%s = NULL;" %
self.parallel_exc)
code.putln(
"if (%s) {" % Naming.parallel_exc_type)
code.putln("/* This may have been overridden by a continue, "
"break or return in another thread. Prefer the error. */")
code.putln("%s = 4;" % Naming.parallel_why)
code.putln(
"}")
if continue_:
any_label_used = self.any_label_used
else:
......@@ -6206,12 +6327,12 @@ class ParallelStatNode(StatNode, ParallelNode):
if any_label_used:
# __pyx_parallel_why is used, declare and initialize
c = self.begin_of_parallel_control_block_point
c.putln("int __pyx_parallel_why;")
c.putln("__pyx_parallel_why = 0;")
code.putln("switch (__pyx_parallel_why) {")
c.putln("const char *%s; int %s, %s;" % self.parallel_pos_info)
c.putln("int %s;" % Naming.parallel_why)
c.putln("%s = NULL; %s = %s = 0;" % self.parallel_pos_info)
c.putln("%s = 0;" % Naming.parallel_why)
code.putln("switch (%s) {" % Naming.parallel_why)
if continue_:
code.put(" case 1: ")
code.put_goto(code.continue_label)
......@@ -6222,8 +6343,11 @@ class ParallelStatNode(StatNode, ParallelNode):
code.put(" case 3: ")
code.put_goto(code.return_label)
code.put(" case 4: ")
code.put_goto(code.error_label)
if self.error_label_used:
code.putln(" case 4:")
self.restore_parallel_exception(code)
code.put_goto(code.error_label)
code.putln("}")
......@@ -6266,13 +6390,14 @@ class ParallelWithBlockNode(ParallelStatNode):
code.putln("#endif /* _OPENMP */")
code.begin_block() # parallel block
self.begin_parallel_block(code)
self.initialize_privates_to_nan(code)
code.funcstate.start_collecting_temps()
self.body.generate_execution_code(code)
self.trap_parallel_exit(code)
code.end_block() # end parallel block
self.privatize_temps(code)
self.end_parallel_block(code)
code.end_block() # end parallel block
continue_ = code.label_used(code.continue_label)
break_ = code.label_used(code.break_label)
......@@ -6492,7 +6617,7 @@ class ParallelRangeNode(ParallelStatNode):
if self.else_clause:
if self.breaking_label_used:
code.put("if (__pyx_parallel_why < 2)" )
code.put("if (%s < 2)" % Naming.parallel_why)
code.begin_block() # else block
code.putln("/* else */")
......@@ -6520,7 +6645,19 @@ class ParallelRangeNode(ParallelStatNode):
if not self.is_parallel:
code.put("#pragma omp for")
else:
code.put("#pragma omp parallel for")
code.put("#pragma omp parallel")
self.put_num_threads(code)
self.privatization_insertion_point = code.insertion_point()
code.putln("")
code.putln("#endif /* _OPENMP */")
code.begin_block() # pragma omp parallel begin block
# Initialize the GIL if needed for this thread
self.begin_parallel_block(code)
code.putln("#ifdef _OPENMP")
code.put("#pragma omp for")
for entry, op in self.privates.iteritems():
# Don't declare the index variable as a reduction
......@@ -6539,12 +6676,9 @@ class ParallelRangeNode(ParallelStatNode):
if self.schedule:
code.put(" schedule(%s)" % self.schedule)
if self.is_parallel:
self.put_num_threads(code)
else:
if not self.is_parallel:
self.put_num_threads(self.parent.privatization_insertion_point)
self.privatization_insertion_point = code.insertion_point()
self.privatization_insertion_point = code.insertion_point()
code.putln("")
code.putln("#endif /* _OPENMP */")
......@@ -6554,6 +6688,10 @@ class ParallelRangeNode(ParallelStatNode):
guard_around_body_codepoint = code.insertion_point()
# Start if guard block around the body. This may be unnecessary, but
# at least it doesn't spoil indentation
code.begin_block()
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.initialize_privates_to_nan(code, exclude=self.target.entry)
......@@ -6567,11 +6705,16 @@ class ParallelRangeNode(ParallelStatNode):
if self.breaking_label_used:
# Put a guard around the loop body in case return, break or
# exceptions might be used
guard_around_body_codepoint.put("if (__pyx_parallel_why < 2) {")
code.putln("}")
guard_around_body_codepoint.put("if (%s < 2)" % Naming.parallel_why)
code.end_block() # end guard around loop body
code.end_block() # end for loop block
if self.is_parallel:
# Release the GIL and deallocate the thread state
self.end_parallel_block(code)
code.end_block() # pragma omp parallel end block
#------------------------------------------------------------------------------------
#
......
......@@ -6,6 +6,8 @@ from cython.parallel import prange, threadid
cimport openmp
from libc.stdlib cimport malloc, free
openmp.omp_set_nested(1)
def test_parallel():
"""
>>> test_parallel()
......
......@@ -402,6 +402,25 @@ def test_parallel_exceptions():
print mylist[0]
print e.args, sum
def test_parallel_exceptions2():
"""
>>> test_parallel_exceptions2()
Traceback (most recent call last):
...
Exception: propagate me
"""
cdef int i, j, k
for i in prange(10, nogil=True):
for j in prange(10):
for k in prange(10):
if i + j + k > 20:
with gil:
raise Exception("propagate me")
break
continue
return
def test_parallel_with_gil_return():
"""
>>> test_parallel_with_gil_return()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment