Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
Boxiang Sun
cython
Commits
18d8b3fb
Commit
18d8b3fb
authored
Apr 15, 2011
by
Mark Florisson
Committed by
Vitja Makarov
May 05, 2011
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Preliminary OpenMP support
parent
8f47d370
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
983 additions
and
22 deletions
+983
-22
Cython/Compiler/Code.py
Cython/Compiler/Code.py
+4
-0
Cython/Compiler/ExprNodes.py
Cython/Compiler/ExprNodes.py
+60
-1
Cython/Compiler/Main.py
Cython/Compiler/Main.py
+2
-1
Cython/Compiler/Nodes.py
Cython/Compiler/Nodes.py
+367
-5
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/ParseTreeTransforms.py
+280
-11
Cython/Compiler/Symtab.py
Cython/Compiler/Symtab.py
+3
-2
Cython/Compiler/Tests/TestParseTreeTransforms.py
Cython/Compiler/Tests/TestParseTreeTransforms.py
+57
-0
Cython/Compiler/TypeInference.py
Cython/Compiler/TypeInference.py
+35
-2
tests/run/parallel.pyx
tests/run/parallel.pyx
+175
-0
No files found.
Cython/Compiler/Code.py
View file @
18d8b3fb
...
@@ -1444,6 +1444,10 @@ class CCodeWriter(object):
...
@@ -1444,6 +1444,10 @@ class CCodeWriter(object):
def
put_trace_return
(
self
,
retvalue_cname
):
def
put_trace_return
(
self
,
retvalue_cname
):
self
.
putln
(
"__Pyx_TraceReturn(%s);"
%
retvalue_cname
)
self
.
putln
(
"__Pyx_TraceReturn(%s);"
%
retvalue_cname
)
def
putln_openmp
(
self
,
string
):
self
.
putln
(
"#ifdef _OPENMP"
)
self
.
putln
(
string
)
self
.
putln
(
"#endif"
)
class
PyrexCodeWriter
(
object
):
class
PyrexCodeWriter
(
object
):
# f file output file
# f file output file
...
...
Cython/Compiler/ExprNodes.py
View file @
18d8b3fb
...
@@ -2059,6 +2059,62 @@ class RawCNameExprNode(ExprNode):
...
@@ -2059,6 +2059,62 @@ class RawCNameExprNode(ExprNode):
pass
pass
#-------------------------------------------------------------------
#
# Parallel nodes (cython.parallel.thread(savailable|id))
#
#-------------------------------------------------------------------
class
ParallelThreadsAvailableNode
(
AtomicExprNode
):
"""
Implements cython.parallel.threadsavailable(). If we are called from the
sequential part of the application, we need to call omp_get_max_threads(),
and in the parallel part we can just call omp_get_num_threads()
"""
type
=
PyrexTypes
.
c_int_type
def
analyse_types
(
self
,
env
):
self
.
is_temp
=
True
env
.
add_include_file
(
"omp.h"
)
return
self
.
type
def
generate_result_code
(
self
,
code
):
code
.
putln
(
"#ifdef _OPENMP"
)
code
.
putln
(
"if (omp_in_parallel()) %s = omp_get_max_threads();"
%
self
.
temp_code
)
code
.
putln
(
"else %s = omp_get_num_threads();"
%
self
.
temp_code
)
code
.
putln
(
"#else"
)
code
.
putln
(
"%s = 1;"
%
self
.
temp_code
)
code
.
putln
(
"#endif"
)
def
result
(
self
):
return
self
.
temp_code
class
ParallelThreadIdNode
(
AtomicExprNode
):
#, Nodes.ParallelNode):
"""
Implements cython.parallel.threadid()
"""
type
=
PyrexTypes
.
c_int_type
def
analyse_types
(
self
,
env
):
self
.
is_temp
=
True
env
.
add_include_file
(
"omp.h"
)
return
self
.
type
def
generate_result_code
(
self
,
code
):
code
.
putln
(
"#ifdef _OPENMP"
)
code
.
putln
(
"%s = omp_get_thread_num();"
%
self
.
temp_code
)
code
.
putln
(
"#else"
)
code
.
putln
(
"%s = 0;"
%
self
.
temp_code
)
code
.
putln
(
"#endif"
)
def
result
(
self
):
return
self
.
temp_code
#-------------------------------------------------------------------
#-------------------------------------------------------------------
#
#
# Trailer nodes
# Trailer nodes
...
@@ -3465,8 +3521,11 @@ class AttributeNode(ExprNode):
...
@@ -3465,8 +3521,11 @@ class AttributeNode(ExprNode):
needs_none_check
=
True
needs_none_check
=
True
def
as_cython_attribute
(
self
):
def
as_cython_attribute
(
self
):
if
isinstance
(
self
.
obj
,
NameNode
)
and
self
.
obj
.
is_cython_module
:
if
(
isinstance
(
self
.
obj
,
NameNode
)
and
self
.
obj
.
is_cython_module
and
not
self
.
attribute
==
u"parallel"
):
return
self
.
attribute
return
self
.
attribute
cy
=
self
.
obj
.
as_cython_attribute
()
cy
=
self
.
obj
.
as_cython_attribute
()
if
cy
:
if
cy
:
return
"%s.%s"
%
(
cy
,
self
.
attribute
)
return
"%s.%s"
%
(
cy
,
self
.
attribute
)
...
...
Cython/Compiler/Main.py
View file @
18d8b3fb
...
@@ -106,7 +106,7 @@ class Context(object):
...
@@ -106,7 +106,7 @@ class Context(object):
from
ParseTreeTransforms
import
AnalyseDeclarationsTransform
,
AnalyseExpressionsTransform
from
ParseTreeTransforms
import
AnalyseDeclarationsTransform
,
AnalyseExpressionsTransform
from
ParseTreeTransforms
import
CreateClosureClasses
,
MarkClosureVisitor
,
DecoratorTransform
from
ParseTreeTransforms
import
CreateClosureClasses
,
MarkClosureVisitor
,
DecoratorTransform
from
ParseTreeTransforms
import
InterpretCompilerDirectives
,
TransformBuiltinMethods
from
ParseTreeTransforms
import
InterpretCompilerDirectives
,
TransformBuiltinMethods
from
ParseTreeTransforms
import
ExpandInplaceOperators
from
ParseTreeTransforms
import
ExpandInplaceOperators
,
ParallelRangeTransform
from
TypeInference
import
MarkAssignments
,
MarkOverflowingArithmetic
from
TypeInference
import
MarkAssignments
,
MarkOverflowingArithmetic
from
ParseTreeTransforms
import
AlignFunctionDefinitions
,
GilCheck
from
ParseTreeTransforms
import
AlignFunctionDefinitions
,
GilCheck
from
ParseTreeTransforms
import
RemoveUnreachableCode
from
ParseTreeTransforms
import
RemoveUnreachableCode
...
@@ -136,6 +136,7 @@ class Context(object):
...
@@ -136,6 +136,7 @@ class Context(object):
PostParse
(
self
),
PostParse
(
self
),
_specific_post_parse
,
_specific_post_parse
,
InterpretCompilerDirectives
(
self
,
self
.
compiler_directives
),
InterpretCompilerDirectives
(
self
,
self
.
compiler_directives
),
ParallelRangeTransform
(
self
),
MarkClosureVisitor
(
self
),
MarkClosureVisitor
(
self
),
_align_function_definitions
,
_align_function_definitions
,
RemoveUnreachableCode
(
self
),
RemoveUnreachableCode
(
self
),
...
...
Cython/Compiler/Nodes.py
View file @
18d8b3fb
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#
#
# Pyrex - Parse tree nodes
# Pyrex - Parse tree nodes
#
#
import
cython
import
cython
from
cython
import
set
from
cython
import
set
cython
.
declare
(
sys
=
object
,
os
=
object
,
time
=
object
,
copy
=
object
,
cython
.
declare
(
sys
=
object
,
os
=
object
,
time
=
object
,
copy
=
object
,
...
@@ -4597,14 +4596,12 @@ class ForInStatNode(LoopNode, StatNode):
...
@@ -4597,14 +4596,12 @@ class ForInStatNode(LoopNode, StatNode):
old_loop_labels
=
code
.
new_loop_labels
()
old_loop_labels
=
code
.
new_loop_labels
()
self
.
iterator
.
allocate_counter_temp
(
code
)
self
.
iterator
.
allocate_counter_temp
(
code
)
self
.
iterator
.
generate_evaluation_code
(
code
)
self
.
iterator
.
generate_evaluation_code
(
code
)
code
.
putln
(
code
.
putln
(
"for (;;) {"
)
"for (;;) {"
)
self
.
item
.
generate_evaluation_code
(
code
)
self
.
item
.
generate_evaluation_code
(
code
)
self
.
target
.
generate_assignment_code
(
self
.
item
,
code
)
self
.
target
.
generate_assignment_code
(
self
.
item
,
code
)
self
.
body
.
generate_execution_code
(
code
)
self
.
body
.
generate_execution_code
(
code
)
code
.
put_label
(
code
.
continue_label
)
code
.
put_label
(
code
.
continue_label
)
code
.
putln
(
code
.
putln
(
"}"
)
"}"
)
break_label
=
code
.
break_label
break_label
=
code
.
break_label
code
.
set_loop_labels
(
old_loop_labels
)
code
.
set_loop_labels
(
old_loop_labels
)
...
@@ -5735,6 +5732,371 @@ class FromImportStatNode(StatNode):
...
@@ -5735,6 +5732,371 @@ class FromImportStatNode(StatNode):
self
.
module
.
free_temps
(
code
)
self
.
module
.
free_temps
(
code
)
class
ParallelNode
(
Node
):
"""
Base class for cython.parallel constructs.
"""
nogil_check
=
None
class
ParallelStatNode
(
StatNode
,
ParallelNode
):
"""
Base class for 'with cython.parallel.parallel:' and 'for i in prange():'.
assignments { Entry(var) : (var.pos, inplace_operator_or_None) }
assignments to variables in this parallel section
parent parent ParallelStatNode or None
is_parallel indicates whether this is a parallel node
is_parallel is true for:
#pragma omp parallel
#pragma omp parallel for
sections, but NOT for
#pragma omp for
We need this to determine the sharing attributes.
"""
child_attrs
=
[
'body'
]
body
=
None
is_prange
=
False
def
__init__
(
self
,
pos
,
**
kwargs
):
super
(
ParallelStatNode
,
self
).
__init__
(
pos
,
**
kwargs
)
self
.
assignments
=
kwargs
.
get
(
'assignments'
)
or
{}
# Insertion point before the outermost parallel section
self
.
before_parallel_section_point
=
None
# Insertion point after the outermost parallel section
self
.
post_parallel_section_point
=
None
def
analyse_expressions
(
self
,
env
):
self
.
body
.
analyse_expressions
(
env
)
def
analyse_declarations
(
self
,
env
):
super
(
ParallelStatNode
,
self
).
analyse_declarations
(
env
)
self
.
body
.
analyse_declarations
(
env
)
def
lookup_assignment
(
self
,
entry
):
"""
Return an assignment's pos and operator. If the parent has the
assignment, return the parent's assignment, otherwise our own.
"""
parent_assignment
=
self
.
parent
and
self
.
parent
.
lookup_assignment
(
entry
)
return
parent_assignment
or
self
.
assignments
.
get
(
entry
)
def
is_private
(
self
,
entry
):
"""
True if this scope should declare the variable private, lastprivate
or reduction.
"""
parent_or_our_entry
=
self
.
lookup_assignment
(
entry
)
our_entry
=
self
.
assignments
.
get
(
entry
)
return
self
.
is_parallel
or
parent_or_our_entry
==
our_entry
def
_allocate_closure_temp
(
self
,
code
,
entry
):
"""
Helper function that allocate a temporary for a closure variable that
is assigned to.
"""
if
self
.
parent
:
return
self
.
parent
.
_allocate_closure_temp
(
code
,
entry
)
cname
=
code
.
funcstate
.
allocate_temp
(
entry
.
type
,
False
)
self
.
modified_entries
.
append
((
entry
,
entry
.
cname
))
code
.
putln
(
"%s = %s;"
%
(
cname
,
entry
.
cname
))
entry
.
cname
=
cname
return
cname
def
declare_closure_privates
(
self
,
code
):
"""
Set self.privates to a dict mapping C variable names that are to be
declared (first|last)private or reduction, to the reduction operator.
If the private is not a reduction, the operator is None.
This is used by subclasses.
If a variable is in a scope object, we need to allocate a temp and
assign the value from the temp to the variable in the scope object
after the parallel section. This kind of copying should be done only
in the outermost parallel section.
"""
self
.
privates
=
{}
self
.
modified_entries
=
[]
for
entry
,
(
pos
,
op
)
in
self
.
assignments
.
iteritems
():
cname
=
entry
.
cname
if
entry
.
from_closure
or
entry
.
in_closure
:
cname
=
self
.
_allocate_closure_temp
(
code
,
entry
)
if
self
.
is_private
(
entry
):
self
.
privates
[
cname
]
=
op
def
release_closure_privates
(
self
,
code
):
"Release any temps used for variables in scope objects"
for
entry
,
original_cname
in
self
.
modified_entries
:
code
.
putln
(
"%s = %s;"
%
(
original_cname
,
entry
.
cname
))
code
.
funcstate
.
release_temp
(
entry
.
cname
)
entry
.
cname
=
original_cname
class
ParallelWithBlockNode
(
ParallelStatNode
):
"""
This node represents a 'with cython.parallel:' block
"""
nogil_check
=
None
def
generate_execution_code
(
self
,
code
):
self
.
declare_closure_privates
(
code
)
code
.
putln
(
"#ifdef _OPENMP"
)
code
.
put
(
"#pragma omp parallel "
)
code
.
putln
(
' '
.
join
([
"private(%s)"
%
e
.
cname
for
e
in
self
.
assignments
if
self
.
is_private
(
e
)]))
code
.
putln
(
"#endif"
)
code
.
begin_block
()
self
.
body
.
generate_execution_code
(
code
)
code
.
end_block
()
self
.
release_closure_privates
(
code
)
class
ParallelRangeNode
(
ParallelStatNode
):
"""
This node represents a 'for i in cython.parallel.prange():' construct.
target NameNode the target iteration variable
else_clause Node or None the else clause of this loop
args tuple the arguments passed to prange()
kwargs DictNode the keyword arguments passed to prange()
(replaced by its compile time value)
is_nogil bool indicates whether this is a nogil prange() node
"""
child_attrs
=
[
'body'
,
'target'
,
'else_clause'
,
'args'
]
body
=
target
=
else_clause
=
args
=
None
start
=
stop
=
step
=
None
is_prange
=
True
def
analyse_declarations
(
self
,
env
):
super
(
ParallelRangeNode
,
self
).
analyse_declarations
(
env
)
self
.
target
.
analyse_target_declaration
(
env
)
if
self
.
else_clause
is
not
None
:
self
.
else_clause
.
analyse_declarations
(
env
)
if
not
self
.
args
or
len
(
self
.
args
)
>
3
:
error
(
self
.
pos
,
"Invalid number of positional arguments to prange"
)
return
if
len
(
self
.
args
)
==
1
:
self
.
stop
,
=
self
.
args
elif
len
(
self
.
args
)
==
2
:
self
.
start
,
self
.
stop
=
self
.
args
else
:
self
.
start
,
self
.
stop
,
self
.
step
=
self
.
args
if
self
.
kwargs
:
self
.
kwargs
=
self
.
kwargs
.
compile_time_value
(
env
)
else
:
self
.
kwargs
=
{}
self
.
is_nogil
=
self
.
kwargs
.
pop
(
'nogil'
,
False
)
self
.
schedule
=
self
.
kwargs
.
pop
(
'schedule'
,
None
)
if
self
.
schedule
not
in
(
None
,
'static'
,
'dynamic'
,
'guided'
,
'runtime'
):
error
(
self
.
pos
,
"Invalid schedule argument to prange: %r"
%
(
self
.
schedule
,))
for
kw
in
self
.
kwargs
:
error
(
self
.
pos
,
"Invalid keyword argument to prange: %s"
%
kw
)
def
analyse_expressions
(
self
,
env
):
self
.
target
.
analyse_target_types
(
env
)
self
.
index_type
=
self
.
target
.
type
if
self
.
index_type
.
is_pyobject
:
# nogil_check will catch this
return
# Setup start, stop and step, allocating temps if needed
self
.
names
=
'start'
,
'stop'
,
'step'
start_stop_step
=
self
.
start
,
self
.
stop
,
self
.
step
for
node
,
name
in
zip
(
start_stop_step
,
self
.
names
):
if
node
is
not
None
:
node
.
analyse_types
(
env
)
if
not
node
.
type
.
is_numeric
:
error
(
node
.
pos
,
"%s argument must be numeric or a pointer "
"(perhaps if a numeric literal is too "
"big, use 1000LL)"
%
name
)
if
not
node
.
is_literal
:
node
=
node
.
coerce_to_temp
(
env
)
setattr
(
self
,
name
,
node
)
# As we range from 0 to nsteps, computing the index along the
# way, we need a fitting type for 'i' and 'nsteps'
self
.
index_type
=
PyrexTypes
.
widest_numeric_type
(
self
.
index_type
,
node
.
type
)
self
.
body
.
analyse_expressions
(
env
)
if
self
.
else_clause
is
not
None
:
self
.
else_clause
.
analyse_expressions
(
env
)
def
nogil_check
(
self
,
env
):
names
=
'start'
,
'stop'
,
'step'
,
'target'
nodes
=
self
.
start
,
self
.
stop
,
self
.
step
,
self
.
target
for
name
,
node
in
zip
(
names
,
nodes
):
if
node
is
not
None
and
node
.
type
.
is_pyobject
:
error
(
node
.
pos
,
"%s may not be a Python object "
"as we don't have the GIL"
%
name
)
def
generate_execution_code
(
self
,
code
):
"""
Generate code in the following steps
1) copy any closure variables determined thread-private
into temporaries
2) allocate temps for start, stop and step
3) generate a loop that calculates the total number of steps,
which then computes the target iteration variable for every step:
for i in prange(start, stop, step):
...
becomes
nsteps = (stop - start) / step;
i = start;
#pragma omp parallel for lastprivate(i)
for (temp = 0; temp < nsteps; temp++) {
i = start + step * temp;
...
}
Note that accumulation of 'i' would have a data dependency
between iterations.
Also, you can't do this
for (i = start; i < stop; i += step)
...
as the '<' operator should become '>' for descending loops.
'for i from x < i < y:' does not suffer from this problem
as the relational operator is known at compile time!
4) release our temps and write back any private closure variables
"""
# Ensure to unpack the target index variable if it's a closure temp
self
.
assignments
[
self
.
target
.
entry
]
=
self
.
target
.
pos
,
None
self
.
declare_closure_privates
(
code
)
#self.insertion_point(code))
# This can only be a NameNode
target_index_cname
=
self
.
target
.
entry
.
cname
# This will be used as the dict to format our code strings, holding
# the start, stop , step, temps and target cnames
fmt_dict
=
{
'target'
:
target_index_cname
,
}
# Setup start, stop and step, allocating temps if needed
start_stop_step
=
self
.
start
,
self
.
stop
,
self
.
step
defaults
=
'0'
,
'0'
,
'1'
for
node
,
name
,
default
in
zip
(
start_stop_step
,
self
.
names
,
defaults
):
if
node
is
None
:
result
=
default
elif
node
.
is_literal
:
result
=
node
.
get_constant_c_result_code
()
else
:
node
.
generate_evaluation_code
(
code
)
result
=
node
.
result
()
fmt_dict
[
name
]
=
result
fmt_dict
[
'i'
]
=
code
.
funcstate
.
allocate_temp
(
self
.
index_type
,
False
)
fmt_dict
[
'nsteps'
]
=
code
.
funcstate
.
allocate_temp
(
self
.
index_type
,
False
)
# TODO: check if the step is 0 and if so, raise an exception in a
# 'with gil' block. For now, just abort
code
.
putln
(
"if (%(step)s == 0) abort();"
%
fmt_dict
)
# Guard for never-ending loops: prange(0, 10, -1) or prange(10, 0, 1)
# range() returns [] in these cases
code
.
put
(
"if ( (%(start)s < %(stop)s && %(step)s > 0) || "
"(%(start)s > %(stop)s && %(step)s < 0) ) "
%
fmt_dict
)
code
.
begin_block
()
# code.putln_openmp("#pragma omp single")
code
.
putln
(
"%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;"
%
fmt_dict
)
# code.putln_openmp("#pragma omp barrier")
self
.
generate_loop
(
code
,
fmt_dict
)
# And finally, release our privates and write back any closure
# variables
for
temp
in
start_stop_step
:
if
temp
is
not
None
:
temp
.
generate_disposal_code
(
code
)
temp
.
free_temps
(
code
)
code
.
funcstate
.
release_temp
(
fmt_dict
[
'i'
])
code
.
funcstate
.
release_temp
(
fmt_dict
[
'nsteps'
])
self
.
release_closure_privates
(
code
)
# end the 'if' block that guards against infinite loops
code
.
end_block
()
def
generate_loop
(
self
,
code
,
fmt_dict
):
target_index_cname
=
fmt_dict
[
'target'
]
code
.
putln
(
"#ifdef _OPENMP"
)
if
not
self
.
is_parallel
:
code
.
put
(
"#pragma omp for"
)
else
:
code
.
put
(
"#pragma omp parallel for"
)
for
private
,
op
in
self
.
privates
.
iteritems
():
# Don't declare the index variable as a reduction
if
private
!=
target_index_cname
:
if
op
and
op
in
"+*-&^|"
:
code
.
put
(
" reduction(%s:%s)"
%
(
op
,
private
))
else
:
code
.
put
(
" lastprivate(%s)"
%
private
)
if
self
.
schedule
:
code
.
put
(
" schedule(%s)"
%
self
.
schedule
)
code
.
putln
(
" lastprivate(%s)"
%
target_index_cname
)
code
.
putln
(
"#endif"
)
code
.
put
(
"for (%(i)s = 0; %(i)s < %(nsteps)s; %(i)s++)"
%
fmt_dict
)
code
.
begin_block
()
code
.
putln
(
"%(target)s = %(start)s + %(step)s * %(i)s;"
%
fmt_dict
)
self
.
body
.
generate_execution_code
(
code
)
code
.
end_block
()
#------------------------------------------------------------------------------------
#------------------------------------------------------------------------------------
#
#
...
...
Cython/Compiler/ParseTreeTransforms.py
View file @
18d8b3fb
...
@@ -612,9 +612,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -612,9 +612,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
}
}
special_methods
=
cython
.
set
([
'declare'
,
'union'
,
'struct'
,
'typedef'
,
'sizeof'
,
special_methods
=
cython
.
set
([
'declare'
,
'union'
,
'struct'
,
'typedef'
,
'sizeof'
,
'cast'
,
'pointer'
,
'compiled'
,
'NULL'
])
'cast'
,
'pointer'
,
'compiled'
,
'NULL'
,
'parallel'
])
special_methods
.
update
(
unop_method_nodes
.
keys
())
special_methods
.
update
(
unop_method_nodes
.
keys
())
valid_parallel_directives
=
cython
.
set
([
"parallel"
,
"prange"
,
"threadid"
,
# "threadsavailable",
])
def
__init__
(
self
,
context
,
compilation_directive_defaults
):
def
__init__
(
self
,
context
,
compilation_directive_defaults
):
super
(
InterpretCompilerDirectives
,
self
).
__init__
(
context
)
super
(
InterpretCompilerDirectives
,
self
).
__init__
(
context
)
self
.
compilation_directive_defaults
=
{}
self
.
compilation_directive_defaults
=
{}
...
@@ -622,6 +629,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -622,6 +629,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
self
.
compilation_directive_defaults
[
unicode
(
key
)]
=
copy
.
deepcopy
(
value
)
self
.
compilation_directive_defaults
[
unicode
(
key
)]
=
copy
.
deepcopy
(
value
)
self
.
cython_module_names
=
cython
.
set
()
self
.
cython_module_names
=
cython
.
set
()
self
.
directive_names
=
{}
self
.
directive_names
=
{}
self
.
parallel_directives
=
{}
def
check_directive_scope
(
self
,
pos
,
directive
,
scope
):
def
check_directive_scope
(
self
,
pos
,
directive
,
scope
):
legal_scopes
=
Options
.
directive_scopes
.
get
(
directive
,
None
)
legal_scopes
=
Options
.
directive_scopes
.
get
(
directive
,
None
)
...
@@ -644,6 +652,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -644,6 +652,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
directives
.
update
(
node
.
directive_comments
)
directives
.
update
(
node
.
directive_comments
)
self
.
directives
=
directives
self
.
directives
=
directives
node
.
directives
=
directives
node
.
directives
=
directives
node
.
parallel_directives
=
self
.
parallel_directives
self
.
visitchildren
(
node
)
self
.
visitchildren
(
node
)
node
.
cython_module_names
=
self
.
cython_module_names
node
.
cython_module_names
=
self
.
cython_module_names
return
node
return
node
...
@@ -655,11 +664,31 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -655,11 +664,31 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
name
in
self
.
special_methods
or
name
in
self
.
special_methods
or
PyrexTypes
.
parse_basic_type
(
name
))
PyrexTypes
.
parse_basic_type
(
name
))
def
is_parallel_directive
(
self
,
full_name
,
pos
):
result
=
(
full_name
+
"."
).
startswith
(
"cython.parallel."
)
if
result
:
directive
=
full_name
.
rsplit
(
'.'
,
1
)
if
(
len
(
directive
)
!=
2
or
directive
[
1
]
not
in
self
.
valid_parallel_directives
):
error
(
pos
,
"No such directive: %s"
%
full_name
)
return
result
def
visit_CImportStatNode
(
self
,
node
):
def
visit_CImportStatNode
(
self
,
node
):
if
node
.
module_name
==
u"cython"
:
if
node
.
module_name
==
u"cython"
:
self
.
cython_module_names
.
add
(
node
.
as_name
or
u"cython"
)
self
.
cython_module_names
.
add
(
node
.
as_name
or
u"cython"
)
elif
node
.
module_name
.
startswith
(
u"cython."
):
elif
node
.
module_name
.
startswith
(
u"cython."
):
if
node
.
module_name
.
startswith
(
u"cython.parallel."
):
error
(
node
.
pos
,
node
.
module_name
+
" is not a module"
)
if
node
.
module_name
==
u"cython.parallel"
:
if
node
.
as_name
:
if
node
.
as_name
:
self
.
parallel_directives
[
node
.
as_name
]
=
node
.
module_name
else
:
self
.
cython_module_names
.
add
(
u"cython"
)
self
.
parallel_directives
[
u"cython.parallel"
]
=
node
.
module_name
elif
node
.
as_name
:
self
.
directive_names
[
node
.
as_name
]
=
node
.
module_name
[
7
:]
self
.
directive_names
[
node
.
as_name
]
=
node
.
module_name
[
7
:]
else
:
else
:
self
.
cython_module_names
.
add
(
u"cython"
)
self
.
cython_module_names
.
add
(
u"cython"
)
...
@@ -673,19 +702,29 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -673,19 +702,29 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
node
.
module_name
.
startswith
(
u"cython."
):
node
.
module_name
.
startswith
(
u"cython."
):
submodule
=
(
node
.
module_name
+
u"."
)[
7
:]
submodule
=
(
node
.
module_name
+
u"."
)[
7
:]
newimp
=
[]
newimp
=
[]
for
pos
,
name
,
as_name
,
kind
in
node
.
imported_names
:
for
pos
,
name
,
as_name
,
kind
in
node
.
imported_names
:
full_name
=
submodule
+
name
full_name
=
submodule
+
name
if
self
.
is_cython_directive
(
full_name
):
qualified_name
=
u"cython."
+
full_name
if
self
.
is_parallel_directive
(
qualified_name
,
node
.
pos
):
# from cython cimport parallel, or
# from cython.parallel cimport parallel, prange, ...
self
.
parallel_directives
[
as_name
or
name
]
=
qualified_name
elif
self
.
is_cython_directive
(
full_name
):
if
as_name
is
None
:
if
as_name
is
None
:
as_name
=
full_name
as_name
=
full_name
self
.
directive_names
[
as_name
]
=
full_name
self
.
directive_names
[
as_name
]
=
full_name
if
kind
is
not
None
:
if
kind
is
not
None
:
self
.
context
.
nonfatal_error
(
PostParseError
(
pos
,
self
.
context
.
nonfatal_error
(
PostParseError
(
pos
,
"Compiler directive imports must be plain imports"
))
"Compiler directive imports must be plain imports"
))
else
:
else
:
newimp
.
append
((
pos
,
name
,
as_name
,
kind
))
newimp
.
append
((
pos
,
name
,
as_name
,
kind
))
if
not
newimp
:
if
not
newimp
:
return
None
return
None
node
.
imported_names
=
newimp
node
.
imported_names
=
newimp
return
node
return
node
...
@@ -696,7 +735,10 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -696,7 +735,10 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
newimp
=
[]
newimp
=
[]
for
name
,
name_node
in
node
.
items
:
for
name
,
name_node
in
node
.
items
:
full_name
=
submodule
+
name
full_name
=
submodule
+
name
if
self
.
is_cython_directive
(
full_name
):
qualified_name
=
u"cython."
+
full_name
if
self
.
is_parallel_directive
(
qualified_name
,
node
.
pos
):
self
.
parallel_directives
[
name_node
.
name
]
=
qualified_name
elif
self
.
is_cython_directive
(
full_name
):
self
.
directive_names
[
name_node
.
name
]
=
full_name
self
.
directive_names
[
name_node
.
name
]
=
full_name
else
:
else
:
newimp
.
append
((
name
,
name_node
))
newimp
.
append
((
name
,
name_node
))
...
@@ -707,11 +749,25 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -707,11 +749,25 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
def
visit_SingleAssignmentNode
(
self
,
node
):
def
visit_SingleAssignmentNode
(
self
,
node
):
if
(
isinstance
(
node
.
rhs
,
ExprNodes
.
ImportNode
)
and
if
(
isinstance
(
node
.
rhs
,
ExprNodes
.
ImportNode
)
and
node
.
rhs
.
module_name
.
value
==
u'cython'
):
node
.
rhs
.
module_name
.
value
in
(
u'cython'
,
u"cython.parallel"
)):
module_name
=
node
.
rhs
.
module_name
.
value
as_name
=
node
.
lhs
.
name
if
module_name
==
u"cython.parallel"
and
as_name
==
u"cython"
:
# Be consistent with the cimport variant
as_name
=
u"cython.parallel"
node
=
Nodes
.
CImportStatNode
(
node
.
pos
,
node
=
Nodes
.
CImportStatNode
(
node
.
pos
,
module_name
=
u'cython'
,
module_name
=
module_name
,
as_name
=
node
.
lhs
.
name
)
as_name
=
as_name
)
self
.
visit_CImportStatNode
(
node
)
self
.
visit_CImportStatNode
(
node
)
if
node
.
module_name
==
u"cython.parallel"
:
# This is an import for a fake module, remove it
return
None
if
node
.
module_name
.
startswith
(
u"cython.parallel."
):
error
(
node
.
pos
,
node
.
module_name
+
" is not a module"
)
else
:
else
:
self
.
visitchildren
(
node
)
self
.
visitchildren
(
node
)
return
node
return
node
...
@@ -897,6 +953,188 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
...
@@ -897,6 +953,188 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return
self
.
visit_with_directives
(
node
.
body
,
directive_dict
)
return
self
.
visit_with_directives
(
node
.
body
,
directive_dict
)
return
self
.
visit_Node
(
node
)
return
self
.
visit_Node
(
node
)
class
ParallelRangeTransform
(
CythonTransform
,
SkipDeclarations
):
"""
Transform cython.parallel stuff. The parallel_directives come from the
module node, set there by InterpretCompilerDirectives.
x = cython.parallel.threadavailable() -> ParallelThreadAvailableNode
with cython.parallel(nogil=True): -> ParallelWithBlockNode
print cython.parallel.threadid() -> ParallelThreadIdNode
for i in cython.parallel.prange(...): -> ParallelRangeNode
...
"""
# a list of names, maps 'cython.parallel.prange' in the code to
# ['cython', 'parallel', 'prange']
parallel_directive
=
None
# Indicates whether a namenode in an expression is the cython module
namenode_is_cython_module
=
False
# Keep track of whether we are the context manager of a 'with' statement
in_context_manager_section
=
False
# Keep track of whether we are in a parallel range section
in_prange
=
False
directive_to_node
=
{
u"cython.parallel.parallel"
:
Nodes
.
ParallelWithBlockNode
,
# u"cython.parallel.threadsavailable": ExprNodes.ParallelThreadsAvailableNode,
u"cython.parallel.threadid"
:
ExprNodes
.
ParallelThreadIdNode
,
u"cython.parallel.prange"
:
Nodes
.
ParallelRangeNode
,
}
def
node_is_parallel_directive
(
self
,
node
):
return
node
.
name
in
self
.
parallel_directives
or
node
.
is_cython_module
def
get_directive_class_node
(
self
,
node
):
"""
Figure out which parallel directive was used and return the associated
Node class.
E.g. for a cython.parallel.prange() call we return ParallelRangeNode
Also disallow break, continue and return in a prange section
"""
if
self
.
namenode_is_cython_module
:
directive
=
'.'
.
join
(
self
.
parallel_directive
)
else
:
directive
=
self
.
parallel_directives
[
self
.
parallel_directive
[
0
]]
directive
=
'%s.%s'
%
(
directive
,
'.'
.
join
(
self
.
parallel_directive
[
1
:]))
directive
=
directive
.
rstrip
(
'.'
)
cls
=
self
.
directive_to_node
.
get
(
directive
)
if
cls
is
None
:
error
(
node
.
pos
,
"Invalid directive: %s"
%
directive
)
self
.
namenode_is_cython_module
=
False
self
.
parallel_directive
=
None
return
cls
def
visit_ModuleNode
(
self
,
node
):
"""
If any parallel directives were imported, copy them over and visit
the AST
"""
if
node
.
parallel_directives
:
self
.
parallel_directives
=
node
.
parallel_directives
self
.
assignment_stack
=
[]
return
self
.
visit_Node
(
node
)
# No parallel directives were imported, so they can't be used :)
return
node
def
visit_NameNode
(
self
,
node
):
if
self
.
node_is_parallel_directive
(
node
):
self
.
parallel_directive
=
[
node
.
name
]
self
.
namenode_is_cython_module
=
node
.
is_cython_module
return
node
def
visit_AttributeNode
(
self
,
node
):
self
.
visitchildren
(
node
)
if
self
.
parallel_directive
:
self
.
parallel_directive
.
append
(
node
.
attribute
)
return
node
def
visit_CallNode
(
self
,
node
):
self
.
visitchildren
(
node
)
if
not
self
.
parallel_directive
:
return
node
# We are a parallel directive, replace this node with the
# corresponding ParallelSomethingSomething node
if
isinstance
(
node
,
ExprNodes
.
GeneralCallNode
):
args
=
node
.
positional_args
.
args
kwargs
=
node
.
keyword_args
else
:
args
=
node
.
args
kwargs
=
{}
parallel_directive_class
=
self
.
get_directive_class_node
(
node
)
if
parallel_directive_class
:
node
=
parallel_directive_class
(
node
.
pos
,
args
=
args
,
kwargs
=
kwargs
)
return
node
def
visit_WithStatNode
(
self
,
node
):
"Rewrite with cython.parallel() blocks"
self
.
visit
(
node
.
manager
)
if
self
.
parallel_directive
:
parallel_directive_class
=
self
.
get_directive_class_node
(
node
)
if
not
parallel_directive_class
:
# There was an error, stop here and now
return
None
self
.
visit
(
node
.
body
)
newnode
=
Nodes
.
ParallelWithBlockNode
(
node
.
pos
,
body
=
node
.
body
)
else
:
newnode
=
node
self
.
visit
(
node
.
body
)
return
newnode
def
visit_ForInStatNode
(
self
,
node
):
"Rewrite 'for i in cython.parallel.prange(...):'"
self
.
visit
(
node
.
iterator
)
self
.
visit
(
node
.
target
)
was_in_prange
=
self
.
in_prange
self
.
in_prange
=
isinstance
(
node
.
iterator
.
sequence
,
Nodes
.
ParallelRangeNode
)
if
self
.
in_prange
:
# This will replace the entire ForInStatNode, so copy the
# attributes
parallel_range_node
=
node
.
iterator
.
sequence
parallel_range_node
.
target
=
node
.
target
parallel_range_node
.
body
=
node
.
body
parallel_range_node
.
else_clause
=
node
.
else_clause
node
=
parallel_range_node
if
not
isinstance
(
node
.
target
,
ExprNodes
.
NameNode
):
error
(
node
.
target
.
pos
,
"Can only iterate over an iteration variable"
)
self
.
visit
(
node
.
body
)
self
.
in_prange
=
was_in_prange
self
.
visit
(
node
.
else_clause
)
return
node
def
ensure_not_in_prange
(
name
):
"Creates error checking functions for break, continue and return"
def
visit_method
(
self
,
node
):
if
self
.
in_prange
:
error
(
node
.
pos
,
name
+
" not allowed in a parallel range section"
)
# Do a visit for 'return'
self
.
visitchildren
(
node
)
return
node
return
visit_method
visit_BreakStatNode
=
ensure_not_in_prange
(
"break"
)
visit_ContinueStatNode
=
ensure_not_in_prange
(
"continue"
)
visit_ReturnStatNode
=
ensure_not_in_prange
(
"return"
)
def
visit
(
self
,
node
):
"Visit a node that may be None"
if
node
is
not
None
:
super
(
ParallelRangeTransform
,
self
).
visit
(
node
)
class
WithTransform
(
CythonTransform
,
SkipDeclarations
):
class
WithTransform
(
CythonTransform
,
SkipDeclarations
):
def
visit_WithStatNode
(
self
,
node
):
def
visit_WithStatNode
(
self
,
node
):
self
.
visitchildren
(
node
,
'body'
)
self
.
visitchildren
(
node
,
'body'
)
...
@@ -1715,22 +1953,54 @@ class GilCheck(VisitorTransform):
...
@@ -1715,22 +1953,54 @@ class GilCheck(VisitorTransform):
self
.
env_stack
.
append
(
node
.
local_scope
)
self
.
env_stack
.
append
(
node
.
local_scope
)
was_nogil
=
self
.
nogil
was_nogil
=
self
.
nogil
self
.
nogil
=
node
.
local_scope
.
nogil
self
.
nogil
=
node
.
local_scope
.
nogil
if
self
.
nogil
and
node
.
nogil_check
:
if
self
.
nogil
and
node
.
nogil_check
:
node
.
nogil_check
(
node
.
local_scope
)
node
.
nogil_check
(
node
.
local_scope
)
self
.
visitchildren
(
node
)
self
.
visitchildren
(
node
)
self
.
env_stack
.
pop
()
self
.
env_stack
.
pop
()
self
.
nogil
=
was_nogil
self
.
nogil
=
was_nogil
return
node
return
node
def
visit_GILStatNode
(
self
,
node
):
def
visit_GILStatNode
(
self
,
node
):
env
=
self
.
env_stack
[
-
1
]
if
self
.
nogil
and
node
.
nogil_check
:
if
self
.
nogil
and
node
.
nogil_check
:
node
.
nogil_check
()
node
.
nogil_check
()
was_nogil
=
self
.
nogil
was_nogil
=
self
.
nogil
self
.
nogil
=
(
node
.
state
==
'nogil'
)
self
.
nogil
=
(
node
.
state
==
'nogil'
)
self
.
visitchildren
(
node
)
self
.
visitchildren
(
node
)
self
.
nogil
=
was_nogil
self
.
nogil
=
was_nogil
return
node
return
node
def
visit_ParallelRangeNode
(
self
,
node
):
if
node
.
is_nogil
:
node
.
is_nogil
=
False
node
=
Nodes
.
GILStatNode
(
node
.
pos
,
state
=
'nogil'
,
body
=
node
)
return
self
.
visit_GILStatNode
(
node
)
if
not
self
.
nogil
:
error
(
node
.
pos
,
"prange() can only be used without the GIL"
)
# Forget about any GIL-related errors that may occur in the body
return
None
node
.
nogil_check
(
self
.
env_stack
[
-
1
])
self
.
visitchildren
(
node
)
return
node
def
visit_ParallelWithBlockNode
(
self
,
node
):
if
not
self
.
nogil
:
error
(
node
.
pos
,
"The parallel section may only be used without "
"the GIL"
)
return
None
if
node
.
nogil_check
:
# It does not currently implement this, but test for it anyway to
# avoid potential future surprises
node
.
nogil_check
(
self
.
env_stack
[
-
1
])
self
.
visitchildren
(
node
)
return
node
def
visit_Node
(
self
,
node
):
def
visit_Node
(
self
,
node
):
if
self
.
env_stack
and
self
.
nogil
and
node
.
nogil_check
:
if
self
.
env_stack
and
self
.
nogil
and
node
.
nogil_check
:
node
.
nogil_check
(
self
.
env_stack
[
-
1
])
node
.
nogil_check
(
self
.
env_stack
[
-
1
])
...
@@ -1857,8 +2127,7 @@ class TransformBuiltinMethods(EnvTransform):
...
@@ -1857,8 +2127,7 @@ class TransformBuiltinMethods(EnvTransform):
class
DebugTransform
(
CythonTransform
):
class
DebugTransform
(
CythonTransform
):
"""
"""
Create debug information and all functions' visibility to extern in order
Write debug information for this Cython module.
to enable debugging.
"""
"""
def
__init__
(
self
,
context
,
options
,
result
):
def
__init__
(
self
,
context
,
options
,
result
):
...
...
Cython/Compiler/Symtab.py
View file @
18d8b3fb
...
@@ -723,6 +723,9 @@ class Scope(object):
...
@@ -723,6 +723,9 @@ class Scope(object):
else:
else:
return outer.is_cpp()
return outer.is_cpp()
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)
class PreImportScope(Scope):
class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname
namespace_cname = Naming.preimport_cname
...
@@ -1856,8 +1859,6 @@ class CppClassScope(Scope):
...
@@ -1856,8 +1859,6 @@ class CppClassScope(Scope):
utility_code
=
e
.
utility_code
)
utility_code
=
e
.
utility_code
)
return
scope
return
scope
def
add_include_file
(
self
,
filename
):
self
.
outer_scope
.
add_include_file
(
filename
)
class
PropertyScope
(
Scope
):
class
PropertyScope
(
Scope
):
# Scope holding the __get__, __set__ and __del__ methods for
# Scope holding the __get__, __set__ and __del__ methods for
...
...
Cython/Compiler/Tests/TestParseTreeTransforms.py
View file @
18d8b3fb
...
@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine
...
@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine
from
Cython.TestUtils
import
TransformTest
from
Cython.TestUtils
import
TransformTest
from
Cython.Compiler.ParseTreeTransforms
import
*
from
Cython.Compiler.ParseTreeTransforms
import
*
from
Cython.Compiler.Nodes
import
*
from
Cython.Compiler.Nodes
import
*
from
Cython.Compiler
import
Main
class
TestNormalizeTree
(
TransformTest
):
class
TestNormalizeTree
(
TransformTest
):
...
@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled!
...
@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled!
"""
,
t
)
"""
,
t
)
class
TestInterpretCompilerDirectives
(
TransformTest
):
"""
This class tests the parallel directives AST-rewriting and importing.
"""
# Test the parallel directives (c)importing
import_code
=
u"""
cimport cython.parallel
cimport cython.parallel as par
from cython cimport parallel as par2
from cython cimport parallel
from cython.parallel cimport threadid as tid
from cython.parallel cimport threadavailable as tavail
from cython.parallel cimport prange
"""
expected_directives_dict
=
{
u'cython.parallel'
:
u'cython.parallel'
,
u'par'
:
u'cython.parallel'
,
u'par2'
:
u'cython.parallel'
,
u'parallel'
:
u'cython.parallel'
,
u"tid"
:
u"cython.parallel.threadid"
,
u"tavail"
:
u"cython.parallel.threadavailable"
,
u"prange"
:
u"cython.parallel.prange"
,
}
def
setUp
(
self
):
super
(
TestInterpretCompilerDirectives
,
self
).
setUp
()
compilation_options
=
Main
.
CompilationOptions
(
Main
.
default_options
)
ctx
=
compilation_options
.
create_context
()
self
.
pipeline
=
[
InterpretCompilerDirectives
(
ctx
,
ctx
.
compiler_directives
),
]
self
.
debug_exception_on_error
=
DebugFlags
.
debug_exception_on_error
def
tearDown
(
self
):
DebugFlags
.
debug_exception_on_error
=
self
.
debug_exception_on_error
def
test_parallel_directives_cimports
(
self
):
self
.
run_pipeline
(
self
.
pipeline
,
self
.
import_code
)
parallel_directives
=
self
.
pipeline
[
0
].
parallel_directives
self
.
assertEqual
(
parallel_directives
,
self
.
expected_directives_dict
)
def
test_parallel_directives_imports
(
self
):
self
.
run_pipeline
(
self
.
pipeline
,
self
.
import_code
.
replace
(
u'cimport'
,
u'import'
))
parallel_directives
=
self
.
pipeline
[
0
].
parallel_directives
self
.
assertEqual
(
parallel_directives
,
self
.
expected_directives_dict
)
# TODO: Re-enable once they're more robust.
# TODO: Re-enable once they're more robust.
if
sys
.
version_info
[:
2
]
>=
(
2
,
5
)
and
False
:
if
sys
.
version_info
[:
2
]
>=
(
2
,
5
)
and
False
:
from
Cython.Debugger
import
DebugWriter
from
Cython.Debugger
import
DebugWriter
...
...
Cython/Compiler/TypeInference.py
View file @
18d8b3fb
...
@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type)
...
@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type)
class
MarkAssignments
(
CythonTransform
):
class
MarkAssignments
(
CythonTransform
):
def
mark_assignment
(
self
,
lhs
,
rhs
):
def
__init__
(
self
,
context
):
super
(
CythonTransform
,
self
).
__init__
()
self
.
context
=
context
# Track the parallel block scopes (with parallel, for i in prange())
self
.
parallel_block_stack
=
[]
def
mark_assignment
(
self
,
lhs
,
rhs
,
inplace_op
=
None
):
if
isinstance
(
lhs
,
(
ExprNodes
.
NameNode
,
Nodes
.
PyArgDeclNode
)):
if
isinstance
(
lhs
,
(
ExprNodes
.
NameNode
,
Nodes
.
PyArgDeclNode
)):
if
lhs
.
entry
is
None
:
if
lhs
.
entry
is
None
:
# TODO: This shouldn't happen...
# TODO: This shouldn't happen...
return
return
lhs
.
entry
.
assignments
.
append
(
rhs
)
lhs
.
entry
.
assignments
.
append
(
rhs
)
if
self
.
parallel_block_stack
:
parallel_node
=
self
.
parallel_block_stack
[
-
1
]
parallel_node
.
assignments
[
lhs
.
entry
]
=
(
lhs
.
pos
,
inplace_op
)
elif
isinstance
(
lhs
,
ExprNodes
.
SequenceNode
):
elif
isinstance
(
lhs
,
ExprNodes
.
SequenceNode
):
for
arg
in
lhs
.
args
:
for
arg
in
lhs
.
args
:
self
.
mark_assignment
(
arg
,
object_expr
)
self
.
mark_assignment
(
arg
,
object_expr
)
...
@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform):
...
@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform):
return
node
return
node
def
visit_InPlaceAssignmentNode
(
self
,
node
):
def
visit_InPlaceAssignmentNode
(
self
,
node
):
self
.
mark_assignment
(
node
.
lhs
,
node
.
create_binop_node
())
self
.
mark_assignment
(
node
.
lhs
,
node
.
create_binop_node
()
,
node
.
operator
)
self
.
visitchildren
(
node
)
self
.
visitchildren
(
node
)
return
node
return
node
...
@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform):
...
@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform):
self
.
visitchildren
(
node
)
self
.
visitchildren
(
node
)
return
node
return
node
def
visit_ParallelStatNode
(
self
,
node
):
if
self
.
parallel_block_stack
:
node
.
parent
=
self
.
parallel_block_stack
[
-
1
]
else
:
node
.
parent
=
None
if
node
.
is_prange
:
if
not
node
.
parent
:
node
.
is_parallel
=
True
else
:
node
.
is_parallel
=
(
node
.
parent
.
is_prange
or
not
node
.
parent
.
is_parallel
)
else
:
node
.
is_parallel
=
True
self
.
parallel_block_stack
.
append
(
node
)
self
.
visitchildren
(
node
)
self
.
parallel_block_stack
.
pop
()
return
node
class
MarkOverflowingArithmetic
(
CythonTransform
):
class
MarkOverflowingArithmetic
(
CythonTransform
):
# It may be possible to integrate this with the above for
# It may be possible to integrate this with the above for
...
...
tests/run/parallel.pyx
0 → 100644
View file @
18d8b3fb
# tag: run
# distutils: libraries = gomp
# distutils: extra_compile_args = -fopenmp
cimport
cython.parallel
from
cython.parallel
import
prange
,
threadid
from
libc.stdlib
cimport
malloc
,
free
cdef
extern
from
"Python.h"
:
void
PyEval_InitThreads
()
PyEval_InitThreads
()
cdef
void
print_int
(
int
x
)
with
gil
:
print
x
#@cython.test_assert_path_exists(
# "//ParallelWithBlockNode//ParallelRangeNode[@schedule = 'dynamic']",
# "//GILStatNode[@state = 'nogil]//ParallelRangeNode")
def
test_prange
():
"""
>>> test_prange()
(9, 9, 45, 45)
"""
cdef
Py_ssize_t
i
,
j
,
sum1
=
0
,
sum2
=
0
with
nogil
,
cython
.
parallel
.
parallel
:
for
i
in
prange
(
10
,
schedule
=
'dynamic'
):
sum1
+=
i
for
j
in
prange
(
10
,
nogil
=
True
):
sum2
+=
j
return
i
,
j
,
sum1
,
sum2
def
test_descending_prange
():
"""
>>> test_descending_prange()
5
"""
cdef
int
i
,
start
=
5
,
stop
=
-
5
,
step
=
-
2
cdef
int
sum
=
0
for
i
in
prange
(
start
,
stop
,
step
,
nogil
=
True
):
sum
+=
i
return
sum
def
test_nested_prange
():
"""
Reduction propagation is not (yet) supported.
>>> test_nested_prange()
50
"""
cdef
int
i
,
j
cdef
int
sum
=
0
for
i
in
prange
(
5
,
nogil
=
True
):
for
j
in
prange
(
5
):
sum
+=
i
# The value of sum is undefined here
sum
=
0
for
i
in
prange
(
5
,
nogil
=
True
):
for
j
in
prange
(
5
):
sum
+=
i
sum
+=
0
return
sum
# threadsavailable test, disable this for now as it won't compile
#def test_parallel():
# """
# >>> test_parallel()
# """
# cdef int *buf = <int *> malloc(sizeof(int) * threadsavailable())
#
# if buf == NULL:
# raise MemoryError
#
# with nogil, cython.parallel.parallel:
# buf[threadid()] = threadid()
#
# for i in range(threadsavailable()):
# assert buf[i] == i
#
# free(buf)
def
test_unsigned_operands
():
"""
This test is disabled, as this currently does not work (neither does it
for 'for i from x < i < y:'. I'm not sure we should strife to support
this, at least the C compiler gives a warning.
test_unsigned_operands()
10
"""
cdef
int
i
cdef
int
start
=
-
5
cdef
unsigned
int
stop
=
5
cdef
int
step
=
1
cdef
int
steps_taken
=
0
for
i
in
prange
(
start
,
stop
,
step
,
nogil
=
True
):
steps_taken
+=
1
return
steps_taken
def
test_reassign_start_stop_step
():
"""
>>> test_reassign_start_stop_step()
20
"""
cdef
int
start
=
0
,
stop
=
10
,
step
=
2
cdef
int
i
cdef
int
sum
=
0
for
i
in
prange
(
start
,
stop
,
step
,
nogil
=
True
):
start
=
-
2
stop
=
2
step
=
0
sum
+=
i
return
sum
def
test_closure_parallel_privates
():
"""
>>> test_closure_parallel_privates()
9 9
45 45
0 0 9 9
"""
cdef
int
x
def
test_target
():
nonlocal
x
for
x
in
prange
(
10
,
nogil
=
True
):
pass
return
x
print
test_target
(),
x
def
test_reduction
():
nonlocal
x
cdef
int
i
x
=
0
for
i
in
prange
(
10
,
nogil
=
True
):
x
+=
i
return
x
print
test_reduction
(),
x
def
test_generator
():
nonlocal
x
cdef
int
i
x
=
0
yield
x
x
=
2
for
i
in
prange
(
10
,
nogil
=
True
):
x
=
i
yield
x
g
=
test_generator
()
print
g
.
next
(),
x
,
g
.
next
(),
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment