Commit 468c17af authored by Mark Florisson's avatar Mark Florisson

Use a switch like try/finally for parallel break/continue/return/error

parent f70465bd
...@@ -97,14 +97,6 @@ binding_cfunc = pyrex_prefix + "binding_PyCFunctionType" ...@@ -97,14 +97,6 @@ binding_cfunc = pyrex_prefix + "binding_PyCFunctionType"
genexpr_id_ref = 'genexpr' genexpr_id_ref = 'genexpr'
# These labels are needed for break, return and exception paths through
# parallel code sections. These are the variables that remember which path
# should be taken after the parallel section
parallel_error = pyrex_prefix + "parallel_error"
parallel_return = pyrex_prefix + "parallel_return"
parallel_break = pyrex_prefix + "parallel_break"
parallel_continue = pyrex_prefix + "parallel_continue"
line_c_macro = "__LINE__" line_c_macro = "__LINE__"
file_c_macro = "__FILE__" file_c_macro = "__FILE__"
......
...@@ -6069,14 +6069,14 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6069,14 +6069,14 @@ class ParallelStatNode(StatNode, ParallelNode):
code.funcstate.release_temp(entry.cname) code.funcstate.release_temp(entry.cname)
entry.cname = original_cname entry.cname = original_cname
def setup_control_flow_variables_block(self, code): def setup_parallel_control_flow_block(self, code):
""" """
Sets up any needed variables outside the parallel block to determine Sets up a block that surrounds the parallel block to determine
how the parallel block was left. Inside the any kind of return is how the parallel section was exited. Any kind of return is
trapped (break, continue, return, exceptions). This is the idea: trapped (break, continue, return, exceptions). This is the idea:
{ {
int returning = 0; int why = 0;
#pragma omp parallel #pragma omp parallel
{ {
...@@ -6084,87 +6084,119 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6084,87 +6084,119 @@ class ParallelStatNode(StatNode, ParallelNode):
goto end_parallel; goto end_parallel;
new_return_label: new_return_label:
returning = 1; why = 3;
#pragma omp flush(returning)
goto end_parallel; goto end_parallel;
end_parallel:; end_parallel:;
#pragma omp flush(why) # we need to flush for every iteration
} }
if (returning) if (why == 3)
goto old_return_label; goto old_return_label;
} }
""" """
self.need_returning_guard = (self.break_label_used or
self.error_label_used or
self.return_label_used)
self.any_label_used = (self.need_returning_guard or
self.continue_label_used)
if not self.any_label_used:
return
self.old_loop_labels = code.new_loop_labels() self.old_loop_labels = code.new_loop_labels()
self.old_error_label = code.new_error_label() self.old_error_label = code.new_error_label()
self.old_return_label = code.return_label self.old_return_label = code.return_label
code.return_label = code.new_label(name="return") code.return_label = code.new_label(name="return")
self.labels = ( code.begin_block() # parallel control flow block
(Naming.parallel_break, 'break_label', self.break_label_used), self.begin_of_parallel_control_block_point = code.insertion_point()
(Naming.parallel_continue, 'continue_label', self.continue_label_used),
(Naming.parallel_error, 'error_label', self.error_label_used),
(Naming.parallel_return, 'return_label', self.return_label_used),
)
code.begin_block() # control flow variables block
for var, label_name, label_used in self.labels:
if label_used:
code.putln("int %s = 0;" % var)
def trap_parallel_exit(self, code, should_flush=False): def trap_parallel_exit(self, code, should_flush=False):
""" """
Trap any kind of return inside a parallel construct. 'should_flush' Trap any kind of return inside a parallel construct. 'should_flush'
indicates whether the variable should be flushed, which is needed by indicates whether the variable should be flushed, which is needed by
prange to skip the loop. prange to skip the loop. It also indicates whether we need to register
a continue (we need this for parallel blocks, but not for prange
loops, as it is a direct jump there).
It uses the same mechanism as try/finally:
1 continue
2 break
3 return
4 error
""" """
if not self.any_label_used:
return
dont_return_label = code.new_label() dont_return_label = code.new_label()
code.put_goto(dont_return_label) insertion_point = code.insertion_point()
self.any_label_used = False
self.breaking_label_used = False
for i, label in enumerate(code.get_all_labels()):
if code.label_used(label):
self.any_label_used = True
is_continue_label = label == code.continue_label
self.breaking_label_used = (self.breaking_label_used or
not is_continue_label)
for var, label_name, label_used in self.labels:
if label_used:
label = getattr(code, label_name)
code.put_label(label) code.put_label(label)
code.putln("%s = 1;" % var) if not (should_flush and is_continue_label):
if should_flush: code.putln("__pyx_parallel_why = %d;" % (i + 1))
code.putln_openmp("#pragma omp flush(%s)" % var)
code.put_goto(dont_return_label) code.put_goto(dont_return_label)
if self.any_label_used:
insertion_point.funcstate = code.funcstate
insertion_point.put_goto(dont_return_label)
code.put_label(dont_return_label) code.put_label(dont_return_label)
if should_flush and self.breaking_label_used:
code.putln_openmp("#pragma omp flush(__pyx_parallel_why)")
def restore_labels(self, code): def restore_labels(self, code):
if self.any_label_used: """
Restore all old labels. Call this before the 'else' clause to for
loops and always before ending the parallel control flow block.
"""
code.set_all_labels(self.old_loop_labels + (self.old_return_label, code.set_all_labels(self.old_loop_labels + (self.old_return_label,
self.old_error_label)) self.old_error_label))
def end_control_flow_variables_block(self, code): def end_parallel_control_flow_block(self, code,
if not self.any_label_used: break_=False, continue_=False):
return """
This ends the parallel control flow block and based on how the parallel
section was exited, takes the corresponding action. The break_ and
continue_ parameters indicate whether these should be propagated
outwards:
if self.return_label_used: for i in prange(...):
code.put("if (%s) " % Naming.parallel_return) with cython.parallel.parallel():
code.put_goto(code.return_label) continue
Here break should be trapped in the parallel block, and propagated to
the for loop.
"""
if continue_:
any_label_used = self.any_label_used
else:
any_label_used = self.breaking_label_used
if any_label_used:
# __pyx_parallel_why is used, declare and initialize
c = self.begin_of_parallel_control_block_point
c.putln("int __pyx_parallel_why;")
c.putln("__pyx_parallel_why = 0;")
code.putln("switch (__pyx_parallel_why) {")
if self.error_label_used: if continue_:
code.put("if (%s) " % Naming.parallel_error) code.put(" case 1: ")
code.put_goto(code.continue_label)
if break_:
code.put(" case 2: ")
code.put_goto(code.break_label)
code.put(" case 3: ")
code.put_goto(code.return_label)
code.put(" case 4: ")
code.put_goto(code.error_label) code.put_goto(code.error_label)
code.end_block() # end control flow variables block code.putln("}")
code.end_block() # end parallel control flow block
class ParallelWithBlockNode(ParallelStatNode): class ParallelWithBlockNode(ParallelStatNode):
...@@ -6186,7 +6218,7 @@ class ParallelWithBlockNode(ParallelStatNode): ...@@ -6186,7 +6218,7 @@ class ParallelWithBlockNode(ParallelStatNode):
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.declare_closure_privates(code) self.declare_closure_privates(code)
self.setup_control_flow_variables_block(code) # control vars block self.setup_parallel_control_flow_block(code)
code.putln("#ifdef _OPENMP") code.putln("#ifdef _OPENMP")
code.put("#pragma omp parallel ") code.put("#pragma omp parallel ")
...@@ -6207,17 +6239,12 @@ class ParallelWithBlockNode(ParallelStatNode): ...@@ -6207,17 +6239,12 @@ class ParallelWithBlockNode(ParallelStatNode):
self.trap_parallel_exit(code) self.trap_parallel_exit(code)
code.end_block() # end parallel block code.end_block() # end parallel block
self.restore_labels(code) continue_ = code.label_used(code.continue_label)
break_ = code.label_used(code.break_label)
if self.break_label_used:
code.put("if (%s) " % Naming.parallel_break)
code.put_goto(code.break_label)
if self.continue_label_used:
code.put("if (%s) " % Naming.parallel_continue)
code.put_goto(code.continue_label)
self.end_control_flow_variables_block(code) # end control vars block self.restore_labels(code)
self.end_parallel_control_flow_block(code, break_=break_,
continue_=continue_)
self.release_closure_privates(code) self.release_closure_privates(code)
...@@ -6402,12 +6429,9 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6402,12 +6429,9 @@ class ParallelRangeNode(ParallelStatNode):
# 'with gil' block. For now, just abort # 'with gil' block. For now, just abort
code.putln("if (%(step)s == 0) abort();" % fmt_dict) code.putln("if (%(step)s == 0) abort();" % fmt_dict)
self.setup_control_flow_variables_block(code) # control flow vars block self.setup_parallel_control_flow_block(code) # parallel control flow block
if self.need_returning_guard: self.control_flow_var_code_point = code.insertion_point()
self.used_control_flow_vars = "(%s)" % " || ".join([
var for var, label_name, label_used in self.labels
if label_used and label_name != 'continue_label'])
# Note: nsteps is private in an outer scope if present # Note: nsteps is private in an outer scope if present
code.putln("%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;" % fmt_dict) code.putln("%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;" % fmt_dict)
...@@ -6427,8 +6451,8 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6427,8 +6451,8 @@ class ParallelRangeNode(ParallelStatNode):
self.restore_labels(code) self.restore_labels(code)
if self.else_clause: if self.else_clause:
if self.need_returning_guard: if self.breaking_label_used:
code.put("if (!%s)" % self.used_control_flow_vars) code.put("if (__pyx_parallel_why < 2)" )
code.begin_block() # else block code.begin_block() # else block
code.putln("/* else */") code.putln("/* else */")
...@@ -6436,7 +6460,7 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6436,7 +6460,7 @@ class ParallelRangeNode(ParallelStatNode):
code.end_block() # end else block code.end_block() # end else block
# ------ cleanup ------ # ------ cleanup ------
self.end_control_flow_variables_block(code) # end control flow vars block self.end_parallel_control_flow_block(code) # end parallel control flow block
# And finally, release our privates and write back any closure # And finally, release our privates and write back any closure
# variables # variables
...@@ -6487,18 +6511,18 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6487,18 +6511,18 @@ class ParallelRangeNode(ParallelStatNode):
code.put("for (%(i)s = 0; %(i)s < %(nsteps)s; %(i)s++)" % fmt_dict) code.put("for (%(i)s = 0; %(i)s < %(nsteps)s; %(i)s++)" % fmt_dict)
code.begin_block() # for loop block code.begin_block() # for loop block
if self.need_returning_guard: guard_around_body_codepoint = code.insertion_point()
code.put("if (!%s) " % self.used_control_flow_vars)
code.begin_block() # return/break/error guard body block
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict) code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.initialize_privates_to_nan(code, exclude=self.target.entry) self.initialize_privates_to_nan(code, exclude=self.target.entry)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
if self.need_returning_guard:
code.end_block() # return/break/error guard body block
self.trap_parallel_exit(code, should_flush=True) self.trap_parallel_exit(code, should_flush=True)
if self.breaking_label_used:
# Put a guard around the loop body in case return, break or
# exceptions might be used
guard_around_body_codepoint.put("if (__pyx_parallel_why < 2) {")
code.putln("}")
code.end_block() # end for loop block code.end_block() # end for loop block
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment