Commit 8e05843a authored by Mark Florisson's avatar Mark Florisson

Support break/continue/return in parallel and prange, ready for exceptions

parent 53e4c102
...@@ -97,6 +97,14 @@ binding_cfunc = pyrex_prefix + "binding_PyCFunctionType" ...@@ -97,6 +97,14 @@ binding_cfunc = pyrex_prefix + "binding_PyCFunctionType"
genexpr_id_ref = 'genexpr' genexpr_id_ref = 'genexpr'
# These labels are needed for break, return and exception paths through
# parallel code sections. These are the variables that remember which path
# should be taken after the parallel section
parallel_error = pyrex_prefix + "parallel_error"
parallel_return = pyrex_prefix + "parallel_return"
parallel_break = pyrex_prefix + "parallel_break"
parallel_continue = pyrex_prefix + "parallel_continue"
line_c_macro = "__LINE__" line_c_macro = "__LINE__"
file_c_macro = "__FILE__" file_c_macro = "__FILE__"
......
...@@ -5839,6 +5839,13 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5839,6 +5839,13 @@ class ParallelStatNode(StatNode, ParallelNode):
is_prange = False is_prange = False
# Labels to break out of parallel constructs
break_label_used = False
continue_label_used = False
error_label_used = False
return_label_used = False
def __init__(self, pos, **kwargs): def __init__(self, pos, **kwargs):
super(ParallelStatNode, self).__init__(pos, **kwargs) super(ParallelStatNode, self).__init__(pos, **kwargs)
...@@ -6062,6 +6069,103 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6062,6 +6069,103 @@ class ParallelStatNode(StatNode, ParallelNode):
code.funcstate.release_temp(entry.cname) code.funcstate.release_temp(entry.cname)
entry.cname = original_cname entry.cname = original_cname
def setup_control_flow_variables_block(self, code):
"""
Sets up any needed variables outside the parallel block to determine
how the parallel block was left. Inside the any kind of return is
trapped (break, continue, return, exceptions). This is the idea:
{
int returning = 0;
#pragma omp parallel
{
return # -> goto new_return_label;
goto end_parallel;
new_return_label:
returning = 1;
#pragma omp flush(returning)
goto end_parallel;
end_parallel:;
}
if (returning)
goto old_return_label;
}
"""
self.need_returning_guard = (self.break_label_used or
self.error_label_used or
self.return_label_used)
self.any_label_used = (self.need_returning_guard or
self.continue_label_used)
if not self.any_label_used:
return
self.old_loop_labels = code.new_loop_labels()
self.old_error_label = code.new_error_label()
self.old_return_label = code.return_label
code.return_label = code.new_label(name="return")
self.labels = (
(Naming.parallel_break, 'break_label', self.break_label_used),
(Naming.parallel_continue, 'continue_label', self.continue_label_used),
(Naming.parallel_error, 'error_label', self.error_label_used),
(Naming.parallel_return, 'return_label', self.return_label_used),
)
code.begin_block() # control flow variables block
for var, label_name, label_used in self.labels:
if label_used:
code.putln("int %s = 0;" % var)
def trap_parallel_exit(self, code, should_flush=False):
"""
Trap any kind of return inside a parallel construct. 'should_flush'
indicates whether the variable should be flushed, which is needed by
prange to skip the loop.
"""
if not self.any_label_used:
return
dont_return_label = code.new_label()
code.put_goto(dont_return_label)
for var, label_name, label_used in self.labels:
if label_used:
label = getattr(code, label_name)
code.put_label(label)
code.putln("%s = 1;" % var)
if should_flush:
code.putln_openmp("#pragma omp flush(%s)" % var)
code.put_goto(dont_return_label)
code.put_label(dont_return_label)
def restore_labels(self, code):
if self.any_label_used:
code.set_all_labels(self.old_loop_labels + (self.old_return_label,
self.old_error_label))
def end_control_flow_variables_block(self, code):
if not self.any_label_used:
return
if self.return_label_used:
code.put("if (%s) " % Naming.parallel_return)
code.put_goto(code.return_label)
if self.error_label_used:
code.put("if (%s) " % Naming.parallel_error)
code.put_goto(code.error_label)
code.end_block() # end control flow variables block
class ParallelWithBlockNode(ParallelStatNode): class ParallelWithBlockNode(ParallelStatNode):
""" """
...@@ -6082,13 +6186,14 @@ class ParallelWithBlockNode(ParallelStatNode): ...@@ -6082,13 +6186,14 @@ class ParallelWithBlockNode(ParallelStatNode):
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.declare_closure_privates(code) self.declare_closure_privates(code)
self.setup_control_flow_variables_block(code) # control vars block
code.putln("#ifdef _OPENMP") code.putln("#ifdef _OPENMP")
code.put("#pragma omp parallel ") code.put("#pragma omp parallel ")
if self.privates: if self.privates:
code.put( privates = [e.cname for e in self.privates]
'private(%s)' % ', '.join([e.cname for e in self.privates])) code.put('private(%s)' % ', '.join(privates))
self.privatization_insertion_point = code.insertion_point() self.privatization_insertion_point = code.insertion_point()
self.put_num_threads(code) self.put_num_threads(code)
...@@ -6096,11 +6201,23 @@ class ParallelWithBlockNode(ParallelStatNode): ...@@ -6096,11 +6201,23 @@ class ParallelWithBlockNode(ParallelStatNode):
code.putln("#endif /* _OPENMP */") code.putln("#endif /* _OPENMP */")
code.begin_block() code.begin_block() # parallel block
self.initialize_privates_to_nan(code) self.initialize_privates_to_nan(code)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
code.end_block() self.trap_parallel_exit(code)
code.end_block() # end parallel block
self.restore_labels(code)
if self.break_label_used:
code.put("if (%s) " % Naming.parallel_break)
code.put_goto(code.break_label)
if self.continue_label_used:
code.put("if (%s) " % Naming.parallel_continue)
code.put_goto(code.continue_label)
self.end_control_flow_variables_block(code) # end control vars block
self.release_closure_privates(code) self.release_closure_privates(code)
...@@ -6285,6 +6402,13 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6285,6 +6402,13 @@ class ParallelRangeNode(ParallelStatNode):
# 'with gil' block. For now, just abort # 'with gil' block. For now, just abort
code.putln("if (%(step)s == 0) abort();" % fmt_dict) code.putln("if (%(step)s == 0) abort();" % fmt_dict)
self.setup_control_flow_variables_block(code) # control flow vars block
if self.need_returning_guard:
self.used_control_flow_vars = "(%s)" % " || ".join([
var for var, label_name, label_used in self.labels
if label_used and label_name != 'continue_label'])
# Note: nsteps is private in an outer scope if present # Note: nsteps is private in an outer scope if present
code.putln("%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;" % fmt_dict) code.putln("%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;" % fmt_dict)
...@@ -6295,12 +6419,24 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6295,12 +6419,24 @@ class ParallelRangeNode(ParallelStatNode):
# erroneously believes that nsteps may be <= 0, leaving the private # erroneously believes that nsteps may be <= 0, leaving the private
# target index uninitialized # target index uninitialized
code.putln("if (%(nsteps)s > 0)" % fmt_dict) code.putln("if (%(nsteps)s > 0)" % fmt_dict)
code.begin_block() code.begin_block() # if block
code.putln("%(target)s = 0;" % fmt_dict) code.putln("%(target)s = 0;" % fmt_dict)
self.generate_loop(code, fmt_dict) self.generate_loop(code, fmt_dict)
code.end_block() # end if block
code.end_block() self.restore_labels(code)
if self.else_clause:
if self.need_returning_guard:
code.put("if (!%s)" % self.used_control_flow_vars)
code.begin_block() # else block
code.putln("/* else */")
self.else_clause.generate_execution_code(code)
code.end_block() # end else block
# ------ cleanup ------
self.end_control_flow_variables_block(code) # end control flow vars block
# And finally, release our privates and write back any closure # And finally, release our privates and write back any closure
# variables # variables
...@@ -6338,7 +6474,10 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6338,7 +6474,10 @@ class ParallelRangeNode(ParallelStatNode):
c = self.parent.privatization_insertion_point c = self.parent.privatization_insertion_point
c.put(" private(%(nsteps)s)" % fmt_dict) c.put(" private(%(nsteps)s)" % fmt_dict)
if self.is_parallel:
self.put_num_threads(code) self.put_num_threads(code)
else:
self.put_num_threads(self.parent.privatization_insertion_point)
self.privatization_insertion_point = code.insertion_point() self.privatization_insertion_point = code.insertion_point()
...@@ -6346,13 +6485,22 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6346,13 +6485,22 @@ class ParallelRangeNode(ParallelStatNode):
code.putln("#endif /* _OPENMP */") code.putln("#endif /* _OPENMP */")
code.put("for (%(i)s = 0; %(i)s < %(nsteps)s; %(i)s++)" % fmt_dict) code.put("for (%(i)s = 0; %(i)s < %(nsteps)s; %(i)s++)" % fmt_dict)
code.begin_block() code.begin_block() # for loop block
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.initialize_privates_to_nan(code, exclude=self.target.entry) if self.need_returning_guard:
code.put("if (!%s) " % self.used_control_flow_vars)
code.begin_block() # return/break/error guard body block
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.initialize_privates_to_nan(code, exclude=self.target.entry)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
code.end_block()
if self.need_returning_guard:
code.end_block() # return/break/error guard body block
self.trap_parallel_exit(code, should_flush=True)
code.end_block() # end for loop block
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
......
...@@ -1005,8 +1005,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1005,8 +1005,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
Node class. Node class.
E.g. for a cython.parallel.prange() call we return ParallelRangeNode E.g. for a cython.parallel.prange() call we return ParallelRangeNode
Also disallow break, continue and return in a prange section
""" """
if self.namenode_is_cython_module: if self.namenode_is_cython_module:
directive = '.'.join(self.parallel_directive) directive = '.'.join(self.parallel_directive)
...@@ -1032,7 +1030,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1032,7 +1030,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
""" """
if node.parallel_directives: if node.parallel_directives:
self.parallel_directives = node.parallel_directives self.parallel_directives = node.parallel_directives
self.assignment_stack = []
return self.visit_Node(node) return self.visit_Node(node)
# No parallel directives were imported, so they can't be used :) # No parallel directives were imported, so they can't be used :)
...@@ -1136,23 +1133,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1136,23 +1133,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
self.visit(node.else_clause) self.visit(node.else_clause)
return node return node
def ensure_not_in_prange(name):
"Creates error checking functions for break, continue and return"
def visit_method(self, node):
if self.in_prange:
error(node.pos,
name + " not allowed in a parallel range section")
# Do a visit for 'return'
self.visitchildren(node)
return node
return visit_method
visit_BreakStatNode = ensure_not_in_prange("break")
visit_ContinueStatNode = ensure_not_in_prange("continue")
visit_ReturnStatNode = ensure_not_in_prange("return")
def visit(self, node): def visit(self, node):
"Visit a node that may be None" "Visit a node that may be None"
if node is not None: if node is not None:
......
...@@ -178,8 +178,51 @@ class MarkAssignments(CythonTransform): ...@@ -178,8 +178,51 @@ class MarkAssignments(CythonTransform):
node.is_parallel = True node.is_parallel = True
self.parallel_block_stack.append(node) self.parallel_block_stack.append(node)
if node.is_prange:
child_attrs = node.child_attrs
node.child_attrs = ['body', 'target', 'args']
self.visitchildren(node)
node.child_attrs = child_attrs
self.parallel_block_stack.pop()
if node.else_clause:
node.else_clause = self.visit(node.else_clause)
else:
self.visitchildren(node) self.visitchildren(node)
self.parallel_block_stack.pop() self.parallel_block_stack.pop()
return node
def visit_BreakStatNode(self, node):
parnode = self.parallel_block_stack[-1]
parnode.break_label_used = True
if not parnode.is_prange and parnode.parent:
parnode.parent.break_label_used = True
return node
def visit_ContinueStatNode(self, node):
parnode = self.parallel_block_stack[-1]
parnode.continue_label_used = True
if not parnode.is_prange and parnode.parent:
parnode.parent.continue_label_used = True
return node
def visit_ReturnStatNode(self, node):
for parnode in self.parallel_block_stack:
parnode.return_label_used = True
return node
def visit_GilStatNode(self, node):
if node.state == 'gil':
for parnode in self.parallel_block_stack:
parnode.error_label_used = True
return node return node
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
cimport cython.parallel cimport cython.parallel
from cython.parallel import prange, threadid from cython.parallel import prange, threadid
from libc.stdlib cimport malloc, free, abort from libc.stdlib cimport malloc, calloc, free, abort
from libc.stdio cimport puts from libc.stdio cimport puts
import sys import sys
...@@ -232,3 +232,114 @@ def test_nan_init(): ...@@ -232,3 +232,114 @@ def test_nan_init():
if err: if err:
raise Exception("One of the values was not initialized to a maximum " raise Exception("One of the values was not initialized to a maximum "
"or NaN value") "or NaN value")
cdef void nogil_print(char *s) with gil:
print s
def test_else_clause():
"""
>>> test_else_clause()
else clause executed
"""
cdef int i
for i in prange(5, nogil=True):
pass
else:
nogil_print('else clause executed')
def test_prange_break():
"""
>>> test_prange_break()
"""
cdef int i
for i in prange(10, nogil=True):
if i == 8:
break
else:
nogil_print('else clause executed')
def test_prange_continue():
"""
>>> test_prange_continue()
else clause executed
0 0
1 0
2 2
3 0
4 4
5 0
6 6
7 0
8 8
9 0
"""
cdef int i
cdef int *p = <int *> calloc(10, sizeof(int))
if p == NULL:
raise MemoryError
for i in prange(10, nogil=True):
if i % 2 != 0:
continue
p[i] = i
else:
nogil_print('else clause executed')
for i in range(10):
print i, p[i]
free(p)
def test_nested_break_continue():
"""
>>> test_nested_break_continue()
6 7 6 7
8
"""
cdef int i, j, result1 = 0, result2 = 0
for i in prange(10, nogil=True, num_threads=2, schedule='static'):
for j in prange(10, num_threads=2, schedule='static'):
if i == 6 and j == 7:
result1 = i
result2 = j
break
else:
continue
break
print i, j, result1, result2
with nogil, cython.parallel.parallel():
for i in prange(10, num_threads=2, schedule='static'):
with cython.parallel.parallel():
if i == 8:
break
else:
continue
print i
cdef int parallel_return() nogil:
cdef int i
for i in prange(10):
if i == 8:
return i
else:
return 1
return 2
def test_return():
"""
>>> test_return()
8
"""
print parallel_return()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment