Commit 87a9bc29 authored by Stefan Behnel's avatar Stefan Behnel

fix T600: lookup of iterables in genexpr must use outer scope

parent c117dddd
......@@ -2397,20 +2397,26 @@ class NextNode(AtomicExprNode):
#
# iterator IteratorNode
def __init__(self, iterator):
def __init__(self, iterator, lives_in_outer_scope=False):
AtomicExprNode.__init__(self, iterator.pos)
self.lives_in_outer_scope = lives_in_outer_scope
self.iterator = iterator
def type_dependencies(self, env):
if self.lives_in_outer_scope:
env = Symtab.NonLocalScopeWrapper(env)
return self.iterator.type_dependencies(env)
def infer_type(self, env, iterator_type = None):
def infer_type(self, env, iterator_type=None):
if self.lives_in_outer_scope:
env = Symtab.NonLocalScopeWrapper(env)
if iterator_type is None:
iterator_type = self.iterator.infer_type(env)
if iterator_type.is_ptr or iterator_type.is_array:
return iterator_type.base_type
elif iterator_type.is_cpp_class:
item_type = env.lookup_operator_for_types(self.pos, "*", [iterator_type]).type.return_type
item_type = env.lookup_operator_for_types(
self.pos, "*", [iterator_type]).type.return_type
if item_type.is_reference:
item_type = item_type.ref_base_type
if item_type.is_const:
......@@ -2418,9 +2424,10 @@ class NextNode(AtomicExprNode):
return item_type
else:
# Avoid duplication of complicated logic.
fake_index_node = IndexNode(self.pos,
base=self.iterator.sequence,
index=IntNode(self.pos, value='0'))
fake_index_node = IndexNode(
self.pos,
base=self.iterator.sequence,
index=IntNode(self.pos, value='0'))
return fake_index_node.infer_type(env)
def analyse_types(self, env):
......
......@@ -10,6 +10,7 @@ import Builtin
import ExprNodes
import Nodes
import Options
import Symtab
from PyrexTypes import py_object_type, unspecified_type
import PyrexTypes
......@@ -984,13 +985,20 @@ class ControlFlowAnalysis(CythonTransform):
next_block = self.flow.newblock()
# Condition with iterator
self.flow.loops.append(LoopDescr(next_block, condition_block))
is_normal_for_loop = isinstance(node, Nodes.ForInStatNode)
if is_normal_for_loop and node.first_in_genexp:
self.env_stack.append(self.env)
self.env = Symtab.NonLocalScopeWrapper(self.env)
self._visit(node.iterator)
if is_normal_for_loop and node.first_in_genexp:
self.env = self.env_stack.pop()
# Target assignment
self.flow.nextblock()
if isinstance(node, Nodes.ForInStatNode):
if is_normal_for_loop:
self.mark_forloop_target(node)
else: # Parallel
else: # Parallel
self.mark_assignment(node.target)
# Body block
......
......@@ -20,7 +20,7 @@ import PyrexTypes
import TypeSlots
from PyrexTypes import py_object_type, error_type
from Symtab import ModuleScope, LocalScope, ClosureScope, \
StructOrUnionScope, PyClassScope, CppClassScope
StructOrUnionScope, PyClassScope, CppClassScope, NonLocalScopeWrapper
from Code import UtilityCode
from StringEncoding import EncodedString, escape_byte_string, split_string_literal
import Options
......@@ -5560,12 +5560,14 @@ class DictIterationNextNode(Node):
target.generate_assignment_code(result, code)
var.release(code)
def ForStatNode(pos, **kw):
if 'iterator' in kw:
return ForInStatNode(pos, **kw)
else:
return ForFromStatNode(pos, **kw)
class ForInStatNode(LoopNode, StatNode):
# for statement
#
......@@ -5574,9 +5576,11 @@ class ForInStatNode(LoopNode, StatNode):
# body StatNode
# else_clause StatNode
# item NextNode used internally
# first_in_genexp bool True for outer loop of a scoped genexp/comprehension
child_attrs = ["target", "iterator", "body", "else_clause"]
item = None
first_in_genexp = False
def analyse_declarations(self, env):
import ExprNodes
......@@ -5584,18 +5588,25 @@ class ForInStatNode(LoopNode, StatNode):
self.body.analyse_declarations(env)
if self.else_clause:
self.else_clause.analyse_declarations(env)
self.item = ExprNodes.NextNode(self.iterator)
self.item = ExprNodes.NextNode(
self.iterator, lives_in_outer_scope=self.first_in_genexp)
def analyse_expressions(self, env):
self.target = self.target.analyse_target_types(env)
self.iterator = self.iterator.analyse_expressions(env)
iterator_env = env
if self.first_in_genexp:
# outermost iterator in a genexpr or scoped comprehension is
# looked up in the outer scope
iterator_env = NonLocalScopeWrapper(iterator_env)
self.iterator = self.iterator.analyse_expressions(iterator_env)
if self.item is None:
# Hack. Sometimes analyse_declarations not called.
import ExprNodes
self.item = ExprNodes.NextNode(self.iterator)
self.item = ExprNodes.NextNode(
self.iterator, lives_in_outer_scope=self.first_in_genexp)
self.item = self.item.analyse_expressions(env)
if (self.iterator.type.is_ptr or self.iterator.type.is_array) and \
self.target.type.assignable_from(self.iterator.type):
if ((self.iterator.type.is_ptr or self.iterator.type.is_array) and
self.target.type.assignable_from(self.iterator.type)):
# C array slice optimization.
pass
else:
......
......@@ -67,7 +67,7 @@ cdef bint check_for_non_ascii_characters(unicode string)
cdef p_string_literal(PyrexScanner s, kind_override=*)
cdef p_list_maker(PyrexScanner s)
cdef p_comp_iter(PyrexScanner s, body)
cdef p_comp_for(PyrexScanner s, body)
cdef p_comp_for(PyrexScanner s, body, bint first_in_genexp)
cdef p_comp_if(PyrexScanner s, body)
cdef p_dict_or_set_maker(PyrexScanner s)
cdef p_backquote_expr(PyrexScanner s)
......
......@@ -909,13 +909,14 @@ def p_list_maker(s):
return ExprNodes.ListNode(pos, args = [])
expr = p_test(s)
if s.sy == 'for':
has_local_scope = s.context.language_level >= 3
append = ExprNodes.ComprehensionAppendNode(pos, expr=expr)
loop = p_comp_for(s, append)
loop = p_comp_for(s, append, first_in_genexp=has_local_scope)
s.expect(']')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, type = Builtin.list_type,
pos, loop=loop, append=append, type=Builtin.list_type,
# list comprehensions leak their loop variable in Py2
has_local_scope = s.context.language_level >= 3)
has_local_scope=has_local_scope)
else:
if s.sy == ',':
s.next()
......@@ -923,23 +924,24 @@ def p_list_maker(s):
else:
exprs = [expr]
s.expect(']')
return ExprNodes.ListNode(pos, args = exprs)
return ExprNodes.ListNode(pos, args=exprs)
def p_comp_iter(s, body):
if s.sy == 'for':
return p_comp_for(s, body)
return p_comp_for(s, body, first_in_genexp=False)
elif s.sy == 'if':
return p_comp_if(s, body)
else:
# insert the 'append' operation into the loop
return body
def p_comp_for(s, body):
def p_comp_for(s, body, first_in_genexp):
# s.sy == 'for'
pos = s.position()
s.next()
kw = p_for_bounds(s, allow_testlist=False)
kw.update(else_clause = None, body = p_comp_iter(s, body))
kw.update(else_clause=None, body=p_comp_iter(s, body),
first_in_genexp=first_in_genexp)
return Nodes.ForStatNode(pos, **kw)
def p_comp_if(s, body):
......@@ -976,7 +978,7 @@ def p_dict_or_set_maker(s):
# set comprehension
append = ExprNodes.ComprehensionAppendNode(
item.pos, expr=item)
loop = p_comp_for(s, append)
loop = p_comp_for(s, append, first_in_genexp=True)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, type=Builtin.set_type)
......@@ -989,7 +991,7 @@ def p_dict_or_set_maker(s):
# dict comprehension
append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr=key, value_expr=value)
loop = p_comp_for(s, append)
loop = p_comp_for(s, append, first_in_genexp=True)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, type=Builtin.dict_type)
......@@ -1087,8 +1089,9 @@ def p_testlist_comp(s):
def p_genexp(s, expr):
# s.sy == 'for'
loop = p_comp_for(s, Nodes.ExprStatNode(
expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr)))
yield_node = Nodes.ExprStatNode(
expr.pos, expr=ExprNodes.YieldExprNode(expr.pos, arg=expr))
loop = p_comp_for(s, yield_node, first_in_genexp=True)
return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop)
expr_terminators = cython.declare(set, set([
......
......@@ -1564,6 +1564,43 @@ class LocalScope(Scope):
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
class ForeignName(str):
"""
String wrapper to store unnamed entries in Scope.entries dict.
"""
def __hash__(self):
return str.__hash__(self) + 1
def __eq__(self, other):
if self is other:
return True
return type(self) is type(other) and str.__eq__(self, other)
class NonLocalScopeWrapper(object):
"""
Wrapper around a local scope that inherits all names from the outer scope.
Used in generator expressions to analyse the outermost iterable.
"""
def __init__(self, scope):
self._scope = scope
self._lookup_outer = scope.outer_scope.lookup
def lookup(self, name):
entry = self._lookup_outer(name)
if entry and entry.scope.is_closure_scope:
entry.in_closure = True
inner_entry = InnerEntry(entry, self._scope)
inner_entry.is_variable = True
# do not overwrite locally declared names
self._scope.entries[ForeignName(name)] = inner_entry
return inner_entry
return entry
def __getattr__(self, name):
return getattr(self._scope, name)
class GeneratorExpressionScope(Scope):
"""Scope for generator expressions and comprehensions. As opposed
to generators, these can be easily inlined in some cases, so all
......
......@@ -7,7 +7,6 @@ missing_baseclass_in_predecl_T262
cfunc_call_tuple_args_T408
cpp_structs
closure_inside_cdef_T554
genexpr_iterable_lookup_T600
generator_expressions_in_class
for_from_pyvar_loop_T601
temp_sideeffects_T654 # not really a bug, Cython warns about it
......
......@@ -43,3 +43,13 @@ def genexpr_in_listcomp(L):
[[1, 2, 3], [1, 2, 3]]
"""
return [list(d for d in z) for z in L]
@cython.test_assert_path_exists('//ForFromStatNode')
def genexpr_range_in_listcomp(L):
"""
>>> genexpr_range_in_listcomp( [1,2,3] )
[[0], [0, 1], [0, 1, 2]]
"""
cdef int z,d
return [list(d for d in range(z)) for z in L]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment