Commit 27ffe2c8 authored by da-woods's avatar da-woods Committed by Stefan Behnel

Run ParallelRangeTransform also recursively on function arguments (GH-3608)

Closes https://github.com/cython/cython/issues/3594
parent 3b1b45de
...@@ -1161,6 +1161,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1161,6 +1161,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
def visit_CallNode(self, node): def visit_CallNode(self, node):
self.visit(node.function) self.visit(node.function)
if not self.parallel_directive: if not self.parallel_directive:
self.visitchildren(node, exclude=('function',))
return node return node
# We are a parallel directive, replace this node with the # We are a parallel directive, replace this node with the
......
...@@ -8,6 +8,9 @@ from libc.stdlib cimport malloc, free ...@@ -8,6 +8,9 @@ from libc.stdlib cimport malloc, free
openmp.omp_set_nested(1) openmp.omp_set_nested(1)
cdef int forward(int x) nogil:
return x
def test_parallel(): def test_parallel():
""" """
>>> test_parallel() >>> test_parallel()
...@@ -20,6 +23,9 @@ def test_parallel(): ...@@ -20,6 +23,9 @@ def test_parallel():
with nogil, cython.parallel.parallel(): with nogil, cython.parallel.parallel():
buf[threadid()] = threadid() buf[threadid()] = threadid()
# Recognise threadid() also when it's used in a function argument.
# See https://github.com/cython/cython/issues/3594
buf[forward(cython.parallel.threadid())] = forward(threadid())
for i in range(maxthreads): for i in range(maxthreads):
assert buf[i] == i assert buf[i] == i
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment