Commit 4657fff3 authored by Mark Florisson's avatar Mark Florisson

Initialize threads when cython.parallel is imported

parent 468c17af
...@@ -6224,7 +6224,8 @@ class ParallelWithBlockNode(ParallelStatNode): ...@@ -6224,7 +6224,8 @@ class ParallelWithBlockNode(ParallelStatNode):
code.put("#pragma omp parallel ") code.put("#pragma omp parallel ")
if self.privates: if self.privates:
privates = [e.cname for e in self.privates] privates = [e.cname for e in self.privates
if not e.type.is_pyobject]
code.put('private(%s)' % ', '.join(privates)) code.put('private(%s)' % ', '.join(privates))
self.privatization_insertion_point = code.insertion_point() self.privatization_insertion_point = code.insertion_point()
...@@ -6239,6 +6240,9 @@ class ParallelWithBlockNode(ParallelStatNode): ...@@ -6239,6 +6240,9 @@ class ParallelWithBlockNode(ParallelStatNode):
self.trap_parallel_exit(code) self.trap_parallel_exit(code)
code.end_block() # end parallel block code.end_block() # end parallel block
# After the parallel block all privates are undefined
self.initialize_privates_to_nan(code)
continue_ = code.label_used(code.continue_label) continue_ = code.label_used(code.continue_label)
break_ = code.label_used(code.break_label) break_ = code.label_used(code.break_label)
...@@ -6489,6 +6493,8 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6489,6 +6493,8 @@ class ParallelRangeNode(ParallelStatNode):
else: else:
if entry == self.target.entry: if entry == self.target.entry:
code.put(" firstprivate(%s)" % entry.cname) code.put(" firstprivate(%s)" % entry.cname)
if not entry.type.is_pyobject:
code.put(" lastprivate(%s)" % entry.cname) code.put(" lastprivate(%s)" % entry.cname)
if self.schedule: if self.schedule:
...@@ -7576,6 +7582,10 @@ proto=""" ...@@ -7576,6 +7582,10 @@ proto="""
#endif #endif
""") """)
init_threads = UtilityCode(
init="PyEval_InitThreads();\n",
)
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# Note that cPython ignores PyTrace_EXCEPTION, # Note that cPython ignores PyTrace_EXCEPTION,
......
...@@ -650,6 +650,8 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -650,6 +650,8 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
self.wrong_scope_error(node.pos, key, 'module') self.wrong_scope_error(node.pos, key, 'module')
del node.directive_comments[key] del node.directive_comments[key]
self.module_scope = node.scope
directives = copy.deepcopy(Options.directive_defaults) directives = copy.deepcopy(Options.directive_defaults)
directives.update(copy.deepcopy(self.compilation_directive_defaults)) directives.update(copy.deepcopy(self.compilation_directive_defaults))
directives.update(node.directive_comments) directives.update(node.directive_comments)
...@@ -684,6 +686,8 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -684,6 +686,8 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
directive[-1] not in self.valid_parallel_directives): directive[-1] not in self.valid_parallel_directives):
error(pos, "No such directive: %s" % full_name) error(pos, "No such directive: %s" % full_name)
self.module_scope.use_utility_code(Nodes.init_threads)
return result return result
def visit_CImportStatNode(self, node): def visit_CImportStatNode(self, node):
...@@ -699,6 +703,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -699,6 +703,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
self.cython_module_names.add(u"cython") self.cython_module_names.add(u"cython")
self.parallel_directives[ self.parallel_directives[
u"cython.parallel"] = node.module_name u"cython.parallel"] = node.module_name
self.module_scope.use_utility_code(Nodes.init_threads)
elif node.as_name: elif node.as_name:
self.directive_names[node.as_name] = node.module_name[7:] self.directive_names[node.as_name] = node.module_name[7:]
else: else:
......
...@@ -387,9 +387,6 @@ class PyObjectType(PyrexType): ...@@ -387,9 +387,6 @@ class PyObjectType(PyrexType):
else: else:
return cname return cname
def invalid_value(self):
return "1"
class BuiltinObjectType(PyObjectType): class BuiltinObjectType(PyObjectType):
# objstruct_cname string Name of PyObject struct # objstruct_cname string Name of PyObject struct
......
...@@ -233,8 +233,14 @@ def test_nan_init(): ...@@ -233,8 +233,14 @@ def test_nan_init():
raise Exception("One of the values was not initialized to a maximum " raise Exception("One of the values was not initialized to a maximum "
"or NaN value") "or NaN value")
c1 = 20
with nogil, cython.parallel.parallel():
c1 = 16
assert c1 not in (16, 20), c1
cdef void nogil_print(char *s) with gil: cdef void nogil_print(char *s) with gil:
print s print s.decode('ascii')
def test_else_clause(): def test_else_clause():
""" """
...@@ -343,3 +349,64 @@ def test_return(): ...@@ -343,3 +349,64 @@ def test_return():
""" """
print parallel_return() print parallel_return()
def test_parallel_exceptions():
"""
>>> test_parallel_exceptions()
('I am executed first', 0)
('propagate me',) 0
"""
cdef int i, j, sum = 0
mylist = []
try:
for i in prange(10, nogil=True):
try:
for j in prange(10):
with gil:
raise Exception("propagate me")
sum += i * j
sum += i
finally:
with gil:
mylist.append(("I am executed first", sum))
except Exception, e:
print mylist[0]
print e.args, sum
def test_parallel_with_gil_return():
"""
>>> test_parallel_with_gil_return()
True
45
"""
cdef int i, sum = 0
for i in prange(10, nogil=True):
with gil:
obj = i
sum += obj
print obj in range(10)
with nogil, cython.parallel.parallel():
with gil:
return sum
def test_parallel_with_gil_continue():
"""
>>> test_parallel_with_gil_continue()
20
"""
cdef int i, sum = 0
for i in prange(10, nogil=True):
with cython.parallel.parallel():
with gil:
if i % 2:
continue
sum += i
print sum
...@@ -202,7 +202,7 @@ def test_loops_and_boxing(): ...@@ -202,7 +202,7 @@ def test_loops_and_boxing():
with nogil: with nogil:
with gil: with gil:
print string.decode('ASCII') print string.decode('ascii')
for c in string[4:]: for c in string[4:]:
print "%c" % c print "%c" % c
else: else:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment