Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
Xavier Thompson
cython
Commits
475bc21b
Commit
475bc21b
authored
Dec 17, 2008
by
Stefan Behnel
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
moved iter-range() optimisation into a transform (worth a review)
parent
66c5a0af
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
131 additions
and
6 deletions
+131
-6
Cython/Compiler/Main.py
Cython/Compiler/Main.py
+2
-2
Cython/Compiler/Nodes.py
Cython/Compiler/Nodes.py
+13
-2
Cython/Compiler/Optimize.py
Cython/Compiler/Optimize.py
+78
-2
tests/run/r_forloop.pyx
tests/run/r_forloop.pyx
+38
-0
No files found.
Cython/Compiler/Main.py
View file @
475bc21b
...
@@ -82,7 +82,7 @@ class Context:
...
@@ -82,7 +82,7 @@ class Context:
from
ParseTreeTransforms
import
InterpretCompilerDirectives
,
TransformBuiltinMethods
from
ParseTreeTransforms
import
InterpretCompilerDirectives
,
TransformBuiltinMethods
from
ParseTreeTransforms
import
AlignFunctionDefinitions
from
ParseTreeTransforms
import
AlignFunctionDefinitions
from
AutoDocTransforms
import
EmbedSignature
from
AutoDocTransforms
import
EmbedSignature
from
Optimize
import
FlattenInListTransform
,
SwitchTransform
,
DictIter
Transform
from
Optimize
import
FlattenInListTransform
,
SwitchTransform
,
Iteration
Transform
from
Optimize
import
FlattenBuiltinTypeCreation
,
ConstantFolding
,
FinalOptimizePhase
from
Optimize
import
FlattenBuiltinTypeCreation
,
ConstantFolding
,
FinalOptimizePhase
from
Buffer
import
IntroduceBufferAuxiliaryVars
from
Buffer
import
IntroduceBufferAuxiliaryVars
from
ModuleNode
import
check_c_declarations
from
ModuleNode
import
check_c_declarations
...
@@ -125,7 +125,7 @@ class Context:
...
@@ -125,7 +125,7 @@ class Context:
AnalyseExpressionsTransform
(
self
),
AnalyseExpressionsTransform
(
self
),
FlattenBuiltinTypeCreation
(),
FlattenBuiltinTypeCreation
(),
ConstantFolding
(),
ConstantFolding
(),
DictIter
Transform
(),
Iteration
Transform
(),
SwitchTransform
(),
SwitchTransform
(),
FinalOptimizePhase
(
self
),
FinalOptimizePhase
(
self
),
# ClearResultCodes(self),
# ClearResultCodes(self),
...
...
Cython/Compiler/Nodes.py
View file @
475bc21b
...
@@ -3719,7 +3719,7 @@ class ForInStatNode(LoopNode, StatNode):
...
@@ -3719,7 +3719,7 @@ class ForInStatNode(LoopNode, StatNode):
def
analyse_expressions
(
self
,
env
):
def
analyse_expressions
(
self
,
env
):
import
ExprNodes
import
ExprNodes
self
.
target
.
analyse_target_types
(
env
)
self
.
target
.
analyse_target_types
(
env
)
if
Options
.
convert_range
and
self
.
target
.
type
.
is_int
:
if
False
:
#
Options.convert_range and self.target.type.is_int:
sequence
=
self
.
iterator
.
sequence
sequence
=
self
.
iterator
.
sequence
if
isinstance
(
sequence
,
ExprNodes
.
SimpleCallNode
)
\
if
isinstance
(
sequence
,
ExprNodes
.
SimpleCallNode
)
\
and
sequence
.
self
is
None
\
and
sequence
.
self
is
None
\
...
@@ -3801,7 +3801,11 @@ class ForFromStatNode(LoopNode, StatNode):
...
@@ -3801,7 +3801,11 @@ class ForFromStatNode(LoopNode, StatNode):
# loopvar_name string
# loopvar_name string
# py_loopvar_node PyTempNode or None
# py_loopvar_node PyTempNode or None
child_attrs
=
[
"target"
,
"bound1"
,
"bound2"
,
"step"
,
"body"
,
"else_clause"
]
child_attrs
=
[
"target"
,
"bound1"
,
"bound2"
,
"step"
,
"body"
,
"else_clause"
]
is_py_target
=
False
loopvar_name
=
None
py_loopvar_node
=
None
def
analyse_declarations
(
self
,
env
):
def
analyse_declarations
(
self
,
env
):
self
.
target
.
analyse_target_declaration
(
env
)
self
.
target
.
analyse_target_declaration
(
env
)
self
.
body
.
analyse_declarations
(
env
)
self
.
body
.
analyse_declarations
(
env
)
...
@@ -3866,6 +3870,13 @@ class ForFromStatNode(LoopNode, StatNode):
...
@@ -3866,6 +3870,13 @@ class ForFromStatNode(LoopNode, StatNode):
self
.
bound2
.
release_temp
(
env
)
self
.
bound2
.
release_temp
(
env
)
if
self
.
step
is
not
None
:
if
self
.
step
is
not
None
:
self
.
step
.
release_temp
(
env
)
self
.
step
.
release_temp
(
env
)
def
reanalyse_c_loop
(
self
,
env
):
# only make sure all subnodes have an integer type
self
.
bound1
=
self
.
bound1
.
coerce_to_integer
(
env
)
self
.
bound2
=
self
.
bound2
.
coerce_to_integer
(
env
)
if
self
.
step
is
not
None
:
self
.
step
=
self
.
step
.
coerce_to_integer
(
env
)
def
generate_execution_code
(
self
,
code
):
def
generate_execution_code
(
self
,
code
):
old_loop_labels
=
code
.
new_loop_labels
()
old_loop_labels
=
code
.
new_loop_labels
()
...
...
Cython/Compiler/Optimize.py
View file @
475bc21b
...
@@ -6,6 +6,7 @@ import Builtin
...
@@ -6,6 +6,7 @@ import Builtin
import
UtilNodes
import
UtilNodes
import
TypeSlots
import
TypeSlots
import
Symtab
import
Symtab
import
Options
from
StringEncoding
import
EncodedString
from
StringEncoding
import
EncodedString
from
ParseTreeTransforms
import
SkipDeclarations
from
ParseTreeTransforms
import
SkipDeclarations
...
@@ -29,8 +30,11 @@ def is_common_value(a, b):
...
@@ -29,8 +30,11 @@ def is_common_value(a, b):
return
False
return
False
class
DictIterTransform
(
Visitor
.
VisitorTransform
):
class
IterationTransform
(
Visitor
.
VisitorTransform
):
"""Transform a for-in-dict loop into a while loop calling PyDict_Next().
"""Transform some common for-in loop patterns into efficient C loops:
- for-in-dict loop becomes a while loop calling PyDict_Next()
- for-in-range loop becomes a plain C for loop
"""
"""
PyDict_Next_func_type
=
PyrexTypes
.
CFuncType
(
PyDict_Next_func_type
=
PyrexTypes
.
CFuncType
(
PyrexTypes
.
c_bint_type
,
[
PyrexTypes
.
c_bint_type
,
[
...
@@ -50,6 +54,18 @@ class DictIterTransform(Visitor.VisitorTransform):
...
@@ -50,6 +54,18 @@ class DictIterTransform(Visitor.VisitorTransform):
self
.
visitchildren
(
node
)
self
.
visitchildren
(
node
)
return
node
return
node
def
visit_ModuleNode
(
self
,
node
):
self
.
current_scope
=
node
.
scope
self
.
visitchildren
(
node
)
return
node
def
visit_DefNode
(
self
,
node
):
oldscope
=
self
.
current_scope
self
.
current_scope
=
node
.
entry
.
scope
self
.
visitchildren
(
node
)
self
.
current_scope
=
oldscope
return
node
def
visit_ForInStatNode
(
self
,
node
):
def
visit_ForInStatNode
(
self
,
node
):
self
.
visitchildren
(
node
)
self
.
visitchildren
(
node
)
iterator
=
node
.
iterator
.
sequence
iterator
=
node
.
iterator
.
sequence
...
@@ -61,6 +77,7 @@ class DictIterTransform(Visitor.VisitorTransform):
...
@@ -61,6 +77,7 @@ class DictIterTransform(Visitor.VisitorTransform):
return
node
return
node
function
=
iterator
.
function
function
=
iterator
.
function
# dict iteration?
if
isinstance
(
function
,
ExprNodes
.
AttributeNode
)
and
\
if
isinstance
(
function
,
ExprNodes
.
AttributeNode
)
and
\
function
.
obj
.
type
==
Builtin
.
dict_type
:
function
.
obj
.
type
==
Builtin
.
dict_type
:
dict_obj
=
function
.
obj
dict_obj
=
function
.
obj
...
@@ -77,8 +94,67 @@ class DictIterTransform(Visitor.VisitorTransform):
...
@@ -77,8 +94,67 @@ class DictIterTransform(Visitor.VisitorTransform):
return
node
return
node
return
self
.
_transform_dict_iteration
(
return
self
.
_transform_dict_iteration
(
node
,
dict_obj
,
keys
,
values
)
node
,
dict_obj
,
keys
,
values
)
# range() iteration?
if
Options
.
convert_range
and
node
.
target
.
type
.
is_int
:
if
iterator
.
self
is
None
and
\
isinstance
(
function
,
ExprNodes
.
NameNode
)
and
\
function
.
entry
.
is_builtin
and
\
function
.
name
in
(
'range'
,
'xrange'
):
return
self
.
_transform_range_iteration
(
node
,
iterator
)
return
node
return
node
def
_transform_range_iteration
(
self
,
node
,
range_function
):
args
=
range_function
.
arg_tuple
.
args
if
len
(
args
)
<
3
:
step_pos
=
range_function
.
pos
step_value
=
1
step
=
ExprNodes
.
IntNode
(
step_pos
,
value
=
1
)
else
:
step
=
args
[
2
]
step_pos
=
step
.
pos
if
step
.
constant_result
is
ExprNodes
.
not_a_constant
:
# cannot determine step direction
return
node
try
:
# FIXME: check how Python handles rounding here, e.g. from float
step_value
=
int
(
step
.
constant_result
)
except
:
return
node
if
not
isinstance
(
step
,
ExprNodes
.
IntNode
):
step
=
ExprNodes
.
IntNode
(
step_pos
,
value
=
step_value
)
if
step_value
>
0
:
relation1
=
'<='
relation2
=
'<'
elif
step_value
<
0
:
step
.
value
=
-
step_value
relation1
=
'>='
relation2
=
'>'
else
:
return
node
if
len
(
args
)
==
1
:
bound1
=
ExprNodes
.
IntNode
(
range_function
.
pos
,
value
=
0
)
bound2
=
args
[
0
]
else
:
bound1
=
args
[
0
]
bound2
=
args
[
1
]
for_node
=
Nodes
.
ForFromStatNode
(
node
.
pos
,
target
=
node
.
target
,
bound1
=
bound1
,
relation1
=
relation1
,
relation2
=
relation2
,
bound2
=
bound2
,
step
=
step
,
body
=
node
.
body
,
else_clause
=
node
.
else_clause
,
loopvar_name
=
node
.
target
.
entry
.
cname
)
for_node
.
reanalyse_c_loop
(
self
.
current_scope
)
# for_node.analyse_expressions(self.current_scope)
return
for_node
def
_transform_dict_iteration
(
self
,
node
,
dict_obj
,
keys
,
values
):
def
_transform_dict_iteration
(
self
,
node
,
dict_obj
,
keys
,
values
):
py_object_ptr
=
PyrexTypes
.
c_void_ptr_type
py_object_ptr
=
PyrexTypes
.
c_void_ptr_type
...
...
tests/run/r_forloop.pyx
View file @
475bc21b
...
@@ -12,8 +12,22 @@ __doc__ = u"""
...
@@ -12,8 +12,22 @@ __doc__ = u"""
Spam!
Spam!
Spam!
Spam!
Spam!
Spam!
>>> go_c_all()
Spam!
Spam!
Spam!
>>> go_c_all_exprs(1)
Spam!
>>> go_c_all_exprs(3)
Spam!
Spam!
>>> go_c_calc(2)
Spam!
Spam!
>>> go_c_ret()
>>> go_c_ret()
2
2
>>> go_c_calc_ret(2)
6
>>> go_list()
>>> go_list()
Spam!
Spam!
...
@@ -54,6 +68,30 @@ def go_c():
...
@@ -54,6 +68,30 @@ def go_c():
for
i
in
range
(
4
):
for
i
in
range
(
4
):
print
u"Spam!"
print
u"Spam!"
def
go_c_all
():
cdef
int
i
for
i
in
range
(
8
,
2
,
-
2
):
print
u"Spam!"
def
go_c_all_exprs
(
x
):
cdef
int
i
for
i
in
range
(
4
*
x
,
2
*
x
,
-
3
):
print
u"Spam!"
def
f
(
x
):
return
2
*
x
def
go_c_calc
(
x
):
cdef
int
i
for
i
in
range
(
2
*
f
(
x
),
f
(
x
),
-
2
):
print
u"Spam!"
def
go_c_calc_ret
(
x
):
cdef
int
i
for
i
in
range
(
2
*
f
(
x
),
f
(
x
),
-
2
):
if
i
<
2
*
f
(
x
):
return
i
def
go_c_ret
():
def
go_c_ret
():
cdef
int
i
cdef
int
i
for
i
in
range
(
4
):
for
i
in
range
(
4
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment