Commit c68e59af authored by Stefan Behnel's avatar Stefan Behnel

infer Py_ssize_t for enumerate() index variable in simple cases when iterating over builtin types

parent 58788692
...@@ -5,7 +5,7 @@ import Builtin ...@@ -5,7 +5,7 @@ import Builtin
import PyrexTypes import PyrexTypes
from Cython import Utils from Cython import Utils
from PyrexTypes import py_object_type, unspecified_type from PyrexTypes import py_object_type, unspecified_type
from Visitor import CythonTransform from Visitor import CythonTransform, EnvTransform
class TypedExprNode(ExprNodes.ExprNode): class TypedExprNode(ExprNodes.ExprNode):
...@@ -15,19 +15,16 @@ class TypedExprNode(ExprNodes.ExprNode): ...@@ -15,19 +15,16 @@ class TypedExprNode(ExprNodes.ExprNode):
object_expr = TypedExprNode(py_object_type) object_expr = TypedExprNode(py_object_type)
class MarkAssignments(CythonTransform): class MarkAssignments(EnvTransform):
# tells us whether we're in a normal loop # tells us whether we're in a normal loop
in_loop = False in_loop = False
parallel_errors = False parallel_errors = False
def __init__(self, context): def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
# Track the parallel block scopes (with parallel, for i in prange()) # Track the parallel block scopes (with parallel, for i in prange())
self.parallel_block_stack = [] self.parallel_block_stack = []
return super(MarkAssignments, self).__init__(context)
def mark_assignment(self, lhs, rhs, inplace_op=None): def mark_assignment(self, lhs, rhs, inplace_op=None):
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)): if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
...@@ -90,25 +87,43 @@ class MarkAssignments(CythonTransform): ...@@ -90,25 +87,43 @@ class MarkAssignments(CythonTransform):
# TODO: Remove redundancy with range optimization... # TODO: Remove redundancy with range optimization...
is_special = False is_special = False
sequence = node.iterator.sequence sequence = node.iterator.sequence
target = node.target
if isinstance(sequence, ExprNodes.SimpleCallNode): if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function function = sequence.function
if sequence.self is None and function.is_name: if sequence.self is None and function.is_name:
if function.name == 'reversed' and len(sequence.args) == 1: entry = self.current_env().lookup(function.name)
sequence = sequence.args[0] if not entry or entry.is_builtin:
if function.name == 'reversed' and len(sequence.args) == 1:
sequence = sequence.args[0]
elif function.name == 'enumerate' and len(sequence.args) == 1:
if target.is_sequence_constructor and len(target.args) == 2:
iterator = sequence.args[0]
if iterator.is_name:
iterator_type = iterator.infer_type(self.current_env())
if iterator_type.is_builtin_type:
# assume that builtin types have a length within Py_ssize_t
self.mark_assignment(
target.args[0],
ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
type=PyrexTypes.c_py_ssize_t_type))
target = target.args[1]
sequence = sequence.args[0]
if isinstance(sequence, ExprNodes.SimpleCallNode): if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function function = sequence.function
if sequence.self is None and function.is_name: if sequence.self is None and function.is_name:
if function.name in ('range', 'xrange'): entry = self.current_env().lookup(function.name)
is_special = True if not entry or entry.is_builtin:
for arg in sequence.args[:2]: if function.name in ('range', 'xrange'):
self.mark_assignment(node.target, arg) is_special = True
if len(sequence.args) > 2: for arg in sequence.args[:2]:
self.mark_assignment( self.mark_assignment(target, arg)
node.target, if len(sequence.args) > 2:
ExprNodes.binop_node(node.pos, self.mark_assignment(
'+', target,
sequence.args[0], ExprNodes.binop_node(node.pos,
sequence.args[2])) '+',
sequence.args[0],
sequence.args[2]))
if not is_special: if not is_special:
# A for-loop basically translates to subsequent calls to # A for-loop basically translates to subsequent calls to
...@@ -116,7 +131,7 @@ class MarkAssignments(CythonTransform): ...@@ -116,7 +131,7 @@ class MarkAssignments(CythonTransform):
# naturally infer the base type of pointers, C arrays, # naturally infer the base type of pointers, C arrays,
# Python strings, etc., while correctly falling back to an # Python strings, etc., while correctly falling back to an
# object type when the base type cannot be handled. # object type when the base type cannot be handled.
self.mark_assignment(node.target, ExprNodes.IndexNode( self.mark_assignment(target, ExprNodes.IndexNode(
node.pos, node.pos,
base = sequence, base = sequence,
index = ExprNodes.IntNode(node.pos, value = '0'))) index = ExprNodes.IntNode(node.pos, value = '0')))
...@@ -163,7 +178,7 @@ class MarkAssignments(CythonTransform): ...@@ -163,7 +178,7 @@ class MarkAssignments(CythonTransform):
if node.starstar_arg: if node.starstar_arg:
self.mark_assignment( self.mark_assignment(
node.starstar_arg, TypedExprNode(Builtin.dict_type)) node.starstar_arg, TypedExprNode(Builtin.dict_type))
self.visitchildren(node) EnvTransform.visit_FuncDefNode(self, node)
return node return node
def visit_DelStatNode(self, node): def visit_DelStatNode(self, node):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment