Commit 18d8b3fb authored by Mark Florisson's avatar Mark Florisson Committed by Vitja Makarov

Preliminary OpenMP support

parent 8f47d370
...@@ -1444,6 +1444,10 @@ class CCodeWriter(object): ...@@ -1444,6 +1444,10 @@ class CCodeWriter(object):
def put_trace_return(self, retvalue_cname): def put_trace_return(self, retvalue_cname):
self.putln("__Pyx_TraceReturn(%s);" % retvalue_cname) self.putln("__Pyx_TraceReturn(%s);" % retvalue_cname)
def putln_openmp(self, string):
self.putln("#ifdef _OPENMP")
self.putln(string)
self.putln("#endif")
class PyrexCodeWriter(object): class PyrexCodeWriter(object):
# f file output file # f file output file
......
...@@ -2059,6 +2059,62 @@ class RawCNameExprNode(ExprNode): ...@@ -2059,6 +2059,62 @@ class RawCNameExprNode(ExprNode):
pass pass
#-------------------------------------------------------------------
#
# Parallel nodes (cython.parallel.thread(savailable|id))
#
#-------------------------------------------------------------------
class ParallelThreadsAvailableNode(AtomicExprNode):
"""
Implements cython.parallel.threadsavailable(). If we are called from the
sequential part of the application, we need to call omp_get_max_threads(),
and in the parallel part we can just call omp_get_num_threads()
"""
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.is_temp = True
env.add_include_file("omp.h")
return self.type
def generate_result_code(self, code):
code.putln("#ifdef _OPENMP")
code.putln("if (omp_in_parallel()) %s = omp_get_max_threads();" %
self.temp_code)
code.putln("else %s = omp_get_num_threads();" % self.temp_code)
code.putln("#else")
code.putln("%s = 1;" % self.temp_code)
code.putln("#endif")
def result(self):
return self.temp_code
class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode):
"""
Implements cython.parallel.threadid()
"""
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.is_temp = True
env.add_include_file("omp.h")
return self.type
def generate_result_code(self, code):
code.putln("#ifdef _OPENMP")
code.putln("%s = omp_get_thread_num();" % self.temp_code)
code.putln("#else")
code.putln("%s = 0;" % self.temp_code)
code.putln("#endif")
def result(self):
return self.temp_code
#------------------------------------------------------------------- #-------------------------------------------------------------------
# #
# Trailer nodes # Trailer nodes
...@@ -3465,8 +3521,11 @@ class AttributeNode(ExprNode): ...@@ -3465,8 +3521,11 @@ class AttributeNode(ExprNode):
needs_none_check = True needs_none_check = True
def as_cython_attribute(self): def as_cython_attribute(self):
if isinstance(self.obj, NameNode) and self.obj.is_cython_module: if (isinstance(self.obj, NameNode) and
self.obj.is_cython_module and not
self.attribute == u"parallel"):
return self.attribute return self.attribute
cy = self.obj.as_cython_attribute() cy = self.obj.as_cython_attribute()
if cy: if cy:
return "%s.%s" % (cy, self.attribute) return "%s.%s" % (cy, self.attribute)
......
...@@ -106,7 +106,7 @@ class Context(object): ...@@ -106,7 +106,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators from ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
from TypeInference import MarkAssignments, MarkOverflowingArithmetic from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from ParseTreeTransforms import RemoveUnreachableCode from ParseTreeTransforms import RemoveUnreachableCode
...@@ -136,6 +136,7 @@ class Context(object): ...@@ -136,6 +136,7 @@ class Context(object):
PostParse(self), PostParse(self),
_specific_post_parse, _specific_post_parse,
InterpretCompilerDirectives(self, self.compiler_directives), InterpretCompilerDirectives(self, self.compiler_directives),
ParallelRangeTransform(self),
MarkClosureVisitor(self), MarkClosureVisitor(self),
_align_function_definitions, _align_function_definitions,
RemoveUnreachableCode(self), RemoveUnreachableCode(self),
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# Pyrex - Parse tree nodes # Pyrex - Parse tree nodes
# #
import cython import cython
from cython import set from cython import set
cython.declare(sys=object, os=object, time=object, copy=object, cython.declare(sys=object, os=object, time=object, copy=object,
...@@ -4597,14 +4596,12 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -4597,14 +4596,12 @@ class ForInStatNode(LoopNode, StatNode):
old_loop_labels = code.new_loop_labels() old_loop_labels = code.new_loop_labels()
self.iterator.allocate_counter_temp(code) self.iterator.allocate_counter_temp(code)
self.iterator.generate_evaluation_code(code) self.iterator.generate_evaluation_code(code)
code.putln( code.putln("for (;;) {")
"for (;;) {")
self.item.generate_evaluation_code(code) self.item.generate_evaluation_code(code)
self.target.generate_assignment_code(self.item, code) self.target.generate_assignment_code(self.item, code)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
code.put_label(code.continue_label) code.put_label(code.continue_label)
code.putln( code.putln("}")
"}")
break_label = code.break_label break_label = code.break_label
code.set_loop_labels(old_loop_labels) code.set_loop_labels(old_loop_labels)
...@@ -5735,6 +5732,371 @@ class FromImportStatNode(StatNode): ...@@ -5735,6 +5732,371 @@ class FromImportStatNode(StatNode):
self.module.free_temps(code) self.module.free_temps(code)
class ParallelNode(Node):
"""
Base class for cython.parallel constructs.
"""
nogil_check = None
class ParallelStatNode(StatNode, ParallelNode):
"""
Base class for 'with cython.parallel.parallel:' and 'for i in prange():'.
assignments { Entry(var) : (var.pos, inplace_operator_or_None) }
assignments to variables in this parallel section
parent parent ParallelStatNode or None
is_parallel indicates whether this is a parallel node
is_parallel is true for:
#pragma omp parallel
#pragma omp parallel for
sections, but NOT for
#pragma omp for
We need this to determine the sharing attributes.
"""
child_attrs = ['body']
body = None
is_prange = False
def __init__(self, pos, **kwargs):
super(ParallelStatNode, self).__init__(pos, **kwargs)
self.assignments = kwargs.get('assignments') or {}
# Insertion point before the outermost parallel section
self.before_parallel_section_point = None
# Insertion point after the outermost parallel section
self.post_parallel_section_point = None
def analyse_expressions(self, env):
self.body.analyse_expressions(env)
def analyse_declarations(self, env):
super(ParallelStatNode, self).analyse_declarations(env)
self.body.analyse_declarations(env)
def lookup_assignment(self, entry):
"""
Return an assignment's pos and operator. If the parent has the
assignment, return the parent's assignment, otherwise our own.
"""
parent_assignment = self.parent and self.parent.lookup_assignment(entry)
return parent_assignment or self.assignments.get(entry)
def is_private(self, entry):
"""
True if this scope should declare the variable private, lastprivate
or reduction.
"""
parent_or_our_entry = self.lookup_assignment(entry)
our_entry = self.assignments.get(entry)
return self.is_parallel or parent_or_our_entry == our_entry
def _allocate_closure_temp(self, code, entry):
"""
Helper function that allocate a temporary for a closure variable that
is assigned to.
"""
if self.parent:
return self.parent._allocate_closure_temp(code, entry)
cname = code.funcstate.allocate_temp(entry.type, False)
self.modified_entries.append((entry, entry.cname))
code.putln("%s = %s;" % (cname, entry.cname))
entry.cname = cname
return cname
def declare_closure_privates(self, code):
"""
Set self.privates to a dict mapping C variable names that are to be
declared (first|last)private or reduction, to the reduction operator.
If the private is not a reduction, the operator is None.
This is used by subclasses.
If a variable is in a scope object, we need to allocate a temp and
assign the value from the temp to the variable in the scope object
after the parallel section. This kind of copying should be done only
in the outermost parallel section.
"""
self.privates = {}
self.modified_entries = []
for entry, (pos, op) in self.assignments.iteritems():
cname = entry.cname
if entry.from_closure or entry.in_closure:
cname = self._allocate_closure_temp(code, entry)
if self.is_private(entry):
self.privates[cname] = op
def release_closure_privates(self, code):
"Release any temps used for variables in scope objects"
for entry, original_cname in self.modified_entries:
code.putln("%s = %s;" % (original_cname, entry.cname))
code.funcstate.release_temp(entry.cname)
entry.cname = original_cname
class ParallelWithBlockNode(ParallelStatNode):
"""
This node represents a 'with cython.parallel:' block
"""
nogil_check = None
def generate_execution_code(self, code):
self.declare_closure_privates(code)
code.putln("#ifdef _OPENMP")
code.put("#pragma omp parallel ")
code.putln(' '.join(["private(%s)" % e.cname
for e in self.assignments
if self.is_private(e)]))
code.putln("#endif")
code.begin_block()
self.body.generate_execution_code(code)
code.end_block()
self.release_closure_privates(code)
class ParallelRangeNode(ParallelStatNode):
"""
This node represents a 'for i in cython.parallel.prange():' construct.
target NameNode the target iteration variable
else_clause Node or None the else clause of this loop
args tuple the arguments passed to prange()
kwargs DictNode the keyword arguments passed to prange()
(replaced by its compile time value)
is_nogil bool indicates whether this is a nogil prange() node
"""
child_attrs = ['body', 'target', 'else_clause', 'args']
body = target = else_clause = args = None
start = stop = step = None
is_prange = True
def analyse_declarations(self, env):
super(ParallelRangeNode, self).analyse_declarations(env)
self.target.analyse_target_declaration(env)
if self.else_clause is not None:
self.else_clause.analyse_declarations(env)
if not self.args or len(self.args) > 3:
error(self.pos, "Invalid number of positional arguments to prange")
return
if len(self.args) == 1:
self.stop, = self.args
elif len(self.args) == 2:
self.start, self.stop = self.args
else:
self.start, self.stop, self.step = self.args
if self.kwargs:
self.kwargs = self.kwargs.compile_time_value(env)
else:
self.kwargs = {}
self.is_nogil = self.kwargs.pop('nogil', False)
self.schedule = self.kwargs.pop('schedule', None)
if self.schedule not in (None, 'static', 'dynamic', 'guided',
'runtime'):
error(self.pos, "Invalid schedule argument to prange: %r" %
(self.schedule,))
for kw in self.kwargs:
error(self.pos, "Invalid keyword argument to prange: %s" % kw)
def analyse_expressions(self, env):
self.target.analyse_target_types(env)
self.index_type = self.target.type
if self.index_type.is_pyobject:
# nogil_check will catch this
return
# Setup start, stop and step, allocating temps if needed
self.names = 'start', 'stop', 'step'
start_stop_step = self.start, self.stop, self.step
for node, name in zip(start_stop_step, self.names):
if node is not None:
node.analyse_types(env)
if not node.type.is_numeric:
error(node.pos, "%s argument must be numeric or a pointer "
"(perhaps if a numeric literal is too "
"big, use 1000LL)" % name)
if not node.is_literal:
node = node.coerce_to_temp(env)
setattr(self, name, node)
# As we range from 0 to nsteps, computing the index along the
# way, we need a fitting type for 'i' and 'nsteps'
self.index_type = PyrexTypes.widest_numeric_type(
self.index_type, node.type)
self.body.analyse_expressions(env)
if self.else_clause is not None:
self.else_clause.analyse_expressions(env)
def nogil_check(self, env):
names = 'start', 'stop', 'step', 'target'
nodes = self.start, self.stop, self.step, self.target
for name, node in zip(names, nodes):
if node is not None and node.type.is_pyobject:
error(node.pos, "%s may not be a Python object "
"as we don't have the GIL" % name)
def generate_execution_code(self, code):
"""
Generate code in the following steps
1) copy any closure variables determined thread-private
into temporaries
2) allocate temps for start, stop and step
3) generate a loop that calculates the total number of steps,
which then computes the target iteration variable for every step:
for i in prange(start, stop, step):
...
becomes
nsteps = (stop - start) / step;
i = start;
#pragma omp parallel for lastprivate(i)
for (temp = 0; temp < nsteps; temp++) {
i = start + step * temp;
...
}
Note that accumulation of 'i' would have a data dependency
between iterations.
Also, you can't do this
for (i = start; i < stop; i += step)
...
as the '<' operator should become '>' for descending loops.
'for i from x < i < y:' does not suffer from this problem
as the relational operator is known at compile time!
4) release our temps and write back any private closure variables
"""
# Ensure to unpack the target index variable if it's a closure temp
self.assignments[self.target.entry] = self.target.pos, None
self.declare_closure_privates(code) #self.insertion_point(code))
# This can only be a NameNode
target_index_cname = self.target.entry.cname
# This will be used as the dict to format our code strings, holding
# the start, stop , step, temps and target cnames
fmt_dict = {
'target': target_index_cname,
}
# Setup start, stop and step, allocating temps if needed
start_stop_step = self.start, self.stop, self.step
defaults = '0', '0', '1'
for node, name, default in zip(start_stop_step, self.names, defaults):
if node is None:
result = default
elif node.is_literal:
result = node.get_constant_c_result_code()
else:
node.generate_evaluation_code(code)
result = node.result()
fmt_dict[name] = result
fmt_dict['i'] = code.funcstate.allocate_temp(self.index_type, False)
fmt_dict['nsteps'] = code.funcstate.allocate_temp(self.index_type, False)
# TODO: check if the step is 0 and if so, raise an exception in a
# 'with gil' block. For now, just abort
code.putln("if (%(step)s == 0) abort();" % fmt_dict)
# Guard for never-ending loops: prange(0, 10, -1) or prange(10, 0, 1)
# range() returns [] in these cases
code.put("if ( (%(start)s < %(stop)s && %(step)s > 0) || "
"(%(start)s > %(stop)s && %(step)s < 0) ) " % fmt_dict)
code.begin_block()
# code.putln_openmp("#pragma omp single")
code.putln("%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;" % fmt_dict)
# code.putln_openmp("#pragma omp barrier")
self.generate_loop(code, fmt_dict)
# And finally, release our privates and write back any closure
# variables
for temp in start_stop_step:
if temp is not None:
temp.generate_disposal_code(code)
temp.free_temps(code)
code.funcstate.release_temp(fmt_dict['i'])
code.funcstate.release_temp(fmt_dict['nsteps'])
self.release_closure_privates(code)
# end the 'if' block that guards against infinite loops
code.end_block()
def generate_loop(self, code, fmt_dict):
target_index_cname = fmt_dict['target']
code.putln("#ifdef _OPENMP")
if not self.is_parallel:
code.put("#pragma omp for")
else:
code.put("#pragma omp parallel for")
for private, op in self.privates.iteritems():
# Don't declare the index variable as a reduction
if private != target_index_cname:
if op and op in "+*-&^|":
code.put(" reduction(%s:%s)" % (op, private))
else:
code.put(" lastprivate(%s)" % private)
if self.schedule:
code.put(" schedule(%s)" % self.schedule)
code.putln(" lastprivate(%s)" % target_index_cname)
code.putln("#endif")
code.put("for (%(i)s = 0; %(i)s < %(nsteps)s; %(i)s++)" % fmt_dict)
code.begin_block()
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.body.generate_execution_code(code)
code.end_block()
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
......
...@@ -612,9 +612,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -612,9 +612,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
} }
special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof', special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
'cast', 'pointer', 'compiled', 'NULL']) 'cast', 'pointer', 'compiled', 'NULL', 'parallel'])
special_methods.update(unop_method_nodes.keys()) special_methods.update(unop_method_nodes.keys())
valid_parallel_directives = cython.set([
"parallel",
"prange",
"threadid",
# "threadsavailable",
])
def __init__(self, context, compilation_directive_defaults): def __init__(self, context, compilation_directive_defaults):
super(InterpretCompilerDirectives, self).__init__(context) super(InterpretCompilerDirectives, self).__init__(context)
self.compilation_directive_defaults = {} self.compilation_directive_defaults = {}
...@@ -622,6 +629,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -622,6 +629,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
self.compilation_directive_defaults[unicode(key)] = copy.deepcopy(value) self.compilation_directive_defaults[unicode(key)] = copy.deepcopy(value)
self.cython_module_names = cython.set() self.cython_module_names = cython.set()
self.directive_names = {} self.directive_names = {}
self.parallel_directives = {}
def check_directive_scope(self, pos, directive, scope): def check_directive_scope(self, pos, directive, scope):
legal_scopes = Options.directive_scopes.get(directive, None) legal_scopes = Options.directive_scopes.get(directive, None)
...@@ -644,6 +652,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -644,6 +652,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
directives.update(node.directive_comments) directives.update(node.directive_comments)
self.directives = directives self.directives = directives
node.directives = directives node.directives = directives
node.parallel_directives = self.parallel_directives
self.visitchildren(node) self.visitchildren(node)
node.cython_module_names = self.cython_module_names node.cython_module_names = self.cython_module_names
return node return node
...@@ -655,11 +664,31 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -655,11 +664,31 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
name in self.special_methods or name in self.special_methods or
PyrexTypes.parse_basic_type(name)) PyrexTypes.parse_basic_type(name))
def is_parallel_directive(self, full_name, pos):
result = (full_name + ".").startswith("cython.parallel.")
if result:
directive = full_name.rsplit('.', 1)
if (len(directive) != 2 or directive[1] not in
self.valid_parallel_directives):
error(pos, "No such directive: %s" % full_name)
return result
def visit_CImportStatNode(self, node): def visit_CImportStatNode(self, node):
if node.module_name == u"cython": if node.module_name == u"cython":
self.cython_module_names.add(node.as_name or u"cython") self.cython_module_names.add(node.as_name or u"cython")
elif node.module_name.startswith(u"cython."): elif node.module_name.startswith(u"cython."):
if node.as_name: if node.module_name.startswith(u"cython.parallel."):
error(node.pos, node.module_name + " is not a module")
if node.module_name == u"cython.parallel":
if node.as_name:
self.parallel_directives[node.as_name] = node.module_name
else:
self.cython_module_names.add(u"cython")
self.parallel_directives[
u"cython.parallel"] = node.module_name
elif node.as_name:
self.directive_names[node.as_name] = node.module_name[7:] self.directive_names[node.as_name] = node.module_name[7:]
else: else:
self.cython_module_names.add(u"cython") self.cython_module_names.add(u"cython")
...@@ -673,19 +702,29 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -673,19 +702,29 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
node.module_name.startswith(u"cython."): node.module_name.startswith(u"cython."):
submodule = (node.module_name + u".")[7:] submodule = (node.module_name + u".")[7:]
newimp = [] newimp = []
for pos, name, as_name, kind in node.imported_names: for pos, name, as_name, kind in node.imported_names:
full_name = submodule + name full_name = submodule + name
if self.is_cython_directive(full_name): qualified_name = u"cython." + full_name
if self.is_parallel_directive(qualified_name, node.pos):
# from cython cimport parallel, or
# from cython.parallel cimport parallel, prange, ...
self.parallel_directives[as_name or name] = qualified_name
elif self.is_cython_directive(full_name):
if as_name is None: if as_name is None:
as_name = full_name as_name = full_name
self.directive_names[as_name] = full_name self.directive_names[as_name] = full_name
if kind is not None: if kind is not None:
self.context.nonfatal_error(PostParseError(pos, self.context.nonfatal_error(PostParseError(pos,
"Compiler directive imports must be plain imports")) "Compiler directive imports must be plain imports"))
else: else:
newimp.append((pos, name, as_name, kind)) newimp.append((pos, name, as_name, kind))
if not newimp: if not newimp:
return None return None
node.imported_names = newimp node.imported_names = newimp
return node return node
...@@ -696,7 +735,10 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -696,7 +735,10 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
newimp = [] newimp = []
for name, name_node in node.items: for name, name_node in node.items:
full_name = submodule + name full_name = submodule + name
if self.is_cython_directive(full_name): qualified_name = u"cython." + full_name
if self.is_parallel_directive(qualified_name, node.pos):
self.parallel_directives[name_node.name] = qualified_name
elif self.is_cython_directive(full_name):
self.directive_names[name_node.name] = full_name self.directive_names[name_node.name] = full_name
else: else:
newimp.append((name, name_node)) newimp.append((name, name_node))
...@@ -707,11 +749,25 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -707,11 +749,25 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
if (isinstance(node.rhs, ExprNodes.ImportNode) and if (isinstance(node.rhs, ExprNodes.ImportNode) and
node.rhs.module_name.value == u'cython'): node.rhs.module_name.value in (u'cython', u"cython.parallel")):
module_name = node.rhs.module_name.value
as_name = node.lhs.name
if module_name == u"cython.parallel" and as_name == u"cython":
# Be consistent with the cimport variant
as_name = u"cython.parallel"
node = Nodes.CImportStatNode(node.pos, node = Nodes.CImportStatNode(node.pos,
module_name = u'cython', module_name = module_name,
as_name = node.lhs.name) as_name = as_name)
self.visit_CImportStatNode(node) self.visit_CImportStatNode(node)
if node.module_name == u"cython.parallel":
# This is an import for a fake module, remove it
return None
if node.module_name.startswith(u"cython.parallel."):
error(node.pos, node.module_name + " is not a module")
else: else:
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -897,6 +953,188 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -897,6 +953,188 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return self.visit_with_directives(node.body, directive_dict) return self.visit_with_directives(node.body, directive_dict)
return self.visit_Node(node) return self.visit_Node(node)
class ParallelRangeTransform(CythonTransform, SkipDeclarations):
"""
Transform cython.parallel stuff. The parallel_directives come from the
module node, set there by InterpretCompilerDirectives.
x = cython.parallel.threadavailable() -> ParallelThreadAvailableNode
with cython.parallel(nogil=True): -> ParallelWithBlockNode
print cython.parallel.threadid() -> ParallelThreadIdNode
for i in cython.parallel.prange(...): -> ParallelRangeNode
...
"""
# a list of names, maps 'cython.parallel.prange' in the code to
# ['cython', 'parallel', 'prange']
parallel_directive = None
# Indicates whether a namenode in an expression is the cython module
namenode_is_cython_module = False
# Keep track of whether we are the context manager of a 'with' statement
in_context_manager_section = False
# Keep track of whether we are in a parallel range section
in_prange = False
directive_to_node = {
u"cython.parallel.parallel": Nodes.ParallelWithBlockNode,
# u"cython.parallel.threadsavailable": ExprNodes.ParallelThreadsAvailableNode,
u"cython.parallel.threadid": ExprNodes.ParallelThreadIdNode,
u"cython.parallel.prange": Nodes.ParallelRangeNode,
}
def node_is_parallel_directive(self, node):
return node.name in self.parallel_directives or node.is_cython_module
def get_directive_class_node(self, node):
"""
Figure out which parallel directive was used and return the associated
Node class.
E.g. for a cython.parallel.prange() call we return ParallelRangeNode
Also disallow break, continue and return in a prange section
"""
if self.namenode_is_cython_module:
directive = '.'.join(self.parallel_directive)
else:
directive = self.parallel_directives[self.parallel_directive[0]]
directive = '%s.%s' % (directive,
'.'.join(self.parallel_directive[1:]))
directive = directive.rstrip('.')
cls = self.directive_to_node.get(directive)
if cls is None:
error(node.pos, "Invalid directive: %s" % directive)
self.namenode_is_cython_module = False
self.parallel_directive = None
return cls
def visit_ModuleNode(self, node):
"""
If any parallel directives were imported, copy them over and visit
the AST
"""
if node.parallel_directives:
self.parallel_directives = node.parallel_directives
self.assignment_stack = []
return self.visit_Node(node)
# No parallel directives were imported, so they can't be used :)
return node
def visit_NameNode(self, node):
if self.node_is_parallel_directive(node):
self.parallel_directive = [node.name]
self.namenode_is_cython_module = node.is_cython_module
return node
def visit_AttributeNode(self, node):
self.visitchildren(node)
if self.parallel_directive:
self.parallel_directive.append(node.attribute)
return node
def visit_CallNode(self, node):
self.visitchildren(node)
if not self.parallel_directive:
return node
# We are a parallel directive, replace this node with the
# corresponding ParallelSomethingSomething node
if isinstance(node, ExprNodes.GeneralCallNode):
args = node.positional_args.args
kwargs = node.keyword_args
else:
args = node.args
kwargs = {}
parallel_directive_class = self.get_directive_class_node(node)
if parallel_directive_class:
node = parallel_directive_class(node.pos, args=args, kwargs=kwargs)
return node
def visit_WithStatNode(self, node):
"Rewrite with cython.parallel() blocks"
self.visit(node.manager)
if self.parallel_directive:
parallel_directive_class = self.get_directive_class_node(node)
if not parallel_directive_class:
# There was an error, stop here and now
return None
self.visit(node.body)
newnode = Nodes.ParallelWithBlockNode(node.pos, body=node.body)
else:
newnode = node
self.visit(node.body)
return newnode
def visit_ForInStatNode(self, node):
"Rewrite 'for i in cython.parallel.prange(...):'"
self.visit(node.iterator)
self.visit(node.target)
was_in_prange = self.in_prange
self.in_prange = isinstance(node.iterator.sequence,
Nodes.ParallelRangeNode)
if self.in_prange:
# This will replace the entire ForInStatNode, so copy the
# attributes
parallel_range_node = node.iterator.sequence
parallel_range_node.target = node.target
parallel_range_node.body = node.body
parallel_range_node.else_clause = node.else_clause
node = parallel_range_node
if not isinstance(node.target, ExprNodes.NameNode):
error(node.target.pos,
"Can only iterate over an iteration variable")
self.visit(node.body)
self.in_prange = was_in_prange
self.visit(node.else_clause)
return node
def ensure_not_in_prange(name):
"Creates error checking functions for break, continue and return"
def visit_method(self, node):
if self.in_prange:
error(node.pos,
name + " not allowed in a parallel range section")
# Do a visit for 'return'
self.visitchildren(node)
return node
return visit_method
visit_BreakStatNode = ensure_not_in_prange("break")
visit_ContinueStatNode = ensure_not_in_prange("continue")
visit_ReturnStatNode = ensure_not_in_prange("return")
def visit(self, node):
"Visit a node that may be None"
if node is not None:
super(ParallelRangeTransform, self).visit(node)
class WithTransform(CythonTransform, SkipDeclarations): class WithTransform(CythonTransform, SkipDeclarations):
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
self.visitchildren(node, 'body') self.visitchildren(node, 'body')
...@@ -1715,22 +1953,54 @@ class GilCheck(VisitorTransform): ...@@ -1715,22 +1953,54 @@ class GilCheck(VisitorTransform):
self.env_stack.append(node.local_scope) self.env_stack.append(node.local_scope)
was_nogil = self.nogil was_nogil = self.nogil
self.nogil = node.local_scope.nogil self.nogil = node.local_scope.nogil
if self.nogil and node.nogil_check: if self.nogil and node.nogil_check:
node.nogil_check(node.local_scope) node.nogil_check(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.env_stack.pop()
self.nogil = was_nogil self.nogil = was_nogil
return node return node
def visit_GILStatNode(self, node): def visit_GILStatNode(self, node):
env = self.env_stack[-1] if self.nogil and node.nogil_check:
if self.nogil and node.nogil_check: node.nogil_check() node.nogil_check()
was_nogil = self.nogil was_nogil = self.nogil
self.nogil = (node.state == 'nogil') self.nogil = (node.state == 'nogil')
self.visitchildren(node) self.visitchildren(node)
self.nogil = was_nogil self.nogil = was_nogil
return node return node
def visit_ParallelRangeNode(self, node):
if node.is_nogil:
node.is_nogil = False
node = Nodes.GILStatNode(node.pos, state='nogil', body=node)
return self.visit_GILStatNode(node)
if not self.nogil:
error(node.pos, "prange() can only be used without the GIL")
# Forget about any GIL-related errors that may occur in the body
return None
node.nogil_check(self.env_stack[-1])
self.visitchildren(node)
return node
def visit_ParallelWithBlockNode(self, node):
if not self.nogil:
error(node.pos, "The parallel section may only be used without "
"the GIL")
return None
if node.nogil_check:
# It does not currently implement this, but test for it anyway to
# avoid potential future surprises
node.nogil_check(self.env_stack[-1])
self.visitchildren(node)
return node
def visit_Node(self, node): def visit_Node(self, node):
if self.env_stack and self.nogil and node.nogil_check: if self.env_stack and self.nogil and node.nogil_check:
node.nogil_check(self.env_stack[-1]) node.nogil_check(self.env_stack[-1])
...@@ -1857,8 +2127,7 @@ class TransformBuiltinMethods(EnvTransform): ...@@ -1857,8 +2127,7 @@ class TransformBuiltinMethods(EnvTransform):
class DebugTransform(CythonTransform): class DebugTransform(CythonTransform):
""" """
Create debug information and all functions' visibility to extern in order Write debug information for this Cython module.
to enable debugging.
""" """
def __init__(self, context, options, result): def __init__(self, context, options, result):
......
...@@ -723,6 +723,9 @@ class Scope(object): ...@@ -723,6 +723,9 @@ class Scope(object):
else: else:
return outer.is_cpp() return outer.is_cpp()
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)
class PreImportScope(Scope): class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname namespace_cname = Naming.preimport_cname
...@@ -1856,8 +1859,6 @@ class CppClassScope(Scope): ...@@ -1856,8 +1859,6 @@ class CppClassScope(Scope):
utility_code = e.utility_code) utility_code = e.utility_code)
return scope return scope
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)
class PropertyScope(Scope): class PropertyScope(Scope):
# Scope holding the __get__, __set__ and __del__ methods for # Scope holding the __get__, __set__ and __del__ methods for
......
...@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine ...@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine
from Cython.TestUtils import TransformTest from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import * from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler import Main
class TestNormalizeTree(TransformTest): class TestNormalizeTree(TransformTest):
...@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled! ...@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled!
""", t) """, t)
class TestInterpretCompilerDirectives(TransformTest):
"""
This class tests the parallel directives AST-rewriting and importing.
"""
# Test the parallel directives (c)importing
import_code = u"""
cimport cython.parallel
cimport cython.parallel as par
from cython cimport parallel as par2
from cython cimport parallel
from cython.parallel cimport threadid as tid
from cython.parallel cimport threadavailable as tavail
from cython.parallel cimport prange
"""
expected_directives_dict = {
u'cython.parallel': u'cython.parallel',
u'par': u'cython.parallel',
u'par2': u'cython.parallel',
u'parallel': u'cython.parallel',
u"tid": u"cython.parallel.threadid",
u"tavail": u"cython.parallel.threadavailable",
u"prange": u"cython.parallel.prange",
}
def setUp(self):
super(TestInterpretCompilerDirectives, self).setUp()
compilation_options = Main.CompilationOptions(Main.default_options)
ctx = compilation_options.create_context()
self.pipeline = [
InterpretCompilerDirectives(ctx, ctx.compiler_directives),
]
self.debug_exception_on_error = DebugFlags.debug_exception_on_error
def tearDown(self):
DebugFlags.debug_exception_on_error = self.debug_exception_on_error
def test_parallel_directives_cimports(self):
self.run_pipeline(self.pipeline, self.import_code)
parallel_directives = self.pipeline[0].parallel_directives
self.assertEqual(parallel_directives, self.expected_directives_dict)
def test_parallel_directives_imports(self):
self.run_pipeline(self.pipeline,
self.import_code.replace(u'cimport', u'import'))
parallel_directives = self.pipeline[0].parallel_directives
self.assertEqual(parallel_directives, self.expected_directives_dict)
# TODO: Re-enable once they're more robust. # TODO: Re-enable once they're more robust.
if sys.version_info[:2] >= (2, 5) and False: if sys.version_info[:2] >= (2, 5) and False:
from Cython.Debugger import DebugWriter from Cython.Debugger import DebugWriter
......
...@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type) ...@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type)
class MarkAssignments(CythonTransform): class MarkAssignments(CythonTransform):
def mark_assignment(self, lhs, rhs): def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
# Track the parallel block scopes (with parallel, for i in prange())
self.parallel_block_stack = []
def mark_assignment(self, lhs, rhs, inplace_op=None):
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)): if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
if lhs.entry is None: if lhs.entry is None:
# TODO: This shouldn't happen... # TODO: This shouldn't happen...
return return
lhs.entry.assignments.append(rhs) lhs.entry.assignments.append(rhs)
if self.parallel_block_stack:
parallel_node = self.parallel_block_stack[-1]
parallel_node.assignments[lhs.entry] = (lhs.pos, inplace_op)
elif isinstance(lhs, ExprNodes.SequenceNode): elif isinstance(lhs, ExprNodes.SequenceNode):
for arg in lhs.args: for arg in lhs.args:
self.mark_assignment(arg, object_expr) self.mark_assignment(arg, object_expr)
...@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform): ...@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform):
return node return node
def visit_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
self.mark_assignment(node.lhs, node.create_binop_node()) self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform): ...@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_ParallelStatNode(self, node):
if self.parallel_block_stack:
node.parent = self.parallel_block_stack[-1]
else:
node.parent = None
if node.is_prange:
if not node.parent:
node.is_parallel = True
else:
node.is_parallel = (node.parent.is_prange or not
node.parent.is_parallel)
else:
node.is_parallel = True
self.parallel_block_stack.append(node)
self.visitchildren(node)
self.parallel_block_stack.pop()
return node
class MarkOverflowingArithmetic(CythonTransform): class MarkOverflowingArithmetic(CythonTransform):
# It may be possible to integrate this with the above for # It may be possible to integrate this with the above for
......
# tag: run
# distutils: libraries = gomp
# distutils: extra_compile_args = -fopenmp
cimport cython.parallel
from cython.parallel import prange, threadid
from libc.stdlib cimport malloc, free
cdef extern from "Python.h":
void PyEval_InitThreads()
PyEval_InitThreads()
cdef void print_int(int x) with gil:
print x
#@cython.test_assert_path_exists(
# "//ParallelWithBlockNode//ParallelRangeNode[@schedule = 'dynamic']",
# "//GILStatNode[@state = 'nogil]//ParallelRangeNode")
def test_prange():
"""
>>> test_prange()
(9, 9, 45, 45)
"""
cdef Py_ssize_t i, j, sum1 = 0, sum2 = 0
with nogil, cython.parallel.parallel:
for i in prange(10, schedule='dynamic'):
sum1 += i
for j in prange(10, nogil=True):
sum2 += j
return i, j, sum1, sum2
def test_descending_prange():
"""
>>> test_descending_prange()
5
"""
cdef int i, start = 5, stop = -5, step = -2
cdef int sum = 0
for i in prange(start, stop, step, nogil=True):
sum += i
return sum
def test_nested_prange():
"""
Reduction propagation is not (yet) supported.
>>> test_nested_prange()
50
"""
cdef int i, j
cdef int sum = 0
for i in prange(5, nogil=True):
for j in prange(5):
sum += i
# The value of sum is undefined here
sum = 0
for i in prange(5, nogil=True):
for j in prange(5):
sum += i
sum += 0
return sum
# threadsavailable test, disable this for now as it won't compile
#def test_parallel():
# """
# >>> test_parallel()
# """
# cdef int *buf = <int *> malloc(sizeof(int) * threadsavailable())
#
# if buf == NULL:
# raise MemoryError
#
# with nogil, cython.parallel.parallel:
# buf[threadid()] = threadid()
#
# for i in range(threadsavailable()):
# assert buf[i] == i
#
# free(buf)
def test_unsigned_operands():
"""
This test is disabled, as this currently does not work (neither does it
for 'for i from x < i < y:'. I'm not sure we should strife to support
this, at least the C compiler gives a warning.
test_unsigned_operands()
10
"""
cdef int i
cdef int start = -5
cdef unsigned int stop = 5
cdef int step = 1
cdef int steps_taken = 0
for i in prange(start, stop, step, nogil=True):
steps_taken += 1
return steps_taken
def test_reassign_start_stop_step():
"""
>>> test_reassign_start_stop_step()
20
"""
cdef int start = 0, stop = 10, step = 2
cdef int i
cdef int sum = 0
for i in prange(start, stop, step, nogil=True):
start = -2
stop = 2
step = 0
sum += i
return sum
def test_closure_parallel_privates():
"""
>>> test_closure_parallel_privates()
9 9
45 45
0 0 9 9
"""
cdef int x
def test_target():
nonlocal x
for x in prange(10, nogil=True):
pass
return x
print test_target(), x
def test_reduction():
nonlocal x
cdef int i
x = 0
for i in prange(10, nogil=True):
x += i
return x
print test_reduction(), x
def test_generator():
nonlocal x
cdef int i
x = 0
yield x
x = 2
for i in prange(10, nogil=True):
x = i
yield x
g = test_generator()
print g.next(), x, g.next(), x
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment