Commit d9aceb67 authored by Mark Florisson's avatar Mark Florisson

Backup lastprivates when breaking/returning/propagating exceptions

parent 58e7e5de
...@@ -5868,6 +5868,8 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5868,6 +5868,8 @@ class ParallelStatNode(StatNode, ParallelNode):
Naming.clineno_cname, Naming.clineno_cname,
) )
critical_section_counter = 0
def __init__(self, pos, **kwargs): def __init__(self, pos, **kwargs):
super(ParallelStatNode, self).__init__(pos, **kwargs) super(ParallelStatNode, self).__init__(pos, **kwargs)
...@@ -5878,7 +5880,8 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5878,7 +5880,8 @@ class ParallelStatNode(StatNode, ParallelNode):
self.seen_closure_vars = set() self.seen_closure_vars = set()
# Dict of variables that should be declared (first|last|)private or # Dict of variables that should be declared (first|last|)private or
# reduction { Entry: op }. If op is not None, it's a reduction. # reduction { Entry: (op, lastprivate) }.
# If op is not None, it's a reduction.
self.privates = {} self.privates = {}
def analyse_declarations(self, env): def analyse_declarations(self, env):
...@@ -5908,7 +5911,11 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5908,7 +5911,11 @@ class ParallelStatNode(StatNode, ParallelNode):
""" """
for entry, (pos, op) in self.assignments.iteritems(): for entry, (pos, op) in self.assignments.iteritems():
if self.is_private(entry): if self.is_private(entry):
self.propagate_var_privatization(entry, op) # lastprivate = self.is_prange and entry == self.target.entry
# By default all variables should have the same values as if
# executed sequentially
lastprivate = True
self.propagate_var_privatization(entry, op, lastprivate)
def is_private(self, entry): def is_private(self, entry):
""" """
...@@ -5918,7 +5925,7 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5918,7 +5925,7 @@ class ParallelStatNode(StatNode, ParallelNode):
return (self.is_parallel or return (self.is_parallel or
(self.parent and entry not in self.parent.privates)) (self.parent and entry not in self.parent.privates))
def propagate_var_privatization(self, entry, op): def propagate_var_privatization(self, entry, op, lastprivate):
""" """
Propagate the sharing attributes of a variable. If the privatization is Propagate the sharing attributes of a variable. If the privatization is
determined by a parent scope, done propagate further. determined by a parent scope, done propagate further.
...@@ -5971,7 +5978,7 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5971,7 +5978,7 @@ class ParallelStatNode(StatNode, ParallelNode):
# sum and j are undefined here # sum and j are undefined here
""" """
self.privates[entry] = op self.privates[entry] = (op, lastprivate)
if self.is_prange: if self.is_prange:
if not self.is_parallel and entry not in self.parent.assignments: if not self.is_parallel and entry not in self.parent.assignments:
# Parent is a parallel with block # Parent is a parallel with block
...@@ -5979,8 +5986,10 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5979,8 +5986,10 @@ class ParallelStatNode(StatNode, ParallelNode):
else: else:
parent = self.parent parent = self.parent
if parent: # We don't need to propagate privates, only reductions and
parent.propagate_var_privatization(entry, op) # lastprivates
if parent and (op or lastprivate):
parent.propagate_var_privatization(entry, op, lastprivate)
def _allocate_closure_temp(self, code, entry): def _allocate_closure_temp(self, code, entry):
""" """
...@@ -6041,7 +6050,7 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6041,7 +6050,7 @@ class ParallelStatNode(StatNode, ParallelNode):
def initialize_privates_to_nan(self, code, exclude=None): def initialize_privates_to_nan(self, code, exclude=None):
first = True first = True
for entry, op in self.privates.iteritems(): for entry, (op, lastprivate) in self.privates.iteritems():
if not op and (not exclude or entry != exclude): if not op and (not exclude or entry != exclude):
invalid_value = entry.type.invalid_value() invalid_value = entry.type.invalid_value()
...@@ -6070,11 +6079,6 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6070,11 +6079,6 @@ class ParallelStatNode(StatNode, ParallelNode):
def declare_closure_privates(self, code): def declare_closure_privates(self, code):
""" """
Set self.privates to a dict mapping C variable names that are to be
declared (first|last)private or reduction, to the reduction operator.
If the private is not a reduction, the operator is None.
This is used by subclasses.
If a variable is in a scope object, we need to allocate a temp and If a variable is in a scope object, we need to allocate a temp and
assign the value from the temp to the variable in the scope object assign the value from the temp to the variable in the scope object
after the parallel section. This kind of copying should be done only after the parallel section. This kind of copying should be done only
...@@ -6197,6 +6201,7 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6197,6 +6201,7 @@ class ParallelStatNode(StatNode, ParallelNode):
3 return 3 return
4 error 4 error
""" """
save_lastprivates_label = code.new_label()
dont_return_label = code.new_label() dont_return_label = code.new_label()
insertion_point = code.insertion_point() insertion_point = code.insertion_point()
...@@ -6204,32 +6209,91 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6204,32 +6209,91 @@ class ParallelStatNode(StatNode, ParallelNode):
self.breaking_label_used = False self.breaking_label_used = False
self.error_label_used = False self.error_label_used = False
for i, label in enumerate(code.get_all_labels()): self.parallel_private_temps = []
all_labels = code.get_all_labels()
# Figure this out before starting to generate any code
for label in all_labels:
if code.label_used(label): if code.label_used(label):
self.any_label_used = True
is_continue_label = label == code.continue_label
self.breaking_label_used = (self.breaking_label_used or self.breaking_label_used = (self.breaking_label_used or
not is_continue_label) label != code.continue_label)
self.any_label_used = True
if self.any_label_used:
code.put_goto(dont_return_label)
for i, label in enumerate(all_labels):
if not code.label_used(label):
continue
is_continue_label = label == code.continue_label
code.put_label(label) code.put_label(label)
if not (should_flush and is_continue_label): if not (should_flush and is_continue_label):
if label == code.error_label: if label == code.error_label:
self.error_label_used = True self.error_label_used = True
self.fetch_parallel_exception(code) self.fetch_parallel_exception(code)
code.putln("%s = %d;" % (Naming.parallel_why, i + 1)) code.putln("%s = %d;" % (Naming.parallel_why, i + 1))
if (self.breaking_label_used and self.is_prange and not
is_continue_label):
code.put_goto(save_lastprivates_label)
else:
code.put_goto(dont_return_label) code.put_goto(dont_return_label)
if self.any_label_used: if self.any_label_used:
insertion_point.funcstate = code.funcstate if self.is_prange and self.breaking_label_used:
insertion_point.put_goto(dont_return_label) # Don't rely on lastprivate, save our lastprivates
code.put_label(save_lastprivates_label)
self.save_parallel_vars(code)
code.put_label(dont_return_label) code.put_label(dont_return_label)
if should_flush and self.breaking_label_used: if should_flush and self.breaking_label_used:
code.putln_openmp("#pragma omp flush(%s)" % Naming.parallel_why) code.putln_openmp("#pragma omp flush(%s)" % Naming.parallel_why)
def save_parallel_vars(self, code):
"""
The following shenanigans are instated when we break, return or
propagate errors from a prange. In this case we cannot rely on
lastprivate() to do its job, as no iterations may have executed yet
in the last thread, leaving the values undefined. It is most likely
that the breaking thread has well-defined values of the lastprivate
variables, so we keep those values.
"""
section_name = ("__pyx_parallel_lastprivates%d" %
self.critical_section_counter)
code.putln_openmp("#pragma omp critical(%s)" % section_name)
ParallelStatNode.critical_section_counter += 1
code.begin_block() # begin critical section
c = self.begin_of_parallel_control_block_point
temp_count = 0
for entry, (op, lastprivate) in self.privates.iteritems():
if not lastprivate or entry.type.is_pyobject:
continue
type_decl = entry.type.declaration_code("")
temp_cname = "__pyx_parallel_temp%d" % temp_count
private_cname = entry.cname
temp_count += 1
# Declare the parallel private in the outer block
c.putln("%s %s;" % (type_decl, temp_cname))
# Initialize before escaping
code.putln("%s = %s;" % (temp_cname, private_cname))
self.parallel_private_temps.append((temp_cname, private_cname))
code.end_block() # end critical section
def fetch_parallel_exception(self, code): def fetch_parallel_exception(self, code):
""" """
As each OpenMP thread may raise an exception, we need to fetch that As each OpenMP thread may raise an exception, we need to fetch that
...@@ -6315,6 +6379,9 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6315,6 +6379,9 @@ class ParallelStatNode(StatNode, ParallelNode):
# Firstly, always prefer errors over returning, continue or break # Firstly, always prefer errors over returning, continue or break
if self.error_label_used: if self.error_label_used:
c.putln("const char *%s; int %s, %s;" % self.parallel_pos_info)
c.putln("%s = NULL; %s = %s = 0;" % self.parallel_pos_info)
c.putln("PyObject *%s = NULL, *%s = NULL, *%s = NULL;" % c.putln("PyObject *%s = NULL, *%s = NULL, *%s = NULL;" %
self.parallel_exc) self.parallel_exc)
...@@ -6333,11 +6400,15 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6333,11 +6400,15 @@ class ParallelStatNode(StatNode, ParallelNode):
if any_label_used: if any_label_used:
# __pyx_parallel_why is used, declare and initialize # __pyx_parallel_why is used, declare and initialize
c.putln("const char *%s; int %s, %s;" % self.parallel_pos_info)
c.putln("int %s;" % Naming.parallel_why) c.putln("int %s;" % Naming.parallel_why)
c.putln("%s = NULL; %s = %s = 0;" % self.parallel_pos_info)
c.putln("%s = 0;" % Naming.parallel_why) c.putln("%s = 0;" % Naming.parallel_why)
code.putln(
"if (%s) {" % Naming.parallel_why)
for temp_cname, private_cname in self.parallel_private_temps:
code.putln("%s = %s;" % (private_cname, temp_cname))
code.putln("switch (%s) {" % Naming.parallel_why) code.putln("switch (%s) {" % Naming.parallel_why)
if continue_: if continue_:
code.put(" case 1: ") code.put(" case 1: ")
...@@ -6355,7 +6426,9 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6355,7 +6426,9 @@ class ParallelStatNode(StatNode, ParallelNode):
self.restore_parallel_exception(code) self.restore_parallel_exception(code)
code.put_goto(code.error_label) code.put_goto(code.error_label)
code.putln("}") code.putln("}") # end switch
code.putln(
"}") # end if
code.end_block() # end parallel control flow block code.end_block() # end parallel control flow block
...@@ -6665,7 +6738,7 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6665,7 +6738,7 @@ class ParallelRangeNode(ParallelStatNode):
code.putln("#ifdef _OPENMP") code.putln("#ifdef _OPENMP")
code.put("#pragma omp for") code.put("#pragma omp for")
for entry, op in self.privates.iteritems(): for entry, (op, lastprivate) in self.privates.iteritems():
# Don't declare the index variable as a reduction # Don't declare the index variable as a reduction
if op and op in "+*-&^|" and entry != self.target.entry: if op and op in "+*-&^|" and entry != self.target.entry:
if entry.type.is_pyobject: if entry.type.is_pyobject:
...@@ -6675,9 +6748,16 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6675,9 +6748,16 @@ class ParallelRangeNode(ParallelStatNode):
else: else:
if entry == self.target.entry: if entry == self.target.entry:
code.put(" firstprivate(%s)" % entry.cname) code.put(" firstprivate(%s)" % entry.cname)
code.put(" lastprivate(%s)" % entry.cname)
continue
if not entry.type.is_pyobject: if not entry.type.is_pyobject:
code.put(" lastprivate(%s)" % entry.cname) if lastprivate:
private = 'lastprivate'
else:
private = 'private'
code.put(" %s(%s)" % (private, entry.cname))
if self.schedule: if self.schedule:
code.put(" schedule(%s)" % self.schedule) code.put(" schedule(%s)" % self.schedule)
...@@ -6711,7 +6791,7 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6711,7 +6791,7 @@ class ParallelRangeNode(ParallelStatNode):
if self.breaking_label_used: if self.breaking_label_used:
# Put a guard around the loop body in case return, break or # Put a guard around the loop body in case return, break or
# exceptions might be used # exceptions might be used
guard_around_body_codepoint.put("if (%s < 2)" % Naming.parallel_why) guard_around_body_codepoint.putln("if (%s < 2)" % Naming.parallel_why)
code.end_block() # end guard around loop body code.end_block() # end guard around loop body
code.end_block() # end for loop block code.end_block() # end for loop block
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment