Commit 468c17af authored by Mark Florisson's avatar Mark Florisson

Use a switch like try/finally for parallel break/continue/return/error

parent f70465bd
......@@ -97,14 +97,6 @@ binding_cfunc = pyrex_prefix + "binding_PyCFunctionType"
genexpr_id_ref = 'genexpr'
# These labels are needed for break, return and exception paths through
# parallel code sections. These are the variables that remember which path
# should be taken after the parallel section
parallel_error = pyrex_prefix + "parallel_error"
parallel_return = pyrex_prefix + "parallel_return"
parallel_break = pyrex_prefix + "parallel_break"
parallel_continue = pyrex_prefix + "parallel_continue"
line_c_macro = "__LINE__"
file_c_macro = "__FILE__"
......
......@@ -6069,14 +6069,14 @@ class ParallelStatNode(StatNode, ParallelNode):
code.funcstate.release_temp(entry.cname)
entry.cname = original_cname
def setup_control_flow_variables_block(self, code):
def setup_parallel_control_flow_block(self, code):
"""
Sets up any needed variables outside the parallel block to determine
how the parallel block was left. Inside the any kind of return is
Sets up a block that surrounds the parallel block to determine
how the parallel section was exited. Any kind of return is
trapped (break, continue, return, exceptions). This is the idea:
{
int returning = 0;
int why = 0;
#pragma omp parallel
{
......@@ -6084,87 +6084,119 @@ class ParallelStatNode(StatNode, ParallelNode):
goto end_parallel;
new_return_label:
returning = 1;
#pragma omp flush(returning)
why = 3;
goto end_parallel;
end_parallel:;
#pragma omp flush(why) # we need to flush for every iteration
}
if (returning)
if (why == 3)
goto old_return_label;
}
"""
self.need_returning_guard = (self.break_label_used or
self.error_label_used or
self.return_label_used)
self.any_label_used = (self.need_returning_guard or
self.continue_label_used)
if not self.any_label_used:
return
self.old_loop_labels = code.new_loop_labels()
self.old_error_label = code.new_error_label()
self.old_return_label = code.return_label
code.return_label = code.new_label(name="return")
self.labels = (
(Naming.parallel_break, 'break_label', self.break_label_used),
(Naming.parallel_continue, 'continue_label', self.continue_label_used),
(Naming.parallel_error, 'error_label', self.error_label_used),
(Naming.parallel_return, 'return_label', self.return_label_used),
)
code.begin_block() # control flow variables block
for var, label_name, label_used in self.labels:
if label_used:
code.putln("int %s = 0;" % var)
code.begin_block() # parallel control flow block
self.begin_of_parallel_control_block_point = code.insertion_point()
def trap_parallel_exit(self, code, should_flush=False):
"""
Trap any kind of return inside a parallel construct. 'should_flush'
indicates whether the variable should be flushed, which is needed by
prange to skip the loop.
prange to skip the loop. It also indicates whether we need to register
a continue (we need this for parallel blocks, but not for prange
loops, as it is a direct jump there).
It uses the same mechanism as try/finally:
1 continue
2 break
3 return
4 error
"""
if not self.any_label_used:
return
dont_return_label = code.new_label()
code.put_goto(dont_return_label)
insertion_point = code.insertion_point()
self.any_label_used = False
self.breaking_label_used = False
for i, label in enumerate(code.get_all_labels()):
if code.label_used(label):
self.any_label_used = True
is_continue_label = label == code.continue_label
self.breaking_label_used = (self.breaking_label_used or
not is_continue_label)
for var, label_name, label_used in self.labels:
if label_used:
label = getattr(code, label_name)
code.put_label(label)
code.putln("%s = 1;" % var)
if should_flush:
code.putln_openmp("#pragma omp flush(%s)" % var)
if not (should_flush and is_continue_label):
code.putln("__pyx_parallel_why = %d;" % (i + 1))
code.put_goto(dont_return_label)
if self.any_label_used:
insertion_point.funcstate = code.funcstate
insertion_point.put_goto(dont_return_label)
code.put_label(dont_return_label)
if should_flush and self.breaking_label_used:
code.putln_openmp("#pragma omp flush(__pyx_parallel_why)")
def restore_labels(self, code):
if self.any_label_used:
"""
Restore all old labels. Call this before the 'else' clause to for
loops and always before ending the parallel control flow block.
"""
code.set_all_labels(self.old_loop_labels + (self.old_return_label,
self.old_error_label))
def end_control_flow_variables_block(self, code):
if not self.any_label_used:
return
def end_parallel_control_flow_block(self, code,
break_=False, continue_=False):
"""
This ends the parallel control flow block and based on how the parallel
section was exited, takes the corresponding action. The break_ and
continue_ parameters indicate whether these should be propagated
outwards:
if self.return_label_used:
code.put("if (%s) " % Naming.parallel_return)
code.put_goto(code.return_label)
for i in prange(...):
with cython.parallel.parallel():
continue
Here break should be trapped in the parallel block, and propagated to
the for loop.
"""
if continue_:
any_label_used = self.any_label_used
else:
any_label_used = self.breaking_label_used
if any_label_used:
# __pyx_parallel_why is used, declare and initialize
c = self.begin_of_parallel_control_block_point
c.putln("int __pyx_parallel_why;")
c.putln("__pyx_parallel_why = 0;")
code.putln("switch (__pyx_parallel_why) {")
if self.error_label_used:
code.put("if (%s) " % Naming.parallel_error)
if continue_:
code.put(" case 1: ")
code.put_goto(code.continue_label)
if break_:
code.put(" case 2: ")
code.put_goto(code.break_label)
code.put(" case 3: ")
code.put_goto(code.return_label)
code.put(" case 4: ")
code.put_goto(code.error_label)
code.end_block() # end control flow variables block
code.putln("}")
code.end_block() # end parallel control flow block
class ParallelWithBlockNode(ParallelStatNode):
......@@ -6186,7 +6218,7 @@ class ParallelWithBlockNode(ParallelStatNode):
def generate_execution_code(self, code):
self.declare_closure_privates(code)
self.setup_control_flow_variables_block(code) # control vars block
self.setup_parallel_control_flow_block(code)
code.putln("#ifdef _OPENMP")
code.put("#pragma omp parallel ")
......@@ -6207,17 +6239,12 @@ class ParallelWithBlockNode(ParallelStatNode):
self.trap_parallel_exit(code)
code.end_block() # end parallel block
self.restore_labels(code)
if self.break_label_used:
code.put("if (%s) " % Naming.parallel_break)
code.put_goto(code.break_label)
if self.continue_label_used:
code.put("if (%s) " % Naming.parallel_continue)
code.put_goto(code.continue_label)
continue_ = code.label_used(code.continue_label)
break_ = code.label_used(code.break_label)
self.end_control_flow_variables_block(code) # end control vars block
self.restore_labels(code)
self.end_parallel_control_flow_block(code, break_=break_,
continue_=continue_)
self.release_closure_privates(code)
......@@ -6402,12 +6429,9 @@ class ParallelRangeNode(ParallelStatNode):
# 'with gil' block. For now, just abort
code.putln("if (%(step)s == 0) abort();" % fmt_dict)
self.setup_control_flow_variables_block(code) # control flow vars block
self.setup_parallel_control_flow_block(code) # parallel control flow block
if self.need_returning_guard:
self.used_control_flow_vars = "(%s)" % " || ".join([
var for var, label_name, label_used in self.labels
if label_used and label_name != 'continue_label'])
self.control_flow_var_code_point = code.insertion_point()
# Note: nsteps is private in an outer scope if present
code.putln("%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;" % fmt_dict)
......@@ -6427,8 +6451,8 @@ class ParallelRangeNode(ParallelStatNode):
self.restore_labels(code)
if self.else_clause:
if self.need_returning_guard:
code.put("if (!%s)" % self.used_control_flow_vars)
if self.breaking_label_used:
code.put("if (__pyx_parallel_why < 2)" )
code.begin_block() # else block
code.putln("/* else */")
......@@ -6436,7 +6460,7 @@ class ParallelRangeNode(ParallelStatNode):
code.end_block() # end else block
# ------ cleanup ------
self.end_control_flow_variables_block(code) # end control flow vars block
self.end_parallel_control_flow_block(code) # end parallel control flow block
# And finally, release our privates and write back any closure
# variables
......@@ -6487,18 +6511,18 @@ class ParallelRangeNode(ParallelStatNode):
code.put("for (%(i)s = 0; %(i)s < %(nsteps)s; %(i)s++)" % fmt_dict)
code.begin_block() # for loop block
if self.need_returning_guard:
code.put("if (!%s) " % self.used_control_flow_vars)
code.begin_block() # return/break/error guard body block
guard_around_body_codepoint = code.insertion_point()
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.initialize_privates_to_nan(code, exclude=self.target.entry)
self.body.generate_execution_code(code)
if self.need_returning_guard:
code.end_block() # return/break/error guard body block
self.trap_parallel_exit(code, should_flush=True)
if self.breaking_label_used:
# Put a guard around the loop body in case return, break or
# exceptions might be used
guard_around_body_codepoint.put("if (__pyx_parallel_why < 2) {")
code.putln("}")
code.end_block() # end for loop block
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment