Commit d2ef994b authored by Mark Florisson's avatar Mark Florisson

Issue error for nogil Py->C conversion & support runtime num_threads

parent 27cb4ad5
...@@ -986,11 +986,13 @@ class CCodeWriter(object): ...@@ -986,11 +986,13 @@ class CCodeWriter(object):
if create_from is not None: if create_from is not None:
# Use same global state # Use same global state
self.globalstate = create_from.globalstate self.globalstate = create_from.globalstate
self.funcstate = create_from.funcstate
# Clone formatting state # Clone formatting state
if copy_formatting: if copy_formatting:
self.level = create_from.level self.level = create_from.level
self.bol = create_from.bol self.bol = create_from.bol
self.call_level = create_from.call_level self.call_level = create_from.call_level
if emit_linenums is None and self.globalstate: if emit_linenums is None and self.globalstate:
self.emit_linenums = self.globalstate.emit_linenums self.emit_linenums = self.globalstate.emit_linenums
else: else:
...@@ -1000,7 +1002,8 @@ class CCodeWriter(object): ...@@ -1000,7 +1002,8 @@ class CCodeWriter(object):
def create_new(self, create_from, buffer, copy_formatting): def create_new(self, create_from, buffer, copy_formatting):
# polymorphic constructor -- very slightly more versatile # polymorphic constructor -- very slightly more versatile
# than using __class__ # than using __class__
result = CCodeWriter(create_from, buffer, copy_formatting, c_line_in_traceback=self.c_line_in_traceback) result = CCodeWriter(create_from, buffer, copy_formatting,
c_line_in_traceback=self.c_line_in_traceback)
return result return result
def copyto(self, f): def copyto(self, f):
......
...@@ -7814,6 +7814,9 @@ class CoerceFromPyTypeNode(CoercionNode): ...@@ -7814,6 +7814,9 @@ class CoerceFromPyTypeNode(CoercionNode):
if self.type.is_pyobject: if self.type.is_pyobject:
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
def nogil_check(self, env):
error(self.pos, "Coercion from Python not allowed without the GIL")
class CoerceToBooleanNode(CoercionNode): class CoerceToBooleanNode(CoercionNode):
# This node is used when a result needs to be used # This node is used when a result needs to be used
......
...@@ -5929,7 +5929,7 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5929,7 +5929,7 @@ class ParallelStatNode(StatNode, ParallelNode):
construct (replaced by its compile time value) construct (replaced by its compile time value)
""" """
child_attrs = ['body'] child_attrs = ['body', 'num_threads']
body = None body = None
...@@ -5937,6 +5937,8 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5937,6 +5937,8 @@ class ParallelStatNode(StatNode, ParallelNode):
error_label_used = False error_label_used = False
num_threads = None
parallel_exc = ( parallel_exc = (
Naming.parallel_exc_type, Naming.parallel_exc_type,
Naming.parallel_exc_value, Naming.parallel_exc_value,
...@@ -5977,7 +5979,15 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5977,7 +5979,15 @@ class ParallelStatNode(StatNode, ParallelNode):
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
self.num_threads = None
if self.kwargs: if self.kwargs:
for idx, dictitem in enumerate(self.kwargs.key_value_pairs[:]):
if dictitem.key.value == 'num_threads':
self.num_threads = dictitem.value
del self.kwargs.key_value_pairs[idx]
break
try: try:
self.kwargs = self.kwargs.compile_time_value(env) self.kwargs = self.kwargs.compile_time_value(env)
except Exception, e: except Exception, e:
...@@ -5993,6 +6003,8 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5993,6 +6003,8 @@ class ParallelStatNode(StatNode, ParallelNode):
setattr(self, kw, val) setattr(self, kw, val)
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.num_threads:
self.num_threads.analyse_expressions(env)
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
self.analyse_sharing_attributes(env) self.analyse_sharing_attributes(env)
...@@ -6000,13 +6012,18 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6000,13 +6012,18 @@ class ParallelStatNode(StatNode, ParallelNode):
if self.parent and self.parent.num_threads is not None: if self.parent and self.parent.num_threads is not None:
error(self.pos, error(self.pos,
"num_threads already declared in outer section") "num_threads already declared in outer section")
elif not isinstance(self.num_threads, (int, long)): elif self.parent:
error(self.pos, error(self.pos,
"Invalid value for num_threads argument, expected an int") "num_threads must be declared in the parent parallel section")
elif self.num_threads <= 0: elif (self.num_threads.type.is_int and
self.num_threads.is_literal and
self.num_threads.compile_time_value(env) <= 0):
error(self.pos, error(self.pos,
"argument to num_threads must be greater than 0") "argument to num_threads must be greater than 0")
self.num_threads = self.num_threads.coerce_to(
PyrexTypes.c_int_type, env).coerce_to_temp(env)
def analyse_sharing_attributes(self, env): def analyse_sharing_attributes(self, env):
""" """
Analyse the privates for this block and set them in self.privates. Analyse the privates for this block and set them in self.privates.
...@@ -6146,7 +6163,16 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -6146,7 +6163,16 @@ class ParallelStatNode(StatNode, ParallelNode):
Write self.num_threads if set as the num_threads OpenMP directive Write self.num_threads if set as the num_threads OpenMP directive
""" """
if self.num_threads is not None: if self.num_threads is not None:
code.put(" num_threads(%d)" % (self.num_threads,)) c = self.begin_of_parallel_control_block_point
# we need to set the owner to ourselves temporarily, as
# allocate_temp may generate a comment in the middle of our pragma
# otherwise when DebugFlags.debug_temp_code_comments is in effect
owner = c.funcstate.owner
c.funcstate.owner = c
self.num_threads.generate_evaluation_code(c)
c.funcstate.owner = owner
code.put(" num_threads(%s)" % (self.num_threads.result(),))
def declare_closure_privates(self, code): def declare_closure_privates(self, code):
...@@ -6511,8 +6537,6 @@ class ParallelWithBlockNode(ParallelStatNode): ...@@ -6511,8 +6537,6 @@ class ParallelWithBlockNode(ParallelStatNode):
This node represents a 'with cython.parallel.parallel():' block This node represents a 'with cython.parallel.parallel():' block
""" """
nogil_check = None
valid_keyword_arguments = ['num_threads'] valid_keyword_arguments = ['num_threads']
num_threads = None num_threads = None
......
...@@ -19,8 +19,17 @@ with nogil, parallel(num_threads=2): ...@@ -19,8 +19,17 @@ with nogil, parallel(num_threads=2):
for i in prange(10, num_threads=2): for i in prange(10, num_threads=2):
pass pass
with nogil, parallel():
for i in prange(10, num_threads=2):
pass
# this one is valid
for i in prange(10, nogil=True, num_threads=2):
pass
_ERRORS = u""" _ERRORS = u"""
e_invalid_num_threads.pyx:8:33: Coercion from Python not allowed without the GIL
e_invalid_num_threads.pyx:12:20: argument to num_threads must be greater than 0 e_invalid_num_threads.pyx:12:20: argument to num_threads must be greater than 0
e_invalid_num_threads.pyx:15:20: Invalid value for num_threads argument, expected an int
e_invalid_num_threads.pyx:19:19: num_threads already declared in outer section e_invalid_num_threads.pyx:19:19: num_threads already declared in outer section
e_invalid_num_threads.pyx:23:19: num_threads must be declared in the parent parallel section
""" """
...@@ -26,10 +26,18 @@ def test_parallel(): ...@@ -26,10 +26,18 @@ def test_parallel():
free(buf) free(buf)
cdef int get_num_threads() with gil:
print "get_num_threads called"
return 3
def test_num_threads(): def test_num_threads():
""" """
>>> test_num_threads() >>> test_num_threads()
1 1
get_num_threads called
3
get_num_threads called
3
""" """
cdef int dyn = openmp.omp_get_dynamic() cdef int dyn = openmp.omp_get_dynamic()
cdef int num_threads cdef int num_threads
...@@ -40,6 +48,19 @@ def test_num_threads(): ...@@ -40,6 +48,19 @@ def test_num_threads():
with nogil, cython.parallel.parallel(num_threads=1): with nogil, cython.parallel.parallel(num_threads=1):
p[0] = openmp.omp_get_num_threads() p[0] = openmp.omp_get_num_threads()
print num_threads
with nogil, cython.parallel.parallel(num_threads=get_num_threads()):
p[0] = openmp.omp_get_num_threads()
print num_threads
cdef int i
num_threads = 0xbad
for i in prange(1, nogil=True, num_threads=get_num_threads()):
p[0] = openmp.omp_get_num_threads()
break
openmp.omp_set_dynamic(dyn) openmp.omp_set_dynamic(dyn)
return num_threads return num_threads
......
...@@ -729,6 +729,6 @@ def test_num_threads_compile(): ...@@ -729,6 +729,6 @@ def test_num_threads_compile():
with nogil, cython.parallel.parallel(num_threads=2): with nogil, cython.parallel.parallel(num_threads=2):
pass pass
with nogil, cython.parallel.parallel(): with nogil, cython.parallel.parallel(num_threads=2):
for i in prange(10, num_threads=2): for i in prange(10):
pass pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment