Commit 3537d938 authored by Mark Florisson's avatar Mark Florisson
parents efb0837d 3487c4e0
...@@ -11,6 +11,7 @@ Cython/Runtime/refnanny.c ...@@ -11,6 +11,7 @@ Cython/Runtime/refnanny.c
BUILD/ BUILD/
build/ build/
dist/ dist/
.git/
.gitrev .gitrev
.coverage .coverage
*.orig *.orig
......
...@@ -1444,6 +1444,10 @@ class CCodeWriter(object): ...@@ -1444,6 +1444,10 @@ class CCodeWriter(object):
def put_trace_return(self, retvalue_cname): def put_trace_return(self, retvalue_cname):
self.putln("__Pyx_TraceReturn(%s);" % retvalue_cname) self.putln("__Pyx_TraceReturn(%s);" % retvalue_cname)
def putln_openmp(self, string):
self.putln("#ifdef _OPENMP")
self.putln(string)
self.putln("#endif /* _OPENMP */")
class PyrexCodeWriter(object): class PyrexCodeWriter(object):
# f file output file # f file output file
......
...@@ -2059,6 +2059,64 @@ class RawCNameExprNode(ExprNode): ...@@ -2059,6 +2059,64 @@ class RawCNameExprNode(ExprNode):
pass pass
#-------------------------------------------------------------------
#
# Parallel nodes (cython.parallel.thread(savailable|id))
#
#-------------------------------------------------------------------
class ParallelThreadsAvailableNode(AtomicExprNode):
"""
Note: this is disabled and not a valid directive at this moment
Implements cython.parallel.threadsavailable(). If we are called from the
sequential part of the application, we need to call omp_get_max_threads(),
and in the parallel part we can just call omp_get_num_threads()
"""
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.is_temp = True
# env.add_include_file("omp.h")
return self.type
def generate_result_code(self, code):
code.putln("#ifdef _OPENMP")
code.putln("if (omp_in_parallel()) %s = omp_get_max_threads();" %
self.temp_code)
code.putln("else %s = omp_get_num_threads();" % self.temp_code)
code.putln("#else")
code.putln("%s = 1;" % self.temp_code)
code.putln("#endif")
def result(self):
return self.temp_code
class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode):
"""
Implements cython.parallel.threadid()
"""
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.is_temp = True
# env.add_include_file("omp.h")
return self.type
def generate_result_code(self, code):
code.putln("#ifdef _OPENMP")
code.putln("%s = omp_get_thread_num();" % self.temp_code)
code.putln("#else")
code.putln("%s = 0;" % self.temp_code)
code.putln("#endif")
def result(self):
return self.temp_code
#------------------------------------------------------------------- #-------------------------------------------------------------------
# #
# Trailer nodes # Trailer nodes
...@@ -3465,8 +3523,11 @@ class AttributeNode(ExprNode): ...@@ -3465,8 +3523,11 @@ class AttributeNode(ExprNode):
needs_none_check = True needs_none_check = True
def as_cython_attribute(self): def as_cython_attribute(self):
if isinstance(self.obj, NameNode) and self.obj.is_cython_module: if (isinstance(self.obj, NameNode) and
self.obj.is_cython_module and not
self.attribute == u"parallel"):
return self.attribute return self.attribute
cy = self.obj.as_cython_attribute() cy = self.obj.as_cython_attribute()
if cy: if cy:
return "%s.%s" % (cy, self.attribute) return "%s.%s" % (cy, self.attribute)
......
...@@ -106,7 +106,7 @@ class Context(object): ...@@ -106,7 +106,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators from ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
from TypeInference import MarkAssignments, MarkOverflowingArithmetic from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform from AnalysedTreeTransforms import AutoTestDictTransform
...@@ -135,6 +135,7 @@ class Context(object): ...@@ -135,6 +135,7 @@ class Context(object):
PostParse(self), PostParse(self),
_specific_post_parse, _specific_post_parse,
InterpretCompilerDirectives(self, self.compiler_directives), InterpretCompilerDirectives(self, self.compiler_directives),
ParallelRangeTransform(self),
MarkClosureVisitor(self), MarkClosureVisitor(self),
_align_function_definitions, _align_function_definitions,
ConstantFolding(), ConstantFolding(),
......
...@@ -756,6 +756,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -756,6 +756,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
else: else:
code.putln('#include "%s"' % byte_decoded_filenname) code.putln('#include "%s"' % byte_decoded_filenname)
code.putln_openmp("#include <omp.h>")
def generate_filename_table(self, code): def generate_filename_table(self, code):
code.putln("") code.putln("")
code.putln("static const char *%s[] = {" % Naming.filetable_cname) code.putln("static const char *%s[] = {" % Naming.filetable_cname)
......
This diff is collapsed.
This diff is collapsed.
...@@ -723,6 +723,9 @@ class Scope(object): ...@@ -723,6 +723,9 @@ class Scope(object):
else: else:
return outer.is_cpp() return outer.is_cpp()
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)
class PreImportScope(Scope): class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname namespace_cname = Naming.preimport_cname
...@@ -1856,8 +1859,6 @@ class CppClassScope(Scope): ...@@ -1856,8 +1859,6 @@ class CppClassScope(Scope):
utility_code = e.utility_code) utility_code = e.utility_code)
return scope return scope
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)
class PropertyScope(Scope): class PropertyScope(Scope):
# Scope holding the __get__, __set__ and __del__ methods for # Scope holding the __get__, __set__ and __del__ methods for
......
...@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine ...@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine
from Cython.TestUtils import TransformTest from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import * from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler import Main
class TestNormalizeTree(TransformTest): class TestNormalizeTree(TransformTest):
...@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled! ...@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled!
""", t) """, t)
class TestInterpretCompilerDirectives(TransformTest):
"""
This class tests the parallel directives AST-rewriting and importing.
"""
# Test the parallel directives (c)importing
import_code = u"""
cimport cython.parallel
cimport cython.parallel as par
from cython cimport parallel as par2
from cython cimport parallel
from cython.parallel cimport threadid as tid
from cython.parallel cimport threadavailable as tavail
from cython.parallel cimport prange
"""
expected_directives_dict = {
u'cython.parallel': u'cython.parallel',
u'par': u'cython.parallel',
u'par2': u'cython.parallel',
u'parallel': u'cython.parallel',
u"tid": u"cython.parallel.threadid",
u"tavail": u"cython.parallel.threadavailable",
u"prange": u"cython.parallel.prange",
}
def setUp(self):
super(TestInterpretCompilerDirectives, self).setUp()
compilation_options = Main.CompilationOptions(Main.default_options)
ctx = compilation_options.create_context()
self.pipeline = [
InterpretCompilerDirectives(ctx, ctx.compiler_directives),
]
self.debug_exception_on_error = DebugFlags.debug_exception_on_error
def tearDown(self):
DebugFlags.debug_exception_on_error = self.debug_exception_on_error
def test_parallel_directives_cimports(self):
self.run_pipeline(self.pipeline, self.import_code)
parallel_directives = self.pipeline[0].parallel_directives
self.assertEqual(parallel_directives, self.expected_directives_dict)
def test_parallel_directives_imports(self):
self.run_pipeline(self.pipeline,
self.import_code.replace(u'cimport', u'import'))
parallel_directives = self.pipeline[0].parallel_directives
self.assertEqual(parallel_directives, self.expected_directives_dict)
# TODO: Re-enable once they're more robust. # TODO: Re-enable once they're more robust.
if sys.version_info[:2] >= (2, 5) and False: if sys.version_info[:2] >= (2, 5) and False:
from Cython.Debugger import DebugWriter from Cython.Debugger import DebugWriter
......
...@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type) ...@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type)
class MarkAssignments(CythonTransform): class MarkAssignments(CythonTransform):
def mark_assignment(self, lhs, rhs): def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
# Track the parallel block scopes (with parallel, for i in prange())
self.parallel_block_stack = []
def mark_assignment(self, lhs, rhs, inplace_op=None):
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)): if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
if lhs.entry is None: if lhs.entry is None:
# TODO: This shouldn't happen... # TODO: This shouldn't happen...
return return
lhs.entry.assignments.append(rhs) lhs.entry.assignments.append(rhs)
if self.parallel_block_stack:
parallel_node = self.parallel_block_stack[-1]
parallel_node.assignments[lhs.entry] = (lhs.pos, inplace_op)
elif isinstance(lhs, ExprNodes.SequenceNode): elif isinstance(lhs, ExprNodes.SequenceNode):
for arg in lhs.args: for arg in lhs.args:
self.mark_assignment(arg, object_expr) self.mark_assignment(arg, object_expr)
...@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform): ...@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform):
return node return node
def visit_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
self.mark_assignment(node.lhs, node.create_binop_node()) self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform): ...@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_ParallelStatNode(self, node):
if self.parallel_block_stack:
node.parent = self.parallel_block_stack[-1]
else:
node.parent = None
if node.is_prange:
if not node.parent:
node.is_parallel = True
else:
node.is_parallel = (node.parent.is_prange or not
node.parent.is_parallel)
else:
node.is_parallel = True
self.parallel_block_stack.append(node)
self.visitchildren(node)
self.parallel_block_stack.pop()
return node
class MarkOverflowingArithmetic(CythonTransform): class MarkOverflowingArithmetic(CythonTransform):
# It may be possible to integrate this with the above for # It may be possible to integrate this with the above for
......
cdef extern from "omp.h":
ctypedef struct omp_lock_t
ctypedef struct omp_nest_lock_t
ctypedef enum omp_sched_t:
omp_sched_static = 1,
omp_sched_dynamic = 2,
omp_sched_guided = 3,
omp_sched_auto = 4
extern void omp_set_num_threads(int)
extern int omp_get_num_threads()
extern int omp_get_max_threads()
extern int omp_get_thread_num()
extern int omp_get_num_procs()
extern int omp_in_parallel()
extern void omp_set_dynamic(int)
extern int omp_get_dynamic()
extern void omp_set_nested(int)
extern int omp_get_nested()
extern void omp_init_lock(omp_lock_t *)
extern void omp_destroy_lock(omp_lock_t *)
extern void omp_set_lock(omp_lock_t *)
extern void omp_unset_lock(omp_lock_t *)
extern int omp_test_lock(omp_lock_t *)
extern void omp_init_nest_lock(omp_nest_lock_t *)
extern void omp_destroy_nest_lock(omp_nest_lock_t *)
extern void omp_set_nest_lock(omp_nest_lock_t *)
extern void omp_unset_nest_lock(omp_nest_lock_t *)
extern int omp_test_nest_lock(omp_nest_lock_t *)
extern double omp_get_wtime()
extern double omp_get_wtick()
void omp_set_schedule(omp_sched_t, int)
void omp_get_schedule(omp_sched_t *, int *)
int omp_get_thread_limit()
void omp_set_max_active_levels(int)
int omp_get_max_active_levels()
int omp_get_level()
int omp_get_ancestor_thread_num(int)
int omp_get_team_size(int)
int omp_get_active_level()
...@@ -277,3 +277,28 @@ for t in int_types + float_types + complex_types + other_types: ...@@ -277,3 +277,28 @@ for t in int_types + float_types + complex_types + other_types:
void = typedef(None) void = typedef(None)
NULL = p_void(0) NULL = p_void(0)
class CythonDotParallel(object):
"""
The cython.parallel module.
"""
__all__ = ['parallel', 'prange', 'threadid']
parallel = nogil
def prange(self, start=0, stop=None, step=1, schedule=None, nogil=False):
if stop is None:
stop = start
start = 0
return range(start, stop, step)
def threadid(self):
return 0
# def threadsavailable(self):
# return 1
import sys
sys.modules['cython.parallel'] = CythonDotParallel()
del sys
\ No newline at end of file
...@@ -18,6 +18,7 @@ Contents: ...@@ -18,6 +18,7 @@ Contents:
limitations limitations
pyrex_differences pyrex_differences
early_binding_for_speed early_binding_for_speed
parallelism
debugging debugging
Indices and tables Indices and tables
......
.. highlight:: cython
.. py:module:: cython.parallel
**********************************
Using Parallelism
**********************************
Cython supports native parallelism through the :py:mod:`cython.parallel`
module. To use this kind of parallelism, the GIL must be released. It
currently supports OpenMP, but later on more backends might be supported.
.. function:: prange([start,] stop[, step], nogil=False, schedule=None)
This function can be used for parallel loops. OpenMP automatically
starts a thread pool and distributes the work according to the schedule
used. ``step`` must not be 0. This function can only be used with the
GIL released. If ``nogil`` is true, the loop will be wrapped in a nogil
section.
Thread-locality and reductions are automatically inferred for variables.
If you assign to a variable, it becomes lastprivate, meaning that the
variable will contain the value from the last iteration. If you use an
inplace operator on a variable, it becomes a reduction, meaning that the
values from the thread-local copies of the variable will be reduced with
the operator and assigned to the original variable after the loop. The
index variable is always lastprivate.
The ``schedule`` is passed to OpenMP and can be one of the following:
+-----------------+------------------------------------------------------+
| Schedule | Description |
+=================+======================================================+
|static | The iteration space is divided into chunks that are |
| | approximately equal in size, and at most one chunk |
| | is distributed to each thread. |
+-----------------+------------------------------------------------------+
|dynamic | The iterations are distributed to threads in the team|
| | as the threads request them, with a chunk size of 1. |
+-----------------+------------------------------------------------------+
|guided | The iterations are distributed to threads in the team|
| | as the threads request them. The size of each chunk |
| | is proportional to the number of unassigned |
| | iterations divided by the number of threads in the |
| | team, decreasing to 1. |
+-----------------+------------------------------------------------------+
|auto | The decision regarding scheduling is delegated to the|
| | compiler and/or runtime system. The programmer gives |
| | the implementation the freedom to choose any possible|
| | mapping of iterations to threads in the team. |
+-----------------+------------------------------------------------------+
|runtime | The schedule and chunk size are taken from the |
| | runtime-scheduling-variable, which can be set through|
| | the ``omp_set_schedule`` function call, or the |
| | ``OMP_SCHEDULE`` environment variable. |
+-----------------+------------------------------------------------------+
The default schedule is implementation defined. For more information consult
the OpenMP specification: [#]_.
Example with a reduction::
from cython.parallel import prange, parallel, threadid
cdef int i
cdef int sum = 0
for i in prange(n, nogil=True):
sum += i
print sum
Example with a shared numpy array::
from cython.parallel import *
def func(np.ndarray[double] x, double alpha):
cdef Py_ssize_t i
for i in prange(x.shape[0]):
x[i] = alpha * x[i]
.. function:: parallel
This directive can be used as part of a ``with`` statement to execute code
sequences in parallel. This is currently useful to setup thread-local
buffers used by a prange. A contained prange will be a worksharing loop
that is not parallel, so any variable assigned to in the parallel section
is also private to the prange. Variables that are private in the parallel
construct are undefined after the parallel block.
Example with thread-local buffers::
from cython.parallel import *
from cython.stdlib cimport abort
cdef Py_ssize_t i, n = 100
cdef int * local_buf
cdef size_t size = 10
with nogil, parallel:
local_buf = malloc(sizeof(int) * size)
if local_buf == NULL:
abort()
# populate our local buffer in a sequential loop
for i in range(size):
local_buf[i] = i * 2
# share the work using the thread-local buffer(s)
for i in prange(n, schedule='guided'):
func(local_buf)
free(local_buf)
Later on sections might be supported in parallel blocks, to distribute
code sections of work among threads.
.. function:: threadid()
Returns the id of the thread. For n threads, the ids will range from 0 to
n.
Compiling
=========
To actually use the OpenMP support, you need to tell the C or C++ compiler to
enable OpenMP. For gcc this can be done as follows in a setup.py::
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
ext_module = Extension(
"hello",
["hello.pyx"],
extra_compile_args=['-fopenmp'],
libraries=['gomp'],
)
setup(
name = 'Hello world app',
cmdclass = {'build_ext': build_ext},
ext_modules = [ext_module],
)
.. rubric:: References
.. [#] http://www.openmp.org/mp-documents/spec30.pdf
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import sys import sys
import re import re
import gc import gc
import locale
import codecs import codecs
import shutil import shutil
import time import time
...@@ -11,6 +12,7 @@ import unittest ...@@ -11,6 +12,7 @@ import unittest
import doctest import doctest
import operator import operator
import tempfile import tempfile
import warnings
import traceback import traceback
try: try:
from StringIO import StringIO from StringIO import StringIO
...@@ -54,6 +56,7 @@ CY3_DIR = None ...@@ -54,6 +56,7 @@ CY3_DIR = None
from distutils.dist import Distribution from distutils.dist import Distribution
from distutils.core import Extension from distutils.core import Extension
from distutils.command.build_ext import build_ext as _build_ext from distutils.command.build_ext import build_ext as _build_ext
from distutils import sysconfig
distutils_distro = Distribution() distutils_distro = Distribution()
if sys.platform == 'win32': if sys.platform == 'win32':
...@@ -78,8 +81,83 @@ def update_numpy_extension(ext): ...@@ -78,8 +81,83 @@ def update_numpy_extension(ext):
import numpy import numpy
ext.include_dirs.append(numpy.get_include()) ext.include_dirs.append(numpy.get_include())
def update_openmp_extension(ext):
language = ext.language
if language == 'cpp':
flags = OPENMP_CPP_COMPILER_FLAGS
else:
flags = OPENMP_C_COMPILER_FLAGS
if flags:
compile_flags, link_flags = flags
ext.extra_compile_args.extend(compile_flags.split())
ext.extra_link_args.extend(link_flags.split())
return ext
return EXCLUDE_EXT
def get_openmp_compiler_flags(language):
"""
As of gcc 4.2, it supports OpenMP 2.5. Gcc 4.4 implements 3.0. We don't
(currently) check for other compilers.
returns a two-tuple of (CFLAGS, LDFLAGS) to build the OpenMP extension
"""
if language == 'cpp':
cc = sysconfig.get_config_var('CXX')
else:
cc = sysconfig.get_config_var('CC')
# For some reason, cc can be e.g. 'gcc -pthread'
cc = cc.split()[0]
matcher = re.compile(r"gcc version (\d+\.\d+)").search
try:
import subprocess
except ImportError:
try:
in_, out, err = os.popen(cc + " -v")
except EnvironmentError:
# Be compatible with Python 3
_, e, _ = sys.exc_info()
warnings.warn("Unable to find the %s compiler: %s: %s" %
(language, os.strerror(e.errno), cc))
return None
output = out.read() or err.read()
else:
try:
p = subprocess.Popen([cc, "-v"], stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
except EnvironmentError:
# Be compatible with Python 3
_, e, _ = sys.exc_info()
warnings.warn("Unable to find the %s compiler: %s: %s" %
(language, os.strerror(e.errno), cc))
return None
output = p.stdout.read()
output = output.decode(locale.getpreferredencoding() or 'UTF-8')
compiler_version = matcher(output).group(1)
if compiler_version and compiler_version.split('.') >= ['4', '2']:
return '-fopenmp', '-fopenmp'
locale.setlocale(locale.LC_ALL, '')
OPENMP_C_COMPILER_FLAGS = get_openmp_compiler_flags('c')
OPENMP_CPP_COMPILER_FLAGS = get_openmp_compiler_flags('cpp')
# Return this from the EXT_EXTRAS matcher callback to exclude the extension
EXCLUDE_EXT = object()
EXT_EXTRAS = { EXT_EXTRAS = {
'tag:numpy' : update_numpy_extension, 'tag:numpy' : update_numpy_extension,
'tag:openmp': update_openmp_extension,
} }
# TODO: use tags # TODO: use tags
...@@ -519,13 +597,21 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -519,13 +597,21 @@ class CythonCompileTestCase(unittest.TestCase):
extra_compile_args = ext_compile_flags, extra_compile_args = ext_compile_flags,
**extra_extension_args **extra_extension_args
) )
if self.language == 'cpp':
# Set the language now as the fixer might need it
extension.language = 'c++'
for matcher, fixer in EXT_EXTRAS.items(): for matcher, fixer in EXT_EXTRAS.items():
if isinstance(matcher, str): if isinstance(matcher, str):
del EXT_EXTRAS[matcher] del EXT_EXTRAS[matcher]
matcher = string_selector(matcher) matcher = string_selector(matcher)
EXT_EXTRAS[matcher] = fixer EXT_EXTRAS[matcher] = fixer
if matcher(module, tags): if matcher(module, tags):
extension = fixer(extension) or extension newext = fixer(extension)
if newext is EXCLUDE_EXT:
return
extension = newext or extension
if self.language == 'cpp': if self.language == 'cpp':
extension.language = 'c++' extension.language = 'c++'
build_extension.extensions = [extension] build_extension.extensions = [extension]
...@@ -646,6 +732,7 @@ def run_forked_test(result, run_func, test_name, fork=True): ...@@ -646,6 +732,7 @@ def run_forked_test(result, run_func, test_name, fork=True):
try: try:
cid, result_code = os.waitpid(child_id, 0) cid, result_code = os.waitpid(child_id, 0)
module_name = test_name.split()[-1]
# os.waitpid returns the child's result code in the # os.waitpid returns the child's result code in the
# upper byte of result_code, and the signal it was # upper byte of result_code, and the signal it was
# killed by in the lower byte # killed by in the lower byte
......
# mode: error
cimport cython.parallel.parallel as p
from cython.parallel cimport something
import cython.parallel.parallel as p
from cython.parallel import something
from cython.parallel cimport prange
import cython.parallel
prange(1, 2, 3, schedule='dynamic')
cdef int i
with nogil, cython.parallel.parallel:
for i in prange(10, schedule='invalid_schedule'):
pass
with cython.parallel.parallel:
print "hello world!"
cdef int *x = NULL
with nogil, cython.parallel.parallel:
for j in prange(10):
pass
for x[1] in prange(10):
pass
for x in prange(10):
pass
with cython.parallel.parallel:
pass
_ERRORS = u"""
e_cython_parallel.pyx:3:8: cython.parallel.parallel is not a module
e_cython_parallel.pyx:4:0: No such directive: cython.parallel.something
e_cython_parallel.pyx:6:7: cython.parallel.parallel is not a module
e_cython_parallel.pyx:7:0: No such directive: cython.parallel.something
e_cython_parallel.pyx:13:6: prange() can only be used as part of a for loop
e_cython_parallel.pyx:13:6: prange() can only be used without the GIL
e_cython_parallel.pyx:18:19: Invalid schedule argument to prange: invalid_schedule
e_cython_parallel.pyx:21:5: The parallel section may only be used without the GIL
e_cython_parallel.pyx:27:10: target may not be a Python object as we don't have the GIL
e_cython_parallel.pyx:30:9: Can only iterate over an iteration variable
e_cython_parallel.pyx:33:10: Must be of numeric type, not int *
e_cython_parallel.pyx:36:24: Closely nested 'with parallel:' blocks are disallowed
"""
# tag: numpy
# tag: openmp
cimport cython
from cython.parallel import prange
cimport numpy as np
@cython.boundscheck(False)
def test_parallel_numpy_arrays():
"""
>>> test_parallel_numpy_arrays()
-5
-4
-3
-2
-1
0
1
2
3
4
"""
cdef Py_ssize_t i
cdef np.ndarray[np.int_t] x
try:
import numpy
except ImportError:
for i in range(-5, 5):
print i
return
x = numpy.zeros(10, dtype=numpy.int)
for i in prange(x.shape[0], nogil=True):
x[i] = i - 5
for i in x:
print i
# tag: run
# tag: openmp
cimport cython.parallel
from cython.parallel import prange, threadid
cimport openmp
from libc.stdlib cimport malloc, free
def test_parallel():
"""
>>> test_parallel()
"""
cdef int maxthreads = openmp.omp_get_max_threads()
cdef int *buf = <int *> malloc(sizeof(int) * maxthreads)
if buf == NULL:
raise MemoryError
with nogil, cython.parallel.parallel:
buf[threadid()] = threadid()
for i in range(maxthreads):
assert buf[i] == i
free(buf)
include "sequential_parallel.pyx"
# tag: run
cimport cython.parallel
from cython.parallel import prange, threadid
from libc.stdlib cimport malloc, free, abort
from libc.stdio cimport puts
import sys
try:
from builtins import next # Py3k
except ImportError:
def next(it):
return it.next()
#@cython.test_assert_path_exists(
# "//ParallelWithBlockNode//ParallelRangeNode[@schedule = 'dynamic']",
# "//GILStatNode[@state = 'nogil]//ParallelRangeNode")
def test_prange():
"""
>>> test_prange()
(9, 9, 45, 45)
"""
cdef Py_ssize_t i, j, sum1 = 0, sum2 = 0
with nogil, cython.parallel.parallel:
for i in prange(10, schedule='dynamic'):
sum1 += i
for j in prange(10, nogil=True):
sum2 += j
return i, j, sum1, sum2
def test_descending_prange():
"""
>>> test_descending_prange()
5
"""
cdef int i, start = 5, stop = -5, step = -2
cdef int sum = 0
for i in prange(start, stop, step, nogil=True):
sum += i
return sum
def test_propagation():
"""
>>> test_propagation()
(9, 9, 9, 9, 450, 450)
"""
cdef int i, j, x, y
cdef int sum1 = 0, sum2 = 0
for i in prange(10, nogil=True):
for j in prange(10):
sum1 += i
with nogil, cython.parallel.parallel:
for x in prange(10):
with cython.parallel.parallel:
for y in prange(10):
sum2 += y
return i, j, x, y, sum1, sum2
def test_unsigned_operands():
"""
>>> test_unsigned_operands()
10
"""
cdef int i
cdef int start = -5
cdef unsigned int stop = 5
cdef int step = 1
cdef int steps_taken = 0
for i in prange(start, stop, step, nogil=True):
steps_taken += 1
if steps_taken > 10:
abort()
return steps_taken
def test_reassign_start_stop_step():
"""
>>> test_reassign_start_stop_step()
20
"""
cdef int start = 0, stop = 10, step = 2
cdef int i
cdef int sum = 0
for i in prange(start, stop, step, nogil=True):
start = -2
stop = 2
step = 0
sum += i
return sum
def test_closure_parallel_privates():
"""
>>> test_closure_parallel_privates()
9 9
45 45
0 0 9 9
"""
cdef int x
def test_target():
nonlocal x
for x in prange(10, nogil=True):
pass
return x
print test_target(), x
def test_reduction():
nonlocal x
cdef int i
x = 0
for i in prange(10, nogil=True):
x += i
return x
print test_reduction(), x
def test_generator():
nonlocal x
cdef int i
x = 0
yield x
x = 2
for i in prange(10, nogil=True):
x = i
yield x
g = test_generator()
print next(g), x, next(g), x
def test_pure_mode():
"""
>>> test_pure_mode()
0
1
2
3
4
4
3
2
1
0
0
"""
import Cython.Shadow
pure_parallel = sys.modules['cython.parallel']
for i in pure_parallel.prange(5):
print i
for i in pure_parallel.prange(4, -1, -1, schedule='dynamic', nogil=True):
print i
with pure_parallel.parallel:
print pure_parallel.threadid()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment