Commit c68e59af authored by Stefan Behnel's avatar Stefan Behnel

infer Py_ssize_t for enumerate() index variable in simple cases when iterating over builtin types

parent 58788692
......@@ -5,7 +5,7 @@ import Builtin
import PyrexTypes
from Cython import Utils
from PyrexTypes import py_object_type, unspecified_type
from Visitor import CythonTransform
from Visitor import CythonTransform, EnvTransform
class TypedExprNode(ExprNodes.ExprNode):
......@@ -15,19 +15,16 @@ class TypedExprNode(ExprNodes.ExprNode):
object_expr = TypedExprNode(py_object_type)
class MarkAssignments(CythonTransform):
class MarkAssignments(EnvTransform):
# tells us whether we're in a normal loop
in_loop = False
parallel_errors = False
def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
# Track the parallel block scopes (with parallel, for i in prange())
self.parallel_block_stack = []
return super(MarkAssignments, self).__init__(context)
def mark_assignment(self, lhs, rhs, inplace_op=None):
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
......@@ -90,21 +87,39 @@ class MarkAssignments(CythonTransform):
# TODO: Remove redundancy with range optimization...
is_special = False
sequence = node.iterator.sequence
target = node.target
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and function.is_name:
entry = self.current_env().lookup(function.name)
if not entry or entry.is_builtin:
if function.name == 'reversed' and len(sequence.args) == 1:
sequence = sequence.args[0]
elif function.name == 'enumerate' and len(sequence.args) == 1:
if target.is_sequence_constructor and len(target.args) == 2:
iterator = sequence.args[0]
if iterator.is_name:
iterator_type = iterator.infer_type(self.current_env())
if iterator_type.is_builtin_type:
# assume that builtin types have a length within Py_ssize_t
self.mark_assignment(
target.args[0],
ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
type=PyrexTypes.c_py_ssize_t_type))
target = target.args[1]
sequence = sequence.args[0]
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and function.is_name:
entry = self.current_env().lookup(function.name)
if not entry or entry.is_builtin:
if function.name in ('range', 'xrange'):
is_special = True
for arg in sequence.args[:2]:
self.mark_assignment(node.target, arg)
self.mark_assignment(target, arg)
if len(sequence.args) > 2:
self.mark_assignment(
node.target,
target,
ExprNodes.binop_node(node.pos,
'+',
sequence.args[0],
......@@ -116,7 +131,7 @@ class MarkAssignments(CythonTransform):
# naturally infer the base type of pointers, C arrays,
# Python strings, etc., while correctly falling back to an
# object type when the base type cannot be handled.
self.mark_assignment(node.target, ExprNodes.IndexNode(
self.mark_assignment(target, ExprNodes.IndexNode(
node.pos,
base = sequence,
index = ExprNodes.IntNode(node.pos, value = '0')))
......@@ -163,7 +178,7 @@ class MarkAssignments(CythonTransform):
if node.starstar_arg:
self.mark_assignment(
node.starstar_arg, TypedExprNode(Builtin.dict_type))
self.visitchildren(node)
EnvTransform.visit_FuncDefNode(self, node)
return node
def visit_DelStatNode(self, node):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment