Commit f0cfc682 authored by Mark Florisson's avatar Mark Florisson

Clean up memoryview slice temporaries in parallel sections

parent 48188cde
......@@ -7229,10 +7229,10 @@ class ParallelStatNode(StatNode, ParallelNode):
if self.is_parallel:
c = self.privatization_insertion_point
temps = code.funcstate.stop_collecting_temps()
self.temps = temps = code.funcstate.stop_collecting_temps()
privates, firstprivates = [], []
for temp, type in temps:
if type.is_pyobject:
if type.is_pyobject or type.is_memoryviewslice:
firstprivates.append(temp)
else:
privates.append(temp)
......@@ -7250,6 +7250,16 @@ class ParallelStatNode(StatNode, ParallelNode):
c.put(" shared(%s)" % ', '.join(shared_vars))
def cleanup_slice_temps(self, code):
# Now clean up any memoryview slice temporaries
first = True
for temp, type in self.temps:
if type.is_memoryviewslice:
if first:
first = False
code.putln("/* Clean up any temporary slices */")
code.put_xdecref_memoryviewslice(temp, have_gil=False)
def setup_parallel_control_flow_block(self, code):
"""
Sets up a block that surrounds the parallel block to determine
......@@ -7311,6 +7321,8 @@ class ParallelStatNode(StatNode, ParallelNode):
end_code.put_release_ensured_gil()
end_code.putln("#endif /* _OPENMP */")
self.cleanup_slice_temps(code)
def trap_parallel_exit(self, code, should_flush=False):
"""
Trap any kind of return inside a parallel construct. 'should_flush'
......
......@@ -6,7 +6,7 @@ from __future__ import unicode_literals
cimport cython
from cython cimport view
from cython.parallel cimport prange
from cython.parallel cimport prange, parallel
import sys
import re
......@@ -1562,6 +1562,48 @@ def test_memslice_prange(arg):
for k in range(src.shape[2]):
assert src[i, j, k] == dst[i, j, k], (src[i, j, k] == dst[i, j, k])
@testcase
def test_clean_temps_prange(int[:, :] buf):
"""
Try to access a buffer out of bounds in a parallel section, and make sure any
temps used by the slicing processes are correctly counted.
>>> A = IntMockBuffer("A", range(100), (10, 10))
>>> test_clean_temps_prange(A)
acquired A
released A
"""
cdef int i
try:
for i in prange(buf.shape[0], nogil=True):
buf[1:10, 20] = 0
except IndexError:
pass
@testcase
def test_clean_temps_parallel(int[:, :] buf):
"""
Try to access a buffer out of bounds in a parallel section, and make sure any
temps used by the slicing processes are correctly counted.
>>> A = IntMockBuffer("A", range(100), (10, 10))
>>> test_clean_temps_parallel(A)
acquired A
released A
"""
cdef int i
try:
with nogil, parallel():
try:
with gil: pass
for i in prange(buf.shape[0]):
buf[1:10, 20] = 0
finally:
buf[1:10, 20] = 0
except IndexError:
pass
# Test arrays in structs
cdef struct ArrayStruct:
int ints[10]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment