Commit b4f0ba7c authored by Mark Florisson's avatar Mark Florisson

Fix num_thread for prange() without parallel() + more error checks

parent 37e1e52b
...@@ -5978,7 +5978,11 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5978,7 +5978,11 @@ class ParallelStatNode(StatNode, ParallelNode):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
if self.kwargs: if self.kwargs:
self.kwargs = self.kwargs.compile_time_value(env) try:
self.kwargs = self.kwargs.compile_time_value(env)
except Exception, e:
error(self.kwargs.pos, "Only compile-time values may be "
"supplied as keyword arguments")
else: else:
self.kwargs = {} self.kwargs = {}
...@@ -5992,6 +5996,17 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5992,6 +5996,17 @@ class ParallelStatNode(StatNode, ParallelNode):
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
self.analyse_sharing_attributes(env) self.analyse_sharing_attributes(env)
if self.num_threads is not None:
if self.parent and self.parent.num_threads is not None:
error(self.pos,
"num_threads already declared in outer section")
elif not isinstance(self.num_threads, (int, long)):
error(self.pos,
"Invalid value for num_threads argument, expected an int")
elif self.num_threads <= 0:
error(self.pos,
"argument to num_threads must be greater than 0")
def analyse_sharing_attributes(self, env): def analyse_sharing_attributes(self, env):
""" """
Analyse the privates for this block and set them in self.privates. Analyse the privates for this block and set them in self.privates.
...@@ -6131,11 +6146,8 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6131,11 +6146,8 @@ class ParallelStatNode(StatNode, ParallelNode):
Write self.num_threads if set as the num_threads OpenMP directive Write self.num_threads if set as the num_threads OpenMP directive
""" """
if self.num_threads is not None: if self.num_threads is not None:
if isinstance(self.num_threads, (int, long)): code.put(" num_threads(%d)" % (self.num_threads,))
code.put(" num_threads(%d)" % (self.num_threads,))
else:
error(self.pos, "Invalid value for num_threads argument, "
"expected an int")
def declare_closure_privates(self, code): def declare_closure_privates(self, code):
""" """
...@@ -6790,11 +6802,11 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6790,11 +6802,11 @@ class ParallelRangeNode(ParallelStatNode):
if not self.is_parallel: if not self.is_parallel:
code.put("#pragma omp for") code.put("#pragma omp for")
self.privatization_insertion_point = code.insertion_point() self.privatization_insertion_point = code.insertion_point()
# reduction_codepoint = self.parent.privatization_insertion_point reduction_codepoint = self.parent.privatization_insertion_point
else: else:
code.put("#pragma omp parallel") code.put("#pragma omp parallel")
self.privatization_insertion_point = code.insertion_point() self.privatization_insertion_point = code.insertion_point()
# reduction_codepoint = self.privatization_insertion_point reduction_codepoint = self.privatization_insertion_point
code.putln("") code.putln("")
code.putln("#endif /* _OPENMP */") code.putln("#endif /* _OPENMP */")
...@@ -6806,11 +6818,6 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6806,11 +6818,6 @@ class ParallelRangeNode(ParallelStatNode):
code.putln("#ifdef _OPENMP") code.putln("#ifdef _OPENMP")
code.put("#pragma omp for") code.put("#pragma omp for")
# Nested parallelism is not supported, so we can put reductions on the
# for and not on the parallel (but would be valid, but gcc45 bugs on
# the former)
reduction_codepoint = code
for entry, (op, lastprivate) in self.privates.iteritems(): for entry, (op, lastprivate) in self.privates.iteritems():
# Don't declare the index variable as a reduction # Don't declare the index variable as a reduction
if op and op in "+*-&^|" and entry != self.target.entry: if op and op in "+*-&^|" and entry != self.target.entry:
......
# mode: error
from cython.parallel cimport parallel, prange
cdef int i
# valid
with nogil, parallel(num_threads=None):
pass
# invalid
with nogil, parallel(num_threads=0):
pass
with nogil, parallel(num_threads=i):
pass
with nogil, parallel(num_threads=2):
for i in prange(10, num_threads=2):
pass
_ERRORS = u"""
e_invalid_num_threads.pyx:12:20: argument to num_threads must be greater than 0
e_invalid_num_threads.pyx:15:20: Invalid value for num_threads argument, expected an int
e_invalid_num_threads.pyx:19:19: num_threads already declared in outer section
"""
...@@ -720,3 +720,15 @@ def test_nogil_cdef_except_clause(): ...@@ -720,3 +720,15 @@ def test_nogil_cdef_except_clause():
for i in prange(10, nogil=True): for i in prange(10, nogil=True):
nogil_cdef_except_clause() nogil_cdef_except_clause()
nogil_cdef_except_star() nogil_cdef_except_star()
def test_num_threads_compile():
cdef int i
for i in prange(10, nogil=True, num_threads=2):
pass
with nogil, cython.parallel.parallel(num_threads=2):
pass
with nogil, cython.parallel.parallel():
for i in prange(10, num_threads=2):
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment