Commit 0994de57 authored by Mark Florisson's avatar Mark Florisson

Add documentation, more tests, pure mode, fixed some bugs

parent bf8594f8
...@@ -1447,7 +1447,7 @@ class CCodeWriter(object): ...@@ -1447,7 +1447,7 @@ class CCodeWriter(object):
def putln_openmp(self, string): def putln_openmp(self, string):
self.putln("#ifdef _OPENMP") self.putln("#ifdef _OPENMP")
self.putln(string) self.putln(string)
self.putln("#endif") self.putln("#endif /* _OPENMP */")
class PyrexCodeWriter(object): class PyrexCodeWriter(object):
# f file output file # f file output file
......
...@@ -2076,7 +2076,7 @@ class ParallelThreadsAvailableNode(AtomicExprNode): ...@@ -2076,7 +2076,7 @@ class ParallelThreadsAvailableNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
self.is_temp = True self.is_temp = True
env.add_include_file("omp.h") # env.add_include_file("omp.h")
return self.type return self.type
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -2101,7 +2101,7 @@ class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode): ...@@ -2101,7 +2101,7 @@ class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode):
def analyse_types(self, env): def analyse_types(self, env):
self.is_temp = True self.is_temp = True
env.add_include_file("omp.h") # env.add_include_file("omp.h")
return self.type return self.type
def generate_result_code(self, code): def generate_result_code(self, code):
......
...@@ -756,6 +756,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -756,6 +756,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
else: else:
code.putln('#include "%s"' % byte_decoded_filenname) code.putln('#include "%s"' % byte_decoded_filenname)
code.putln_openmp("#include <omp.h>")
def generate_filename_table(self, code): def generate_filename_table(self, code):
code.putln("") code.putln("")
code.putln("static const char *%s[] = {" % Naming.filetable_cname) code.putln("static const char *%s[] = {" % Naming.filetable_cname)
......
...@@ -5842,7 +5842,7 @@ class ParallelStatNode(StatNode, ParallelNode): ...@@ -5842,7 +5842,7 @@ class ParallelStatNode(StatNode, ParallelNode):
class ParallelWithBlockNode(ParallelStatNode): class ParallelWithBlockNode(ParallelStatNode):
""" """
This node represents a 'with cython.parallel:' block This node represents a 'with cython.parallel.parallel:' block
""" """
nogil_check = None nogil_check = None
...@@ -5883,6 +5883,7 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -5883,6 +5883,7 @@ class ParallelRangeNode(ParallelStatNode):
start = stop = step = None start = stop = step = None
is_prange = True is_prange = True
is_nogil = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
super(ParallelRangeNode, self).analyse_declarations(env) super(ParallelRangeNode, self).analyse_declarations(env)
...@@ -5918,13 +5919,16 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -5918,13 +5919,16 @@ class ParallelRangeNode(ParallelStatNode):
error(self.pos, "Invalid keyword argument to prange: %s" % kw) error(self.pos, "Invalid keyword argument to prange: %s" % kw)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.target.analyse_target_types(env) if self.target is None:
error(self.pos, "prange() can only be used as part of a for loop")
return
self.target.analyse_target_types(env)
self.index_type = self.target.type self.index_type = self.target.type
if self.index_type.is_pyobject: if self.index_type.is_pyobject:
# nogil_check will catch this # nogil_check will catch this, for now, assume a valid type
return self.index_type = PyrexTypes.c_py_ssize_t_type
# Setup start, stop and step, allocating temps if needed # Setup start, stop and step, allocating temps if needed
self.names = 'start', 'stop', 'step' self.names = 'start', 'stop', 'step'
...@@ -5953,7 +5957,7 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -5953,7 +5957,7 @@ class ParallelRangeNode(ParallelStatNode):
def nogil_check(self, env): def nogil_check(self, env):
names = 'start', 'stop', 'step', 'target' names = 'start', 'stop', 'step', 'target'
nodes = self.start, self.stop, self.step, self.target nodes = self.start, self.stop, self.step, self.target
for name, node in zip(names, nodes): for name, node in zip(names, nodes):
if node is not None and node.type.is_pyobject: if node is not None and node.type.is_pyobject:
error(node.pos, "%s may not be a Python object " error(node.pos, "%s may not be a Python object "
...@@ -6039,9 +6043,8 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6039,9 +6043,8 @@ class ParallelRangeNode(ParallelStatNode):
"(%(start)s > %(stop)s && %(step)s < 0) ) " % fmt_dict) "(%(start)s > %(stop)s && %(step)s < 0) ) " % fmt_dict)
code.begin_block() code.begin_block()
# code.putln_openmp("#pragma omp single") code.putln_openmp("#pragma omp critical")
code.putln("%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;" % fmt_dict) code.putln("%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;" % fmt_dict)
# code.putln_openmp("#pragma omp barrier")
self.generate_loop(code, fmt_dict) self.generate_loop(code, fmt_dict)
...@@ -6081,7 +6084,10 @@ class ParallelRangeNode(ParallelStatNode): ...@@ -6081,7 +6084,10 @@ class ParallelRangeNode(ParallelStatNode):
if self.schedule: if self.schedule:
code.put(" schedule(%s)" % self.schedule) code.put(" schedule(%s)" % self.schedule)
code.putln(" lastprivate(%s)" % target_index_cname) if self.is_parallel or self.target.entry not in self.parent.assignments:
code.putln(" lastprivate(%s)" % target_index_cname)
else:
code.putln("")
code.putln("#endif") code.putln("#endif")
......
...@@ -669,8 +669,12 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -669,8 +669,12 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
if result: if result:
directive = full_name.rsplit('.', 1) directive = full_name.rsplit('.', 1)
if (len(directive) != 2 or directive[1] not in if len(directive) == 2 and directive[1] == '*':
self.valid_parallel_directives): # star import
for name in self.valid_parallel_directives:
self.parallel_directives[name] = u"cython.parallel.%s" % name
elif (len(directive) != 2 or
directive[1] not in self.valid_parallel_directives):
error(pos, "No such directive: %s" % full_name) error(pos, "No such directive: %s" % full_name)
return result return result
...@@ -682,7 +686,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -682,7 +686,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
if node.module_name.startswith(u"cython.parallel."): if node.module_name.startswith(u"cython.parallel."):
error(node.pos, node.module_name + " is not a module") error(node.pos, node.module_name + " is not a module")
if node.module_name == u"cython.parallel": if node.module_name == u"cython.parallel":
if node.as_name: if node.as_name and node.as_name != u"cython":
self.parallel_directives[node.as_name] = node.module_name self.parallel_directives[node.as_name] = node.module_name
else: else:
self.cython_module_names.add(u"cython") self.cython_module_names.add(u"cython")
...@@ -748,28 +752,23 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -748,28 +752,23 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return node return node
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
if (isinstance(node.rhs, ExprNodes.ImportNode) and if isinstance(node.rhs, ExprNodes.ImportNode):
node.rhs.module_name.value in (u'cython', u"cython.parallel")): module_name = node.rhs.module_name.value
is_parallel = (module_name + u".").startswith(u"cython.parallel.")
if module_name != u"cython" and not is_parallel:
return node
module_name = node.rhs.module_name.value module_name = node.rhs.module_name.value
as_name = node.lhs.name as_name = node.lhs.name
if module_name == u"cython.parallel" and as_name == u"cython":
# Be consistent with the cimport variant
as_name = u"cython.parallel"
node = Nodes.CImportStatNode(node.pos, node = Nodes.CImportStatNode(node.pos,
module_name = module_name, module_name = module_name,
as_name = as_name) as_name = as_name)
node = self.visit_CImportStatNode(node)
self.visit_CImportStatNode(node)
if node.module_name == u"cython.parallel":
# This is an import for a fake module, remove it
return None
if node.module_name.startswith(u"cython.parallel."):
error(node.pos, node.module_name + " is not a module")
else: else:
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_NameNode(self, node): def visit_NameNode(self, node):
...@@ -1041,7 +1040,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1041,7 +1040,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
return node return node
def visit_CallNode(self, node): def visit_CallNode(self, node):
self.visitchildren(node) self.visit(node.function)
if not self.parallel_directive: if not self.parallel_directive:
return node return node
......
...@@ -277,3 +277,28 @@ for t in int_types + float_types + complex_types + other_types: ...@@ -277,3 +277,28 @@ for t in int_types + float_types + complex_types + other_types:
void = typedef(None) void = typedef(None)
NULL = p_void(0) NULL = p_void(0)
class CythonDotParallel(object):
"""
The cython.parallel module.
"""
__all__ = ['parallel', 'prange', 'threadid']
parallel = nogil
def prange(self, start=0, stop=None, step=1, schedule=None, nogil=False):
if stop is None:
stop = start
start = 0
return range(start, stop, step)
def threadid(self):
return 0
# def threadsavailable(self):
# return 1
import sys
sys.modules['cython.parallel'] = CythonDotParallel()
del sys
\ No newline at end of file
...@@ -18,6 +18,7 @@ Contents: ...@@ -18,6 +18,7 @@ Contents:
limitations limitations
pyrex_differences pyrex_differences
early_binding_for_speed early_binding_for_speed
parallelism
debugging debugging
Indices and tables Indices and tables
......
.. highlight:: cython
.. py:module:: cython.parallel
**********************************
Using Parallelism
**********************************
Cython supports native parallelism through the :py:mod:`cython.parallel`
module. To use this kind of parallelism, the GIL must be released. It
currently supports OpenMP, but later on more backends might be supported.
.. function:: prange([start,] stop[, step], nogil=False, schedule=None)
This function can be used for parallel loops. OpenMP automatically
starts a thread pool and distributes the work according to the schedule
used. ``step`` must not be 0. This function can only be used with the
GIL released. If ``nogil`` is true, the loop will be wrapped in a nogil
section.
Thread-locality and reductions are automatically inferred for variables.
If you assign to a variable, it becomes lastprivate, meaning that the
variable will contain the value from the last iteration. If you use an
inplace operator on a variable, it becomes a reduction, meaning that the
values from the thread-local copies of the variable will be reduced with
the operator and assigned to the original variable after the loop. The
index variable is always lastprivate.
The ``schedule`` is passed to OpenMP and can be one of the following:
+-----------------+------------------------------------------------------+
| Schedule | Description |
+=================+======================================================+
|static | The iteration space is divided into chunks that are |
| | approximately equal in size, and at most one chunk |
| | is distributed to each thread. |
+-----------------+------------------------------------------------------+
|dynamic | The iterations are distributed to threads in the team|
| | as the threads request them, with a chunk size of 1. |
+-----------------+------------------------------------------------------+
|guided | The iterations are distributed to threads in the team|
| | as the threads request them. The size of each chunk |
| | is proportional to the number of unassigned |
| | iterations divided by the number of threads in the |
| | team, decreasing to 1. |
+-----------------+------------------------------------------------------+
|auto | The decision regarding scheduling is delegated to the|
| | compiler and/or runtime system. The programmer gives |
| | the implementation the freedom to choose any possible|
| | mapping of iterations to threads in the team. |
+-----------------+------------------------------------------------------+
|runtime | The schedule and chunk size are taken from the |
| | runtime-scheduling-variable, which can be set through|
| | the ``omp_set_schedule`` function call, or the |
| | ``OMP_SCHEDULE`` environment variable. |
+-----------------+------------------------------------------------------+
The default schedule is implementation defined. For more information consult
the OpenMP specification: [#]_.
Example with a reduction::
from cython.parallel import prange, parallel, threadid
cdef int i
cdef int sum = 0
for i in prange(n, nogil=True):
sum += i
print sum
Example with a shared numpy array::
from cython.parallel import *
def func(np.ndarray[double] x, double alpha):
cdef Py_ssize_t i
for i in prange(x.shape[0]):
x[i] = alpha * x[i]
.. function:: parallel
This directive can be used as part of a ``with`` statement to execute code
sequences in parallel. This is currently useful to setup thread-local
buffers used by a prange. A contained prange will be a worksharing loop
that is not parallel, so any variable assigned to in the parallel section
is also private to the prange. Variables that are private in the parallel
construct are undefined after the parallel block.
Example with thread-local buffers::
from cython.parallel import *
from cython.stdlib cimport abort
cdef Py_ssize_t i, n = 100
cdef int * local_buf
cdef size_t size = 10
with nogil, parallel:
local_buf = malloc(sizeof(int) * size)
if local_buf == NULL:
abort()
# populate our local buffer in a sequential loop
for i in range(size):
local_buf[i] = i * 2
# share the work using the thread-local buffer(s)
for i in prange(n, schedule='guided'):
func(local_buf)
free(local_buf)
Later on sections might be supported in parallel blocks, to distribute
code sections of work among threads.
.. function:: threadid()
Returns the id of the thread. For n threads, the ids will range from 0 to
n.
Compiling
=========
To actually use the OpenMP support, you need to tell the C or C++ compiler to
enable OpenMP. For gcc this can be done as follows in a setup.py::
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
ext_module = Extension(
"hello",
["hello.pyx"],
extra_compile_args=['-fopenmp'],
libraries=['gomp'],
)
setup(
name = 'Hello world app',
cmdclass = {'build_ext': build_ext},
ext_modules = [ext_module],
)
.. rubric:: References
.. [#] http://www.openmp.org/mp-documents/spec30.pdf
# mode: error
cimport cython.parallel.parallel as p
from cython.parallel cimport something
import cython.parallel.parallel as p
from cython.parallel import something
from cython.parallel cimport prange
import cython.parallel
prange(1, 2, 3, schedule='dynamic')
cdef int i
with nogil, cython.parallel.parallel:
for i in prange(10, schedule='invalid_schedule'):
pass
with cython.parallel.parallel:
print "hello world!"
cdef int *x = NULL
with nogil, cython.parallel.parallel:
for j in prange(10):
pass
for x[1] in prange(10):
pass
_ERRORS = u"""
e_cython_parallel.pyx:3:8: cython.parallel.parallel is not a module
e_cython_parallel.pyx:4:0: No such directive: cython.parallel.something
e_cython_parallel.pyx:6:7: cython.parallel.parallel is not a module
e_cython_parallel.pyx:7:0: No such directive: cython.parallel.something
e_cython_parallel.pyx:13:6: prange() can only be used as part of a for loop
e_cython_parallel.pyx:13:6: prange() can only be used without the GIL
e_cython_parallel.pyx:18:19: Invalid schedule argument to prange: 'invalid_schedule'
e_cython_parallel.pyx:21:5: The parallel section may only be used without the GIL
e_cython_parallel.pyx:27:10: target may not be a Python object as we don't have the GIL
e_cython_parallel.pyx:30:9: Can only iterate over an iteration variable
"""
...@@ -5,14 +5,10 @@ ...@@ -5,14 +5,10 @@
cimport cython.parallel cimport cython.parallel
from cython.parallel import prange, threadid from cython.parallel import prange, threadid
from libc.stdlib cimport malloc, free from libc.stdlib cimport malloc, free
from libc.stdio cimport puts
cimport numpy as np
cdef extern from "Python.h": import sys
void PyEval_InitThreads()
PyEval_InitThreads()
cdef void print_int(int x) with gil:
print x
#@cython.test_assert_path_exists( #@cython.test_assert_path_exists(
# "//ParallelWithBlockNode//ParallelRangeNode[@schedule = 'dynamic']", # "//ParallelWithBlockNode//ParallelRangeNode[@schedule = 'dynamic']",
...@@ -173,3 +169,64 @@ def test_closure_parallel_privates(): ...@@ -173,3 +169,64 @@ def test_closure_parallel_privates():
g = test_generator() g = test_generator()
print g.next(), x, g.next(), x print g.next(), x, g.next(), x
def test_pure_mode():
"""
>>> test_pure_mode()
0
1
2
3
4
4
3
2
1
0
0
"""
import Cython.Shadow
pure_parallel = sys.modules['cython.parallel']
for i in pure_parallel.prange(5):
print i
for i in pure_parallel.prange(4, -1, -1, schedule='dynamic', nogil=True):
print i
with pure_parallel.parallel:
print pure_parallel.threadid()
@cython.boundscheck(False)
def test_parallel_numpy_arrays():
"""
Disabled for now, need to handle buffer auxiliary variables.
test_parallel_numpy_arrays()
-5
-4
-3
-2
-1
0
1
2
3
4
"""
cdef Py_ssize_t i
cdef np.ndarray[np.int_t] x
try:
import numpy
except ImportError:
for i in range(-5, 5):
print i
return
x = numpy.zeros(10, dtype=np.int)
for i in prange(x.shape[0], nogil=True):
x[i] = i - 5
for i in x:
print x
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment