Commit bbce1d98 authored by Mark Florisson's avatar Mark Florisson

Privatize temporaries and make assignment to __pyx_r critical

parent 086067d0
...@@ -33,6 +33,7 @@ cdef class FunctionState: ...@@ -33,6 +33,7 @@ cdef class FunctionState:
cdef public dict temps_free cdef public dict temps_free
cdef public dict temps_used_type cdef public dict temps_used_type
cdef public size_t temp_counter cdef public size_t temp_counter
cdef public list collect_temps_stack
cdef public object closure_temps cdef public object closure_temps
cdef public bint should_declare_error_indicator cdef public bint should_declare_error_indicator
......
...@@ -134,6 +134,10 @@ class FunctionState(object): ...@@ -134,6 +134,10 @@ class FunctionState(object):
self.temp_counter = 0 self.temp_counter = 0
self.closure_temps = None self.closure_temps = None
# This is used to collect temporaries, useful to find out which temps
# need to be privatized in parallel sections
self.collect_temps_stack = []
# This is used for the error indicator, which needs to be local to the # This is used for the error indicator, which needs to be local to the
# function. It used to be global, which relies on the GIL being held. # function. It used to be global, which relies on the GIL being held.
# However, exceptions may need to be propagated through 'nogil' # However, exceptions may need to be propagated through 'nogil'
...@@ -236,6 +240,10 @@ class FunctionState(object): ...@@ -236,6 +240,10 @@ class FunctionState(object):
self.temps_used_type[result] = (type, manage_ref) self.temps_used_type[result] = (type, manage_ref)
if DebugFlags.debug_temp_code_comments: if DebugFlags.debug_temp_code_comments:
self.owner.putln("/* %s allocated */" % result) self.owner.putln("/* %s allocated */" % result)
if self.collect_temps_stack:
self.collect_temps_stack[-1].add((result, type))
return result return result
def release_temp(self, name): def release_temp(self, name):
...@@ -292,6 +300,15 @@ class FunctionState(object): ...@@ -292,6 +300,15 @@ class FunctionState(object):
if manage_ref if manage_ref
for cname in freelist] for cname in freelist]
def start_collecting_temps(self):
"""
Useful to find out which temps were used in a code block
"""
self.collect_temps_stack.append(cython.set())
def stop_collecting_temps(self):
return self.collect_temps_stack.pop()
def init_closure_temps(self, scope): def init_closure_temps(self, scope):
self.closure_temps = ClosureTempAllocator(scope) self.closure_temps = ClosureTempAllocator(scope)
......
...@@ -4111,6 +4111,9 @@ class ReturnStatNode(StatNode): ...@@ -4111,6 +4111,9 @@ class ReturnStatNode(StatNode):
child_attrs = ["value"] child_attrs = ["value"]
is_terminator = True is_terminator = True
# Whether we are in a parallel section
in_parallel = False
def analyse_expressions(self, env): def analyse_expressions(self, env):
return_type = env.return_type return_type = env.return_type
self.return_type = return_type self.return_type = return_type
...@@ -4147,24 +4150,25 @@ class ReturnStatNode(StatNode): ...@@ -4147,24 +4150,25 @@ class ReturnStatNode(StatNode):
if self.value: if self.value:
self.value.generate_evaluation_code(code) self.value.generate_evaluation_code(code)
self.value.make_owned_reference(code) self.value.make_owned_reference(code)
code.putln( self.put_return(code, self.value.result_as(self.return_type))
"%s = %s;" % (
Naming.retval_cname,
self.value.result_as(self.return_type)))
self.value.generate_post_assignment_code(code) self.value.generate_post_assignment_code(code)
self.value.free_temps(code) self.value.free_temps(code)
else: else:
if self.return_type.is_pyobject: if self.return_type.is_pyobject:
code.put_init_to_py_none(Naming.retval_cname, self.return_type) code.put_init_to_py_none(Naming.retval_cname, self.return_type)
elif self.return_type.is_returncode: elif self.return_type.is_returncode:
code.putln( self.put_return(code, self.return_type.default_value)
"%s = %s;" % (
Naming.retval_cname,
self.return_type.default_value))
for cname, type in code.funcstate.temps_holding_reference(): for cname, type in code.funcstate.temps_holding_reference():
code.put_decref_clear(cname, type) code.put_decref_clear(cname, type)
code.put_goto(code.return_label) code.put_goto(code.return_label)
def put_return(self, code, value):
if self.in_parallel:
code.putln_openmp("#pragma omp critical(__pyx_returning)")
code.putln("%s = %s;" % (Naming.retval_cname, value))
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
if self.value is not None: if self.value is not None:
self.value.generate_function_definitions(env, code) self.value.generate_function_definitions(env, code)
...@@ -5812,7 +5816,9 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5812,7 +5816,9 @@ class ParallelStatNode(StatNode, ParallelNode):
assignments to variables in this parallel section assignments to variables in this parallel section
parent parent ParallelStatNode or None parent parent ParallelStatNode or None
is_parallel indicates whether this is a parallel node is_parallel indicates whether this node is OpenMP parallel
(true for #pragma omp parallel for and
#pragma omp parallel)
is_parallel is true for: is_parallel is true for:
...@@ -6062,13 +6068,30 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6062,13 +6068,30 @@ class ParallelStatNode(StatNode, ParallelNode):
""" """
Release any temps used for variables in scope objects. As this is the Release any temps used for variables in scope objects. As this is the
outermost parallel block, we don't need to delete the cnames from outermost parallel block, we don't need to delete the cnames from
self.seen_closure_vars self.seen_closure_vars.
""" """
for entry, original_cname in self.modified_entries: for entry, original_cname in self.modified_entries:
code.putln("%s = %s;" % (original_cname, entry.cname)) code.putln("%s = %s;" % (original_cname, entry.cname))
code.funcstate.release_temp(entry.cname) code.funcstate.release_temp(entry.cname)
entry.cname = original_cname entry.cname = original_cname
def privatize_temps(self, code, exclude_temps=()):
"""
Make any used temporaries private. Before the relevant code block
code.start_collecting_temps() should have been called.
"""
if self.is_parallel:
private_cnames = cython.set([e.cname for e in self.privates])
temps = []
for cname, type in code.funcstate.stop_collecting_temps():
if not type.is_pyobject:
temps.append(cname)
if temps:
c = self.privatization_insertion_point
c.put(" private(%s)" % ", ".join(temps))
def setup_parallel_control_flow_block(self, code): def setup_parallel_control_flow_block(self, code):
""" """
Sets up a block that surrounds the parallel block to determine Sets up a block that surrounds the parallel block to determine
...@@ -6236,12 +6259,12 @@ class ParallelWithBlockNode(ParallelStatNode): ...@@ -6236,12 +6259,12 @@ class ParallelWithBlockNode(ParallelStatNode):
code.begin_block() # parallel block code.begin_block() # parallel block
self.initialize_privates_to_nan(code) self.initialize_privates_to_nan(code)
code.funcstate.start_collecting_temps()
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
self.trap_parallel_exit(code) self.trap_parallel_exit(code)
code.end_block() # end parallel block code.end_block() # end parallel block
# After the parallel block all privates are undefined self.privatize_temps(code)
self.initialize_privates_to_nan(code)
continue_ = code.label_used(code.continue_label) continue_ = code.label_used(code.continue_label)
break_ = code.label_used(code.break_label) break_ = code.label_used(code.break_label)
...@@ -6500,10 +6523,6 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6500,10 +6523,6 @@ class ParallelRangeNode(ParallelStatNode):
if self.schedule: if self.schedule:
code.put(" schedule(%s)" % self.schedule) code.put(" schedule(%s)" % self.schedule)
if self.parent:
c = self.parent.privatization_insertion_point
c.put(" private(%(nsteps)s)" % fmt_dict)
if self.is_parallel: if self.is_parallel:
self.put_num_threads(code) self.put_num_threads(code)
else: else:
...@@ -6521,8 +6540,15 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6521,8 +6540,15 @@ class ParallelRangeNode(ParallelStatNode):
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict) code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.initialize_privates_to_nan(code, exclude=self.target.entry) self.initialize_privates_to_nan(code, exclude=self.target.entry)
if self.is_parallel:
code.funcstate.start_collecting_temps()
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
if self.is_parallel:
self.privatize_temps(code)
self.trap_parallel_exit(code, should_flush=True) self.trap_parallel_exit(code, should_flush=True)
if self.breaking_label_used: if self.breaking_label_used:
# Put a guard around the loop body in case return, break or # Put a guard around the loop body in case return, break or
...@@ -7747,7 +7773,7 @@ static float %(PYX_NAN)s; ...@@ -7747,7 +7773,7 @@ static float %(PYX_NAN)s;
init=""" init="""
/* Initialize NaN. The sign is irrelevant, an exponent with all bits 1 and /* Initialize NaN. The sign is irrelevant, an exponent with all bits 1 and
a nonzero mantissa means NaN. If the first bit in the mantissa is 1, it is a nonzero mantissa means NaN. If the first bit in the mantissa is 1, it is
a signalling NaN. */ a quiet NaN. */
memset(&%(PYX_NAN)s, 0xFF, sizeof(%(PYX_NAN)s)); memset(&%(PYX_NAN)s, 0xFF, sizeof(%(PYX_NAN)s));
""" % vars(Naming)) """ % vars(Naming))
...@@ -119,7 +119,7 @@ class MarkAssignments(CythonTransform): ...@@ -119,7 +119,7 @@ class MarkAssignments(CythonTransform):
base = sequence, base = sequence,
index = ExprNodes.IntNode(node.pos, value = '0'))) index = ExprNodes.IntNode(node.pos, value = '0')))
self._visit_loop_node_children(node) self.visitchildren(node)
return node return node
def visit_ForFromStatNode(self, node): def visit_ForFromStatNode(self, node):
...@@ -130,11 +130,11 @@ class MarkAssignments(CythonTransform): ...@@ -130,11 +130,11 @@ class MarkAssignments(CythonTransform):
'+', '+',
node.bound1, node.bound1,
node.step)) node.step))
self._visit_loop_node_children(node) self.visitchildren(node)
return node return node
def visit_WhileStatNode(self, node): def visit_WhileStatNode(self, node):
self._visit_loop_node_children(node) self.visitchildren(node)
return node return node
def visit_ExceptClauseNode(self, node): def visit_ExceptClauseNode(self, node):
...@@ -208,25 +208,9 @@ class MarkAssignments(CythonTransform): ...@@ -208,25 +208,9 @@ class MarkAssignments(CythonTransform):
return node return node
def _visit_loop_node_children(self, node): def visit_ReturnStatNode(self, node):
""" node.in_parallel = bool(self.parallel_block_stack)
Used for the children of "loop nodes", like ForInStatNode, so subnodes return node
can establish whether break and continue belong to a parallel node
or something else.
"""
child_attrs = node.child_attrs
node.child_attrs = [attr for attr in child_attrs if attr != 'else_clause']
was_in_loop = self.in_loop
self.in_loop = True
self.visitchildren(node)
self.in_loop = was_in_loop
node.child_attrs = child_attrs
if node.else_clause:
node.else_clause = self.visit(node.else_clause)
class MarkOverflowingArithmetic(CythonTransform): class MarkOverflowingArithmetic(CythonTransform):
......
...@@ -237,7 +237,7 @@ def test_nan_init(): ...@@ -237,7 +237,7 @@ def test_nan_init():
with nogil, cython.parallel.parallel(): with nogil, cython.parallel.parallel():
c1 = 16 c1 = 16
assert c1 not in (16, 20), c1 assert c1 == 20, c1
cdef void nogil_print(char *s) with gil: cdef void nogil_print(char *s) with gil:
print s.decode('ascii') print s.decode('ascii')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment