Commit 18d8b3fb authored by Mark Florisson's avatar Mark Florisson Committed by Vitja Makarov

Preliminary OpenMP support

parent 8f47d370
...@@ -1444,6 +1444,10 @@ class CCodeWriter(object): ...@@ -1444,6 +1444,10 @@ class CCodeWriter(object):
def put_trace_return(self, retvalue_cname): def put_trace_return(self, retvalue_cname):
self.putln("__Pyx_TraceReturn(%s);" % retvalue_cname) self.putln("__Pyx_TraceReturn(%s);" % retvalue_cname)
def putln_openmp(self, string):
self.putln("#ifdef _OPENMP")
self.putln(string)
self.putln("#endif")
class PyrexCodeWriter(object): class PyrexCodeWriter(object):
# f file output file # f file output file
......
...@@ -2059,6 +2059,62 @@ class RawCNameExprNode(ExprNode): ...@@ -2059,6 +2059,62 @@ class RawCNameExprNode(ExprNode):
pass pass
#-------------------------------------------------------------------
#
# Parallel nodes (cython.parallel.thread(savailable|id))
#
#-------------------------------------------------------------------
class ParallelThreadsAvailableNode(AtomicExprNode):
"""
Implements cython.parallel.threadsavailable(). If we are called from the
sequential part of the application, we need to call omp_get_max_threads(),
and in the parallel part we can just call omp_get_num_threads()
"""
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.is_temp = True
env.add_include_file("omp.h")
return self.type
def generate_result_code(self, code):
code.putln("#ifdef _OPENMP")
code.putln("if (omp_in_parallel()) %s = omp_get_max_threads();" %
self.temp_code)
code.putln("else %s = omp_get_num_threads();" % self.temp_code)
code.putln("#else")
code.putln("%s = 1;" % self.temp_code)
code.putln("#endif")
def result(self):
return self.temp_code
class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode):
"""
Implements cython.parallel.threadid()
"""
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.is_temp = True
env.add_include_file("omp.h")
return self.type
def generate_result_code(self, code):
code.putln("#ifdef _OPENMP")
code.putln("%s = omp_get_thread_num();" % self.temp_code)
code.putln("#else")
code.putln("%s = 0;" % self.temp_code)
code.putln("#endif")
def result(self):
return self.temp_code
#------------------------------------------------------------------- #-------------------------------------------------------------------
# #
# Trailer nodes # Trailer nodes
...@@ -3465,8 +3521,11 @@ class AttributeNode(ExprNode): ...@@ -3465,8 +3521,11 @@ class AttributeNode(ExprNode):
needs_none_check = True needs_none_check = True
def as_cython_attribute(self): def as_cython_attribute(self):
if isinstance(self.obj, NameNode) and self.obj.is_cython_module: if (isinstance(self.obj, NameNode) and
self.obj.is_cython_module and not
self.attribute == u"parallel"):
return self.attribute return self.attribute
cy = self.obj.as_cython_attribute() cy = self.obj.as_cython_attribute()
if cy: if cy:
return "%s.%s" % (cy, self.attribute) return "%s.%s" % (cy, self.attribute)
......
...@@ -106,7 +106,7 @@ class Context(object): ...@@ -106,7 +106,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators from ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
from TypeInference import MarkAssignments, MarkOverflowingArithmetic from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from ParseTreeTransforms import RemoveUnreachableCode from ParseTreeTransforms import RemoveUnreachableCode
...@@ -136,6 +136,7 @@ class Context(object): ...@@ -136,6 +136,7 @@ class Context(object):
PostParse(self), PostParse(self),
_specific_post_parse, _specific_post_parse,
InterpretCompilerDirectives(self, self.compiler_directives), InterpretCompilerDirectives(self, self.compiler_directives),
ParallelRangeTransform(self),
MarkClosureVisitor(self), MarkClosureVisitor(self),
_align_function_definitions, _align_function_definitions,
RemoveUnreachableCode(self), RemoveUnreachableCode(self),
......
This diff is collapsed.
This diff is collapsed.
...@@ -723,6 +723,9 @@ class Scope(object): ...@@ -723,6 +723,9 @@ class Scope(object):
else: else:
return outer.is_cpp() return outer.is_cpp()
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)
class PreImportScope(Scope): class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname namespace_cname = Naming.preimport_cname
...@@ -1856,8 +1859,6 @@ class CppClassScope(Scope): ...@@ -1856,8 +1859,6 @@ class CppClassScope(Scope):
utility_code = e.utility_code) utility_code = e.utility_code)
return scope return scope
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)
class PropertyScope(Scope): class PropertyScope(Scope):
# Scope holding the __get__, __set__ and __del__ methods for # Scope holding the __get__, __set__ and __del__ methods for
......
...@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine ...@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine
from Cython.TestUtils import TransformTest from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import * from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler import Main
class TestNormalizeTree(TransformTest): class TestNormalizeTree(TransformTest):
...@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled! ...@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled!
""", t) """, t)
class TestInterpretCompilerDirectives(TransformTest):
"""
This class tests the parallel directives AST-rewriting and importing.
"""
# Test the parallel directives (c)importing
import_code = u"""
cimport cython.parallel
cimport cython.parallel as par
from cython cimport parallel as par2
from cython cimport parallel
from cython.parallel cimport threadid as tid
from cython.parallel cimport threadavailable as tavail
from cython.parallel cimport prange
"""
expected_directives_dict = {
u'cython.parallel': u'cython.parallel',
u'par': u'cython.parallel',
u'par2': u'cython.parallel',
u'parallel': u'cython.parallel',
u"tid": u"cython.parallel.threadid",
u"tavail": u"cython.parallel.threadavailable",
u"prange": u"cython.parallel.prange",
}
def setUp(self):
super(TestInterpretCompilerDirectives, self).setUp()
compilation_options = Main.CompilationOptions(Main.default_options)
ctx = compilation_options.create_context()
self.pipeline = [
InterpretCompilerDirectives(ctx, ctx.compiler_directives),
]
self.debug_exception_on_error = DebugFlags.debug_exception_on_error
def tearDown(self):
DebugFlags.debug_exception_on_error = self.debug_exception_on_error
def test_parallel_directives_cimports(self):
self.run_pipeline(self.pipeline, self.import_code)
parallel_directives = self.pipeline[0].parallel_directives
self.assertEqual(parallel_directives, self.expected_directives_dict)
def test_parallel_directives_imports(self):
self.run_pipeline(self.pipeline,
self.import_code.replace(u'cimport', u'import'))
parallel_directives = self.pipeline[0].parallel_directives
self.assertEqual(parallel_directives, self.expected_directives_dict)
# TODO: Re-enable once they're more robust. # TODO: Re-enable once they're more robust.
if sys.version_info[:2] >= (2, 5) and False: if sys.version_info[:2] >= (2, 5) and False:
from Cython.Debugger import DebugWriter from Cython.Debugger import DebugWriter
......
...@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type) ...@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type)
class MarkAssignments(CythonTransform): class MarkAssignments(CythonTransform):
def mark_assignment(self, lhs, rhs): def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
# Track the parallel block scopes (with parallel, for i in prange())
self.parallel_block_stack = []
def mark_assignment(self, lhs, rhs, inplace_op=None):
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)): if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
if lhs.entry is None: if lhs.entry is None:
# TODO: This shouldn't happen... # TODO: This shouldn't happen...
return return
lhs.entry.assignments.append(rhs) lhs.entry.assignments.append(rhs)
if self.parallel_block_stack:
parallel_node = self.parallel_block_stack[-1]
parallel_node.assignments[lhs.entry] = (lhs.pos, inplace_op)
elif isinstance(lhs, ExprNodes.SequenceNode): elif isinstance(lhs, ExprNodes.SequenceNode):
for arg in lhs.args: for arg in lhs.args:
self.mark_assignment(arg, object_expr) self.mark_assignment(arg, object_expr)
...@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform): ...@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform):
return node return node
def visit_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
self.mark_assignment(node.lhs, node.create_binop_node()) self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform): ...@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_ParallelStatNode(self, node):
if self.parallel_block_stack:
node.parent = self.parallel_block_stack[-1]
else:
node.parent = None
if node.is_prange:
if not node.parent:
node.is_parallel = True
else:
node.is_parallel = (node.parent.is_prange or not
node.parent.is_parallel)
else:
node.is_parallel = True
self.parallel_block_stack.append(node)
self.visitchildren(node)
self.parallel_block_stack.pop()
return node
class MarkOverflowingArithmetic(CythonTransform): class MarkOverflowingArithmetic(CythonTransform):
# It may be possible to integrate this with the above for # It may be possible to integrate this with the above for
......
# tag: run
# distutils: libraries = gomp
# distutils: extra_compile_args = -fopenmp
cimport cython.parallel
from cython.parallel import prange, threadid
from libc.stdlib cimport malloc, free
cdef extern from "Python.h":
void PyEval_InitThreads()
PyEval_InitThreads()
cdef void print_int(int x) with gil:
print x
#@cython.test_assert_path_exists(
# "//ParallelWithBlockNode//ParallelRangeNode[@schedule = 'dynamic']",
# "//GILStatNode[@state = 'nogil]//ParallelRangeNode")
def test_prange():
"""
>>> test_prange()
(9, 9, 45, 45)
"""
cdef Py_ssize_t i, j, sum1 = 0, sum2 = 0
with nogil, cython.parallel.parallel:
for i in prange(10, schedule='dynamic'):
sum1 += i
for j in prange(10, nogil=True):
sum2 += j
return i, j, sum1, sum2
def test_descending_prange():
"""
>>> test_descending_prange()
5
"""
cdef int i, start = 5, stop = -5, step = -2
cdef int sum = 0
for i in prange(start, stop, step, nogil=True):
sum += i
return sum
def test_nested_prange():
"""
Reduction propagation is not (yet) supported.
>>> test_nested_prange()
50
"""
cdef int i, j
cdef int sum = 0
for i in prange(5, nogil=True):
for j in prange(5):
sum += i
# The value of sum is undefined here
sum = 0
for i in prange(5, nogil=True):
for j in prange(5):
sum += i
sum += 0
return sum
# threadsavailable test, disable this for now as it won't compile
#def test_parallel():
# """
# >>> test_parallel()
# """
# cdef int *buf = <int *> malloc(sizeof(int) * threadsavailable())
#
# if buf == NULL:
# raise MemoryError
#
# with nogil, cython.parallel.parallel:
# buf[threadid()] = threadid()
#
# for i in range(threadsavailable()):
# assert buf[i] == i
#
# free(buf)
def test_unsigned_operands():
"""
This test is disabled, as this currently does not work (neither does it
for 'for i from x < i < y:'. I'm not sure we should strife to support
this, at least the C compiler gives a warning.
test_unsigned_operands()
10
"""
cdef int i
cdef int start = -5
cdef unsigned int stop = 5
cdef int step = 1
cdef int steps_taken = 0
for i in prange(start, stop, step, nogil=True):
steps_taken += 1
return steps_taken
def test_reassign_start_stop_step():
"""
>>> test_reassign_start_stop_step()
20
"""
cdef int start = 0, stop = 10, step = 2
cdef int i
cdef int sum = 0
for i in prange(start, stop, step, nogil=True):
start = -2
stop = 2
step = 0
sum += i
return sum
def test_closure_parallel_privates():
"""
>>> test_closure_parallel_privates()
9 9
45 45
0 0 9 9
"""
cdef int x
def test_target():
nonlocal x
for x in prange(10, nogil=True):
pass
return x
print test_target(), x
def test_reduction():
nonlocal x
cdef int i
x = 0
for i in prange(10, nogil=True):
x += i
return x
print test_reduction(), x
def test_generator():
nonlocal x
cdef int i
x = 0
yield x
x = 2
for i in prange(10, nogil=True):
x = i
yield x
g = test_generator()
print g.next(), x, g.next(), x
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment