Commit bbce1d98 authored by Mark Florisson's avatar Mark Florisson

Privatize temporaries and make assignment to __pyx_r critical

parent 086067d0
......@@ -33,6 +33,7 @@ cdef class FunctionState:
cdef public dict temps_free
cdef public dict temps_used_type
cdef public size_t temp_counter
cdef public list collect_temps_stack
cdef public object closure_temps
cdef public bint should_declare_error_indicator
......
......@@ -134,6 +134,10 @@ class FunctionState(object):
self.temp_counter = 0
self.closure_temps = None
# This is used to collect temporaries, useful to find out which temps
# need to be privatized in parallel sections
self.collect_temps_stack = []
# This is used for the error indicator, which needs to be local to the
# function. It used to be global, which relies on the GIL being held.
# However, exceptions may need to be propagated through 'nogil'
......@@ -236,6 +240,10 @@ class FunctionState(object):
self.temps_used_type[result] = (type, manage_ref)
if DebugFlags.debug_temp_code_comments:
self.owner.putln("/* %s allocated */" % result)
if self.collect_temps_stack:
self.collect_temps_stack[-1].add((result, type))
return result
def release_temp(self, name):
......@@ -292,6 +300,15 @@ class FunctionState(object):
if manage_ref
for cname in freelist]
def start_collecting_temps(self):
"""
Useful to find out which temps were used in a code block
"""
self.collect_temps_stack.append(cython.set())
def stop_collecting_temps(self):
return self.collect_temps_stack.pop()
def init_closure_temps(self, scope):
self.closure_temps = ClosureTempAllocator(scope)
......
......@@ -4111,6 +4111,9 @@ class ReturnStatNode(StatNode):
child_attrs = ["value"]
is_terminator = True
# Whether we are in a parallel section
in_parallel = False
def analyse_expressions(self, env):
return_type = env.return_type
self.return_type = return_type
......@@ -4147,24 +4150,25 @@ class ReturnStatNode(StatNode):
if self.value:
self.value.generate_evaluation_code(code)
self.value.make_owned_reference(code)
code.putln(
"%s = %s;" % (
Naming.retval_cname,
self.value.result_as(self.return_type)))
self.put_return(code, self.value.result_as(self.return_type))
self.value.generate_post_assignment_code(code)
self.value.free_temps(code)
else:
if self.return_type.is_pyobject:
code.put_init_to_py_none(Naming.retval_cname, self.return_type)
elif self.return_type.is_returncode:
code.putln(
"%s = %s;" % (
Naming.retval_cname,
self.return_type.default_value))
self.put_return(code, self.return_type.default_value)
for cname, type in code.funcstate.temps_holding_reference():
code.put_decref_clear(cname, type)
code.put_goto(code.return_label)
def put_return(self, code, value):
if self.in_parallel:
code.putln_openmp("#pragma omp critical(__pyx_returning)")
code.putln("%s = %s;" % (Naming.retval_cname, value))
def generate_function_definitions(self, env, code):
if self.value is not None:
self.value.generate_function_definitions(env, code)
......@@ -5812,7 +5816,9 @@ class ParallelStatNode(StatNode, ParallelNode):
assignments to variables in this parallel section
parent parent ParallelStatNode or None
is_parallel indicates whether this is a parallel node
is_parallel indicates whether this node is OpenMP parallel
(true for #pragma omp parallel for and
#pragma omp parallel)
is_parallel is true for:
......@@ -6062,13 +6068,30 @@ class ParallelStatNode(StatNode, ParallelNode):
"""
Release any temps used for variables in scope objects. As this is the
outermost parallel block, we don't need to delete the cnames from
self.seen_closure_vars
self.seen_closure_vars.
"""
for entry, original_cname in self.modified_entries:
code.putln("%s = %s;" % (original_cname, entry.cname))
code.funcstate.release_temp(entry.cname)
entry.cname = original_cname
def privatize_temps(self, code, exclude_temps=()):
"""
Make any used temporaries private. Before the relevant code block
code.start_collecting_temps() should have been called.
"""
if self.is_parallel:
private_cnames = cython.set([e.cname for e in self.privates])
temps = []
for cname, type in code.funcstate.stop_collecting_temps():
if not type.is_pyobject:
temps.append(cname)
if temps:
c = self.privatization_insertion_point
c.put(" private(%s)" % ", ".join(temps))
def setup_parallel_control_flow_block(self, code):
"""
Sets up a block that surrounds the parallel block to determine
......@@ -6236,12 +6259,12 @@ class ParallelWithBlockNode(ParallelStatNode):
code.begin_block() # parallel block
self.initialize_privates_to_nan(code)
code.funcstate.start_collecting_temps()
self.body.generate_execution_code(code)
self.trap_parallel_exit(code)
code.end_block() # end parallel block
# After the parallel block all privates are undefined
self.initialize_privates_to_nan(code)
self.privatize_temps(code)
continue_ = code.label_used(code.continue_label)
break_ = code.label_used(code.break_label)
......@@ -6500,10 +6523,6 @@ class ParallelRangeNode(ParallelStatNode):
if self.schedule:
code.put(" schedule(%s)" % self.schedule)
if self.parent:
c = self.parent.privatization_insertion_point
c.put(" private(%(nsteps)s)" % fmt_dict)
if self.is_parallel:
self.put_num_threads(code)
else:
......@@ -6521,8 +6540,15 @@ class ParallelRangeNode(ParallelStatNode):
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.initialize_privates_to_nan(code, exclude=self.target.entry)
if self.is_parallel:
code.funcstate.start_collecting_temps()
self.body.generate_execution_code(code)
if self.is_parallel:
self.privatize_temps(code)
self.trap_parallel_exit(code, should_flush=True)
if self.breaking_label_used:
# Put a guard around the loop body in case return, break or
......@@ -7747,7 +7773,7 @@ static float %(PYX_NAN)s;
init="""
/* Initialize NaN. The sign is irrelevant, an exponent with all bits 1 and
a nonzero mantissa means NaN. If the first bit in the mantissa is 1, it is
a signalling NaN. */
a quiet NaN. */
memset(&%(PYX_NAN)s, 0xFF, sizeof(%(PYX_NAN)s));
""" % vars(Naming))
......@@ -119,7 +119,7 @@ class MarkAssignments(CythonTransform):
base = sequence,
index = ExprNodes.IntNode(node.pos, value = '0')))
self._visit_loop_node_children(node)
self.visitchildren(node)
return node
def visit_ForFromStatNode(self, node):
......@@ -130,11 +130,11 @@ class MarkAssignments(CythonTransform):
'+',
node.bound1,
node.step))
self._visit_loop_node_children(node)
self.visitchildren(node)
return node
def visit_WhileStatNode(self, node):
self._visit_loop_node_children(node)
self.visitchildren(node)
return node
def visit_ExceptClauseNode(self, node):
......@@ -208,25 +208,9 @@ class MarkAssignments(CythonTransform):
return node
def _visit_loop_node_children(self, node):
"""
Used for the children of "loop nodes", like ForInStatNode, so subnodes
can establish whether break and continue belong to a parallel node
or something else.
"""
child_attrs = node.child_attrs
node.child_attrs = [attr for attr in child_attrs if attr != 'else_clause']
was_in_loop = self.in_loop
self.in_loop = True
self.visitchildren(node)
self.in_loop = was_in_loop
node.child_attrs = child_attrs
if node.else_clause:
node.else_clause = self.visit(node.else_clause)
def visit_ReturnStatNode(self, node):
node.in_parallel = bool(self.parallel_block_stack)
return node
class MarkOverflowingArithmetic(CythonTransform):
......
......@@ -237,7 +237,7 @@ def test_nan_init():
with nogil, cython.parallel.parallel():
c1 = 16
assert c1 not in (16, 20), c1
assert c1 == 20, c1
cdef void nogil_print(char *s) with gil:
print s.decode('ascii')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment