Commit d9aceb67 authored by Mark Florisson's avatar Mark Florisson

Backup lastprivates when breaking/returning/propagating exceptions

parent 58e7e5de
......@@ -5868,6 +5868,8 @@ class ParallelStatNode(StatNode, ParallelNode):
Naming.clineno_cname,
)
critical_section_counter = 0
def __init__(self, pos, **kwargs):
super(ParallelStatNode, self).__init__(pos, **kwargs)
......@@ -5878,7 +5880,8 @@ class ParallelStatNode(StatNode, ParallelNode):
self.seen_closure_vars = set()
# Dict of variables that should be declared (first|last|)private or
# reduction { Entry: op }. If op is not None, it's a reduction.
# reduction { Entry: (op, lastprivate) }.
# If op is not None, it's a reduction.
self.privates = {}
def analyse_declarations(self, env):
......@@ -5908,7 +5911,11 @@ class ParallelStatNode(StatNode, ParallelNode):
"""
for entry, (pos, op) in self.assignments.iteritems():
if self.is_private(entry):
self.propagate_var_privatization(entry, op)
# lastprivate = self.is_prange and entry == self.target.entry
# By default all variables should have the same values as if
# executed sequentially
lastprivate = True
self.propagate_var_privatization(entry, op, lastprivate)
def is_private(self, entry):
"""
......@@ -5918,7 +5925,7 @@ class ParallelStatNode(StatNode, ParallelNode):
return (self.is_parallel or
(self.parent and entry not in self.parent.privates))
def propagate_var_privatization(self, entry, op):
def propagate_var_privatization(self, entry, op, lastprivate):
"""
Propagate the sharing attributes of a variable. If the privatization is
determined by a parent scope, done propagate further.
......@@ -5971,7 +5978,7 @@ class ParallelStatNode(StatNode, ParallelNode):
# sum and j are undefined here
"""
self.privates[entry] = op
self.privates[entry] = (op, lastprivate)
if self.is_prange:
if not self.is_parallel and entry not in self.parent.assignments:
# Parent is a parallel with block
......@@ -5979,8 +5986,10 @@ class ParallelStatNode(StatNode, ParallelNode):
else:
parent = self.parent
if parent:
parent.propagate_var_privatization(entry, op)
# We don't need to propagate privates, only reductions and
# lastprivates
if parent and (op or lastprivate):
parent.propagate_var_privatization(entry, op, lastprivate)
def _allocate_closure_temp(self, code, entry):
"""
......@@ -6041,7 +6050,7 @@ class ParallelStatNode(StatNode, ParallelNode):
def initialize_privates_to_nan(self, code, exclude=None):
first = True
for entry, op in self.privates.iteritems():
for entry, (op, lastprivate) in self.privates.iteritems():
if not op and (not exclude or entry != exclude):
invalid_value = entry.type.invalid_value()
......@@ -6070,11 +6079,6 @@ class ParallelStatNode(StatNode, ParallelNode):
def declare_closure_privates(self, code):
"""
Set self.privates to a dict mapping C variable names that are to be
declared (first|last)private or reduction, to the reduction operator.
If the private is not a reduction, the operator is None.
This is used by subclasses.
If a variable is in a scope object, we need to allocate a temp and
assign the value from the temp to the variable in the scope object
after the parallel section. This kind of copying should be done only
......@@ -6197,6 +6201,7 @@ class ParallelStatNode(StatNode, ParallelNode):
3 return
4 error
"""
save_lastprivates_label = code.new_label()
dont_return_label = code.new_label()
insertion_point = code.insertion_point()
......@@ -6204,12 +6209,25 @@ class ParallelStatNode(StatNode, ParallelNode):
self.breaking_label_used = False
self.error_label_used = False
for i, label in enumerate(code.get_all_labels()):
self.parallel_private_temps = []
all_labels = code.get_all_labels()
# Figure this out before starting to generate any code
for label in all_labels:
if code.label_used(label):
self.breaking_label_used = (self.breaking_label_used or
label != code.continue_label)
self.any_label_used = True
if self.any_label_used:
code.put_goto(dont_return_label)
for i, label in enumerate(all_labels):
if not code.label_used(label):
continue
is_continue_label = label == code.continue_label
self.breaking_label_used = (self.breaking_label_used or
not is_continue_label)
code.put_label(label)
......@@ -6220,16 +6238,62 @@ class ParallelStatNode(StatNode, ParallelNode):
code.putln("%s = %d;" % (Naming.parallel_why, i + 1))
if (self.breaking_label_used and self.is_prange and not
is_continue_label):
code.put_goto(save_lastprivates_label)
else:
code.put_goto(dont_return_label)
if self.any_label_used:
insertion_point.funcstate = code.funcstate
insertion_point.put_goto(dont_return_label)
if self.is_prange and self.breaking_label_used:
# Don't rely on lastprivate, save our lastprivates
code.put_label(save_lastprivates_label)
self.save_parallel_vars(code)
code.put_label(dont_return_label)
if should_flush and self.breaking_label_used:
code.putln_openmp("#pragma omp flush(%s)" % Naming.parallel_why)
def save_parallel_vars(self, code):
"""
The following shenanigans are instated when we break, return or
propagate errors from a prange. In this case we cannot rely on
lastprivate() to do its job, as no iterations may have executed yet
in the last thread, leaving the values undefined. It is most likely
that the breaking thread has well-defined values of the lastprivate
variables, so we keep those values.
"""
section_name = ("__pyx_parallel_lastprivates%d" %
self.critical_section_counter)
code.putln_openmp("#pragma omp critical(%s)" % section_name)
ParallelStatNode.critical_section_counter += 1
code.begin_block() # begin critical section
c = self.begin_of_parallel_control_block_point
temp_count = 0
for entry, (op, lastprivate) in self.privates.iteritems():
if not lastprivate or entry.type.is_pyobject:
continue
type_decl = entry.type.declaration_code("")
temp_cname = "__pyx_parallel_temp%d" % temp_count
private_cname = entry.cname
temp_count += 1
# Declare the parallel private in the outer block
c.putln("%s %s;" % (type_decl, temp_cname))
# Initialize before escaping
code.putln("%s = %s;" % (temp_cname, private_cname))
self.parallel_private_temps.append((temp_cname, private_cname))
code.end_block() # end critical section
def fetch_parallel_exception(self, code):
"""
As each OpenMP thread may raise an exception, we need to fetch that
......@@ -6315,6 +6379,9 @@ class ParallelStatNode(StatNode, ParallelNode):
# Firstly, always prefer errors over returning, continue or break
if self.error_label_used:
c.putln("const char *%s; int %s, %s;" % self.parallel_pos_info)
c.putln("%s = NULL; %s = %s = 0;" % self.parallel_pos_info)
c.putln("PyObject *%s = NULL, *%s = NULL, *%s = NULL;" %
self.parallel_exc)
......@@ -6333,11 +6400,15 @@ class ParallelStatNode(StatNode, ParallelNode):
if any_label_used:
# __pyx_parallel_why is used, declare and initialize
c.putln("const char *%s; int %s, %s;" % self.parallel_pos_info)
c.putln("int %s;" % Naming.parallel_why)
c.putln("%s = NULL; %s = %s = 0;" % self.parallel_pos_info)
c.putln("%s = 0;" % Naming.parallel_why)
code.putln(
"if (%s) {" % Naming.parallel_why)
for temp_cname, private_cname in self.parallel_private_temps:
code.putln("%s = %s;" % (private_cname, temp_cname))
code.putln("switch (%s) {" % Naming.parallel_why)
if continue_:
code.put(" case 1: ")
......@@ -6355,7 +6426,9 @@ class ParallelStatNode(StatNode, ParallelNode):
self.restore_parallel_exception(code)
code.put_goto(code.error_label)
code.putln("}")
code.putln("}") # end switch
code.putln(
"}") # end if
code.end_block() # end parallel control flow block
......@@ -6665,7 +6738,7 @@ class ParallelRangeNode(ParallelStatNode):
code.putln("#ifdef _OPENMP")
code.put("#pragma omp for")
for entry, op in self.privates.iteritems():
for entry, (op, lastprivate) in self.privates.iteritems():
# Don't declare the index variable as a reduction
if op and op in "+*-&^|" and entry != self.target.entry:
if entry.type.is_pyobject:
......@@ -6675,9 +6748,16 @@ class ParallelRangeNode(ParallelStatNode):
else:
if entry == self.target.entry:
code.put(" firstprivate(%s)" % entry.cname)
code.put(" lastprivate(%s)" % entry.cname)
continue
if not entry.type.is_pyobject:
code.put(" lastprivate(%s)" % entry.cname)
if lastprivate:
private = 'lastprivate'
else:
private = 'private'
code.put(" %s(%s)" % (private, entry.cname))
if self.schedule:
code.put(" schedule(%s)" % self.schedule)
......@@ -6711,7 +6791,7 @@ class ParallelRangeNode(ParallelStatNode):
if self.breaking_label_used:
# Put a guard around the loop body in case return, break or
# exceptions might be used
guard_around_body_codepoint.put("if (%s < 2)" % Naming.parallel_why)
guard_around_body_codepoint.putln("if (%s < 2)" % Naming.parallel_why)
code.end_block() # end guard around loop body
code.end_block() # end for loop block
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment