Commit 475bc21b authored by Stefan Behnel's avatar Stefan Behnel

moved iter-range() optimisation into a transform (worth a review)

parent 66c5a0af
...@@ -82,7 +82,7 @@ class Context: ...@@ -82,7 +82,7 @@ class Context:
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import AlignFunctionDefinitions from ParseTreeTransforms import AlignFunctionDefinitions
from AutoDocTransforms import EmbedSignature from AutoDocTransforms import EmbedSignature
from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
from Buffer import IntroduceBufferAuxiliaryVars from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_declarations from ModuleNode import check_c_declarations
...@@ -125,7 +125,7 @@ class Context: ...@@ -125,7 +125,7 @@ class Context:
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
FlattenBuiltinTypeCreation(), FlattenBuiltinTypeCreation(),
ConstantFolding(), ConstantFolding(),
DictIterTransform(), IterationTransform(),
SwitchTransform(), SwitchTransform(),
FinalOptimizePhase(self), FinalOptimizePhase(self),
# ClearResultCodes(self), # ClearResultCodes(self),
......
...@@ -3719,7 +3719,7 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -3719,7 +3719,7 @@ class ForInStatNode(LoopNode, StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
import ExprNodes import ExprNodes
self.target.analyse_target_types(env) self.target.analyse_target_types(env)
if Options.convert_range and self.target.type.is_int: if False: # Options.convert_range and self.target.type.is_int:
sequence = self.iterator.sequence sequence = self.iterator.sequence
if isinstance(sequence, ExprNodes.SimpleCallNode) \ if isinstance(sequence, ExprNodes.SimpleCallNode) \
and sequence.self is None \ and sequence.self is None \
...@@ -3801,7 +3801,11 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -3801,7 +3801,11 @@ class ForFromStatNode(LoopNode, StatNode):
# loopvar_name string # loopvar_name string
# py_loopvar_node PyTempNode or None # py_loopvar_node PyTempNode or None
child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause"] child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause"]
is_py_target = False
loopvar_name = None
py_loopvar_node = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.target.analyse_target_declaration(env) self.target.analyse_target_declaration(env)
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
...@@ -3866,6 +3870,13 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -3866,6 +3870,13 @@ class ForFromStatNode(LoopNode, StatNode):
self.bound2.release_temp(env) self.bound2.release_temp(env)
if self.step is not None: if self.step is not None:
self.step.release_temp(env) self.step.release_temp(env)
def reanalyse_c_loop(self, env):
# only make sure all subnodes have an integer type
self.bound1 = self.bound1.coerce_to_integer(env)
self.bound2 = self.bound2.coerce_to_integer(env)
if self.step is not None:
self.step = self.step.coerce_to_integer(env)
def generate_execution_code(self, code): def generate_execution_code(self, code):
old_loop_labels = code.new_loop_labels() old_loop_labels = code.new_loop_labels()
......
...@@ -6,6 +6,7 @@ import Builtin ...@@ -6,6 +6,7 @@ import Builtin
import UtilNodes import UtilNodes
import TypeSlots import TypeSlots
import Symtab import Symtab
import Options
from StringEncoding import EncodedString from StringEncoding import EncodedString
from ParseTreeTransforms import SkipDeclarations from ParseTreeTransforms import SkipDeclarations
...@@ -29,8 +30,11 @@ def is_common_value(a, b): ...@@ -29,8 +30,11 @@ def is_common_value(a, b):
return False return False
class DictIterTransform(Visitor.VisitorTransform): class IterationTransform(Visitor.VisitorTransform):
"""Transform a for-in-dict loop into a while loop calling PyDict_Next(). """Transform some common for-in loop patterns into efficient C loops:
- for-in-dict loop becomes a while loop calling PyDict_Next()
- for-in-range loop becomes a plain C for loop
""" """
PyDict_Next_func_type = PyrexTypes.CFuncType( PyDict_Next_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [ PyrexTypes.c_bint_type, [
...@@ -50,6 +54,18 @@ class DictIterTransform(Visitor.VisitorTransform): ...@@ -50,6 +54,18 @@ class DictIterTransform(Visitor.VisitorTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_ModuleNode(self, node):
self.current_scope = node.scope
self.visitchildren(node)
return node
def visit_DefNode(self, node):
oldscope = self.current_scope
self.current_scope = node.entry.scope
self.visitchildren(node)
self.current_scope = oldscope
return node
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
self.visitchildren(node) self.visitchildren(node)
iterator = node.iterator.sequence iterator = node.iterator.sequence
...@@ -61,6 +77,7 @@ class DictIterTransform(Visitor.VisitorTransform): ...@@ -61,6 +77,7 @@ class DictIterTransform(Visitor.VisitorTransform):
return node return node
function = iterator.function function = iterator.function
# dict iteration?
if isinstance(function, ExprNodes.AttributeNode) and \ if isinstance(function, ExprNodes.AttributeNode) and \
function.obj.type == Builtin.dict_type: function.obj.type == Builtin.dict_type:
dict_obj = function.obj dict_obj = function.obj
...@@ -77,8 +94,67 @@ class DictIterTransform(Visitor.VisitorTransform): ...@@ -77,8 +94,67 @@ class DictIterTransform(Visitor.VisitorTransform):
return node return node
return self._transform_dict_iteration( return self._transform_dict_iteration(
node, dict_obj, keys, values) node, dict_obj, keys, values)
# range() iteration?
if Options.convert_range and node.target.type.is_int:
if iterator.self is None and \
isinstance(function, ExprNodes.NameNode) and \
function.entry.is_builtin and \
function.name in ('range', 'xrange'):
return self._transform_range_iteration(
node, iterator)
return node return node
def _transform_range_iteration(self, node, range_function):
args = range_function.arg_tuple.args
if len(args) < 3:
step_pos = range_function.pos
step_value = 1
step = ExprNodes.IntNode(step_pos, value=1)
else:
step = args[2]
step_pos = step.pos
if step.constant_result is ExprNodes.not_a_constant:
# cannot determine step direction
return node
try:
# FIXME: check how Python handles rounding here, e.g. from float
step_value = int(step.constant_result)
except:
return node
if not isinstance(step, ExprNodes.IntNode):
step = ExprNodes.IntNode(step_pos, value=step_value)
if step_value > 0:
relation1 = '<='
relation2 = '<'
elif step_value < 0:
step.value = -step_value
relation1 = '>='
relation2 = '>'
else:
return node
if len(args) == 1:
bound1 = ExprNodes.IntNode(range_function.pos, value=0)
bound2 = args[0]
else:
bound1 = args[0]
bound2 = args[1]
for_node = Nodes.ForFromStatNode(
node.pos,
target=node.target,
bound1=bound1, relation1=relation1,
relation2=relation2, bound2=bound2,
step=step, body=node.body,
else_clause=node.else_clause,
loopvar_name = node.target.entry.cname)
for_node.reanalyse_c_loop(self.current_scope)
# for_node.analyse_expressions(self.current_scope)
return for_node
def _transform_dict_iteration(self, node, dict_obj, keys, values): def _transform_dict_iteration(self, node, dict_obj, keys, values):
py_object_ptr = PyrexTypes.c_void_ptr_type py_object_ptr = PyrexTypes.c_void_ptr_type
......
...@@ -12,8 +12,22 @@ __doc__ = u""" ...@@ -12,8 +12,22 @@ __doc__ = u"""
Spam! Spam!
Spam! Spam!
Spam! Spam!
>>> go_c_all()
Spam!
Spam!
Spam!
>>> go_c_all_exprs(1)
Spam!
>>> go_c_all_exprs(3)
Spam!
Spam!
>>> go_c_calc(2)
Spam!
Spam!
>>> go_c_ret() >>> go_c_ret()
2 2
>>> go_c_calc_ret(2)
6
>>> go_list() >>> go_list()
Spam! Spam!
...@@ -54,6 +68,30 @@ def go_c(): ...@@ -54,6 +68,30 @@ def go_c():
for i in range(4): for i in range(4):
print u"Spam!" print u"Spam!"
def go_c_all():
cdef int i
for i in range(8,2,-2):
print u"Spam!"
def go_c_all_exprs(x):
cdef int i
for i in range(4*x,2*x,-3):
print u"Spam!"
def f(x):
return 2*x
def go_c_calc(x):
cdef int i
for i in range(2*f(x),f(x), -2):
print u"Spam!"
def go_c_calc_ret(x):
cdef int i
for i in range(2*f(x),f(x), -2):
if i < 2*f(x):
return i
def go_c_ret(): def go_c_ret():
cdef int i cdef int i
for i in range(4): for i in range(4):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment