Commit cf2290aa authored by Mark Florisson's avatar Mark Florisson

Mandate calling of cython.parallel.parallel

parent 9e3798fd
...@@ -5736,7 +5736,7 @@ class ParallelNode(Node): ...@@ -5736,7 +5736,7 @@ class ParallelNode(Node):
class ParallelStatNode(StatNode, ParallelNode): class ParallelStatNode(StatNode, ParallelNode):
""" """
Base class for 'with cython.parallel.parallel:' and 'for i in prange():'. Base class for 'with cython.parallel.parallel():' and 'for i in prange():'.
assignments { Entry(var) : (var.pos, inplace_operator_or_None) } assignments { Entry(var) : (var.pos, inplace_operator_or_None) }
assignments to variables in this parallel section assignments to variables in this parallel section
...@@ -5920,7 +5920,7 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5920,7 +5920,7 @@ class ParallelStatNode(StatNode, ParallelNode):
class ParallelWithBlockNode(ParallelStatNode): class ParallelWithBlockNode(ParallelStatNode):
""" """
This node represents a 'with cython.parallel.parallel:' block This node represents a 'with cython.parallel.parallel():' block
""" """
nogil_check = None nogil_check = None
......
...@@ -963,7 +963,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -963,7 +963,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
module node, set there by InterpretCompilerDirectives. module node, set there by InterpretCompilerDirectives.
x = cython.parallel.threadavailable() -> ParallelThreadAvailableNode x = cython.parallel.threadavailable() -> ParallelThreadAvailableNode
with nogil, cython.parallel.parallel: -> ParallelWithBlockNode with nogil, cython.parallel.parallel(): -> ParallelWithBlockNode
print cython.parallel.threadid() -> ParallelThreadIdNode print cython.parallel.threadid() -> ParallelThreadIdNode
for i in cython.parallel.prange(...): -> ParallelRangeNode for i in cython.parallel.prange(...): -> ParallelRangeNode
... ...
...@@ -1064,36 +1064,40 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1064,36 +1064,40 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
parallel_directive_class = self.get_directive_class_node(node) parallel_directive_class = self.get_directive_class_node(node)
if parallel_directive_class: if parallel_directive_class:
# Note: in case of a parallel() the body is set by
# visit_WithStatNode
node = parallel_directive_class(node.pos, args=args, kwargs=kwargs) node = parallel_directive_class(node.pos, args=args, kwargs=kwargs)
return node return node
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
"Rewrite with cython.parallel() blocks" "Rewrite with cython.parallel.parallel() blocks"
self.visit(node.manager) newnode = self.visit(node.manager)
if self.parallel_directive:
parallel_directive_class = self.get_directive_class_node(node)
if not parallel_directive_class:
# There was an error, stop here and now
return None
if isinstance(newnode, Nodes.ParallelWithBlockNode):
if self.state == 'parallel with': if self.state == 'parallel with':
error(node.manager.pos, error(node.manager.pos,
"Closely nested 'with parallel:' blocks are disallowed") "Closely nested 'with parallel:' blocks are disallowed")
self.state = 'parallel with' self.state = 'parallel with'
self.visit(node.body) body = self.visit(node.body)
self.state = None self.state = None
newnode = Nodes.ParallelWithBlockNode(node.pos, body=node.body) newnode.body = body
return newnode
elif self.parallel_directive:
parallel_directive_class = self.get_directive_class_node(node)
else: if not parallel_directive_class:
newnode = node # There was an error, stop here and now
return None
self.visit(node.body) if parallel_directive_class is Nodes.ParallelWithBlockNode:
error(node.pos, "The parallel directive must be called")
return None
return newnode node.body = self.visit(node.body)
return node
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
"Rewrite 'for i in cython.parallel.prange(...):'" "Rewrite 'for i in cython.parallel.prange(...):'"
...@@ -1149,7 +1153,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1149,7 +1153,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
def visit(self, node): def visit(self, node):
"Visit a node that may be None" "Visit a node that may be None"
if node is not None: if node is not None:
super(ParallelRangeTransform, self).visit(node) return super(ParallelRangeTransform, self).visit(node)
class WithTransform(CythonTransform, SkipDeclarations): class WithTransform(CythonTransform, SkipDeclarations):
......
...@@ -14,16 +14,16 @@ prange(1, 2, 3, schedule='dynamic') ...@@ -14,16 +14,16 @@ prange(1, 2, 3, schedule='dynamic')
cdef int i cdef int i
with nogil, cython.parallel.parallel: with nogil, cython.parallel.parallel():
for i in prange(10, schedule='invalid_schedule'): for i in prange(10, schedule='invalid_schedule'):
pass pass
with cython.parallel.parallel: with cython.parallel.parallel():
print "hello world!" print "hello world!"
cdef int *x = NULL cdef int *x = NULL
with nogil, cython.parallel.parallel: with nogil, cython.parallel.parallel():
for j in prange(10): for j in prange(10):
pass pass
...@@ -33,9 +33,12 @@ with nogil, cython.parallel.parallel: ...@@ -33,9 +33,12 @@ with nogil, cython.parallel.parallel:
for x in prange(10): for x in prange(10):
pass pass
with cython.parallel.parallel: with cython.parallel.parallel():
pass pass
with nogil, cython.parallel.parallel:
pass
_ERRORS = u""" _ERRORS = u"""
e_cython_parallel.pyx:3:8: cython.parallel.parallel is not a module e_cython_parallel.pyx:3:8: cython.parallel.parallel is not a module
e_cython_parallel.pyx:4:0: No such directive: cython.parallel.something e_cython_parallel.pyx:4:0: No such directive: cython.parallel.something
...@@ -44,9 +47,10 @@ e_cython_parallel.pyx:7:0: No such directive: cython.parallel.something ...@@ -44,9 +47,10 @@ e_cython_parallel.pyx:7:0: No such directive: cython.parallel.something
e_cython_parallel.pyx:13:6: prange() can only be used as part of a for loop e_cython_parallel.pyx:13:6: prange() can only be used as part of a for loop
e_cython_parallel.pyx:13:6: prange() can only be used without the GIL e_cython_parallel.pyx:13:6: prange() can only be used without the GIL
e_cython_parallel.pyx:18:19: Invalid schedule argument to prange: invalid_schedule e_cython_parallel.pyx:18:19: Invalid schedule argument to prange: invalid_schedule
e_cython_parallel.pyx:21:5: The parallel section may only be used without the GIL c_cython_parallel.pyx:21:29: The parallel section may only be used without the GIL
e_cython_parallel.pyx:27:10: target may not be a Python object as we don't have the GIL e_cython_parallel.pyx:27:10: target may not be a Python object as we don't have the GIL
e_cython_parallel.pyx:30:9: Can only iterate over an iteration variable e_cython_parallel.pyx:30:9: Can only iterate over an iteration variable
e_cython_parallel.pyx:33:10: Must be of numeric type, not int * e_cython_parallel.pyx:33:10: Must be of numeric type, not int *
e_cython_parallel.pyx:36:24: Closely nested 'with parallel:' blocks are disallowed e_cython_parallel.pyx:36:33: Closely nested 'with parallel:' blocks are disallowed
e_cython_parallel.pyx:39:12: The parallel directive must be called
""" """
...@@ -16,7 +16,7 @@ def test_parallel(): ...@@ -16,7 +16,7 @@ def test_parallel():
if buf == NULL: if buf == NULL:
raise MemoryError raise MemoryError
with nogil, cython.parallel.parallel: with nogil, cython.parallel.parallel():
buf[threadid()] = threadid() buf[threadid()] = threadid()
for i in range(maxthreads): for i in range(maxthreads):
...@@ -24,4 +24,4 @@ def test_parallel(): ...@@ -24,4 +24,4 @@ def test_parallel():
free(buf) free(buf)
include "sequential_parallel.pyx" #include "sequential_parallel.pyx"
...@@ -23,7 +23,7 @@ def test_prange(): ...@@ -23,7 +23,7 @@ def test_prange():
""" """
cdef Py_ssize_t i, j, sum1 = 0, sum2 = 0 cdef Py_ssize_t i, j, sum1 = 0, sum2 = 0
with nogil, cython.parallel.parallel: with nogil, cython.parallel.parallel():
for i in prange(10, schedule='dynamic'): for i in prange(10, schedule='dynamic'):
sum1 += i sum1 += i
...@@ -57,9 +57,9 @@ def test_propagation(): ...@@ -57,9 +57,9 @@ def test_propagation():
for j in prange(10): for j in prange(10):
sum1 += i sum1 += i
with nogil, cython.parallel.parallel: with nogil, cython.parallel.parallel():
for x in prange(10): for x in prange(10):
with cython.parallel.parallel: with cython.parallel.parallel():
for y in prange(10): for y in prange(10):
sum2 += y sum2 += y
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment