Commit 87a9bc29 authored by Stefan Behnel's avatar Stefan Behnel

fix T600: lookup of iterables in genexpr must use outer scope

parent c117dddd
...@@ -2397,20 +2397,26 @@ class NextNode(AtomicExprNode): ...@@ -2397,20 +2397,26 @@ class NextNode(AtomicExprNode):
# #
# iterator IteratorNode # iterator IteratorNode
def __init__(self, iterator): def __init__(self, iterator, lives_in_outer_scope=False):
AtomicExprNode.__init__(self, iterator.pos) AtomicExprNode.__init__(self, iterator.pos)
self.lives_in_outer_scope = lives_in_outer_scope
self.iterator = iterator self.iterator = iterator
def type_dependencies(self, env): def type_dependencies(self, env):
if self.lives_in_outer_scope:
env = Symtab.NonLocalScopeWrapper(env)
return self.iterator.type_dependencies(env) return self.iterator.type_dependencies(env)
def infer_type(self, env, iterator_type = None): def infer_type(self, env, iterator_type=None):
if self.lives_in_outer_scope:
env = Symtab.NonLocalScopeWrapper(env)
if iterator_type is None: if iterator_type is None:
iterator_type = self.iterator.infer_type(env) iterator_type = self.iterator.infer_type(env)
if iterator_type.is_ptr or iterator_type.is_array: if iterator_type.is_ptr or iterator_type.is_array:
return iterator_type.base_type return iterator_type.base_type
elif iterator_type.is_cpp_class: elif iterator_type.is_cpp_class:
item_type = env.lookup_operator_for_types(self.pos, "*", [iterator_type]).type.return_type item_type = env.lookup_operator_for_types(
self.pos, "*", [iterator_type]).type.return_type
if item_type.is_reference: if item_type.is_reference:
item_type = item_type.ref_base_type item_type = item_type.ref_base_type
if item_type.is_const: if item_type.is_const:
...@@ -2418,7 +2424,8 @@ class NextNode(AtomicExprNode): ...@@ -2418,7 +2424,8 @@ class NextNode(AtomicExprNode):
return item_type return item_type
else: else:
# Avoid duplication of complicated logic. # Avoid duplication of complicated logic.
fake_index_node = IndexNode(self.pos, fake_index_node = IndexNode(
self.pos,
base=self.iterator.sequence, base=self.iterator.sequence,
index=IntNode(self.pos, value='0')) index=IntNode(self.pos, value='0'))
return fake_index_node.infer_type(env) return fake_index_node.infer_type(env)
......
...@@ -10,6 +10,7 @@ import Builtin ...@@ -10,6 +10,7 @@ import Builtin
import ExprNodes import ExprNodes
import Nodes import Nodes
import Options import Options
import Symtab
from PyrexTypes import py_object_type, unspecified_type from PyrexTypes import py_object_type, unspecified_type
import PyrexTypes import PyrexTypes
...@@ -984,11 +985,18 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -984,11 +985,18 @@ class ControlFlowAnalysis(CythonTransform):
next_block = self.flow.newblock() next_block = self.flow.newblock()
# Condition with iterator # Condition with iterator
self.flow.loops.append(LoopDescr(next_block, condition_block)) self.flow.loops.append(LoopDescr(next_block, condition_block))
is_normal_for_loop = isinstance(node, Nodes.ForInStatNode)
if is_normal_for_loop and node.first_in_genexp:
self.env_stack.append(self.env)
self.env = Symtab.NonLocalScopeWrapper(self.env)
self._visit(node.iterator) self._visit(node.iterator)
if is_normal_for_loop and node.first_in_genexp:
self.env = self.env_stack.pop()
# Target assignment # Target assignment
self.flow.nextblock() self.flow.nextblock()
if is_normal_for_loop:
if isinstance(node, Nodes.ForInStatNode):
self.mark_forloop_target(node) self.mark_forloop_target(node)
else: # Parallel else: # Parallel
self.mark_assignment(node.target) self.mark_assignment(node.target)
......
...@@ -20,7 +20,7 @@ import PyrexTypes ...@@ -20,7 +20,7 @@ import PyrexTypes
import TypeSlots import TypeSlots
from PyrexTypes import py_object_type, error_type from PyrexTypes import py_object_type, error_type
from Symtab import ModuleScope, LocalScope, ClosureScope, \ from Symtab import ModuleScope, LocalScope, ClosureScope, \
StructOrUnionScope, PyClassScope, CppClassScope StructOrUnionScope, PyClassScope, CppClassScope, NonLocalScopeWrapper
from Code import UtilityCode from Code import UtilityCode
from StringEncoding import EncodedString, escape_byte_string, split_string_literal from StringEncoding import EncodedString, escape_byte_string, split_string_literal
import Options import Options
...@@ -5560,12 +5560,14 @@ class DictIterationNextNode(Node): ...@@ -5560,12 +5560,14 @@ class DictIterationNextNode(Node):
target.generate_assignment_code(result, code) target.generate_assignment_code(result, code)
var.release(code) var.release(code)
def ForStatNode(pos, **kw): def ForStatNode(pos, **kw):
if 'iterator' in kw: if 'iterator' in kw:
return ForInStatNode(pos, **kw) return ForInStatNode(pos, **kw)
else: else:
return ForFromStatNode(pos, **kw) return ForFromStatNode(pos, **kw)
class ForInStatNode(LoopNode, StatNode): class ForInStatNode(LoopNode, StatNode):
# for statement # for statement
# #
...@@ -5574,9 +5576,11 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -5574,9 +5576,11 @@ class ForInStatNode(LoopNode, StatNode):
# body StatNode # body StatNode
# else_clause StatNode # else_clause StatNode
# item NextNode used internally # item NextNode used internally
# first_in_genexp bool True for outer loop of a scoped genexp/comprehension
child_attrs = ["target", "iterator", "body", "else_clause"] child_attrs = ["target", "iterator", "body", "else_clause"]
item = None item = None
first_in_genexp = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
import ExprNodes import ExprNodes
...@@ -5584,18 +5588,25 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -5584,18 +5588,25 @@ class ForInStatNode(LoopNode, StatNode):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
if self.else_clause: if self.else_clause:
self.else_clause.analyse_declarations(env) self.else_clause.analyse_declarations(env)
self.item = ExprNodes.NextNode(self.iterator) self.item = ExprNodes.NextNode(
self.iterator, lives_in_outer_scope=self.first_in_genexp)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.target = self.target.analyse_target_types(env) self.target = self.target.analyse_target_types(env)
self.iterator = self.iterator.analyse_expressions(env) iterator_env = env
if self.first_in_genexp:
# outermost iterator in a genexpr or scoped comprehension is
# looked up in the outer scope
iterator_env = NonLocalScopeWrapper(iterator_env)
self.iterator = self.iterator.analyse_expressions(iterator_env)
if self.item is None: if self.item is None:
# Hack. Sometimes analyse_declarations not called. # Hack. Sometimes analyse_declarations not called.
import ExprNodes import ExprNodes
self.item = ExprNodes.NextNode(self.iterator) self.item = ExprNodes.NextNode(
self.iterator, lives_in_outer_scope=self.first_in_genexp)
self.item = self.item.analyse_expressions(env) self.item = self.item.analyse_expressions(env)
if (self.iterator.type.is_ptr or self.iterator.type.is_array) and \ if ((self.iterator.type.is_ptr or self.iterator.type.is_array) and
self.target.type.assignable_from(self.iterator.type): self.target.type.assignable_from(self.iterator.type)):
# C array slice optimization. # C array slice optimization.
pass pass
else: else:
......
...@@ -67,7 +67,7 @@ cdef bint check_for_non_ascii_characters(unicode string) ...@@ -67,7 +67,7 @@ cdef bint check_for_non_ascii_characters(unicode string)
cdef p_string_literal(PyrexScanner s, kind_override=*) cdef p_string_literal(PyrexScanner s, kind_override=*)
cdef p_list_maker(PyrexScanner s) cdef p_list_maker(PyrexScanner s)
cdef p_comp_iter(PyrexScanner s, body) cdef p_comp_iter(PyrexScanner s, body)
cdef p_comp_for(PyrexScanner s, body) cdef p_comp_for(PyrexScanner s, body, bint first_in_genexp)
cdef p_comp_if(PyrexScanner s, body) cdef p_comp_if(PyrexScanner s, body)
cdef p_dict_or_set_maker(PyrexScanner s) cdef p_dict_or_set_maker(PyrexScanner s)
cdef p_backquote_expr(PyrexScanner s) cdef p_backquote_expr(PyrexScanner s)
......
...@@ -909,13 +909,14 @@ def p_list_maker(s): ...@@ -909,13 +909,14 @@ def p_list_maker(s):
return ExprNodes.ListNode(pos, args = []) return ExprNodes.ListNode(pos, args = [])
expr = p_test(s) expr = p_test(s)
if s.sy == 'for': if s.sy == 'for':
has_local_scope = s.context.language_level >= 3
append = ExprNodes.ComprehensionAppendNode(pos, expr=expr) append = ExprNodes.ComprehensionAppendNode(pos, expr=expr)
loop = p_comp_for(s, append) loop = p_comp_for(s, append, first_in_genexp=has_local_scope)
s.expect(']') s.expect(']')
return ExprNodes.ComprehensionNode( return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, type = Builtin.list_type, pos, loop=loop, append=append, type=Builtin.list_type,
# list comprehensions leak their loop variable in Py2 # list comprehensions leak their loop variable in Py2
has_local_scope = s.context.language_level >= 3) has_local_scope=has_local_scope)
else: else:
if s.sy == ',': if s.sy == ',':
s.next() s.next()
...@@ -923,23 +924,24 @@ def p_list_maker(s): ...@@ -923,23 +924,24 @@ def p_list_maker(s):
else: else:
exprs = [expr] exprs = [expr]
s.expect(']') s.expect(']')
return ExprNodes.ListNode(pos, args = exprs) return ExprNodes.ListNode(pos, args=exprs)
def p_comp_iter(s, body): def p_comp_iter(s, body):
if s.sy == 'for': if s.sy == 'for':
return p_comp_for(s, body) return p_comp_for(s, body, first_in_genexp=False)
elif s.sy == 'if': elif s.sy == 'if':
return p_comp_if(s, body) return p_comp_if(s, body)
else: else:
# insert the 'append' operation into the loop # insert the 'append' operation into the loop
return body return body
def p_comp_for(s, body): def p_comp_for(s, body, first_in_genexp):
# s.sy == 'for' # s.sy == 'for'
pos = s.position() pos = s.position()
s.next() s.next()
kw = p_for_bounds(s, allow_testlist=False) kw = p_for_bounds(s, allow_testlist=False)
kw.update(else_clause = None, body = p_comp_iter(s, body)) kw.update(else_clause=None, body=p_comp_iter(s, body),
first_in_genexp=first_in_genexp)
return Nodes.ForStatNode(pos, **kw) return Nodes.ForStatNode(pos, **kw)
def p_comp_if(s, body): def p_comp_if(s, body):
...@@ -976,7 +978,7 @@ def p_dict_or_set_maker(s): ...@@ -976,7 +978,7 @@ def p_dict_or_set_maker(s):
# set comprehension # set comprehension
append = ExprNodes.ComprehensionAppendNode( append = ExprNodes.ComprehensionAppendNode(
item.pos, expr=item) item.pos, expr=item)
loop = p_comp_for(s, append) loop = p_comp_for(s, append, first_in_genexp=True)
s.expect('}') s.expect('}')
return ExprNodes.ComprehensionNode( return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, type=Builtin.set_type) pos, loop=loop, append=append, type=Builtin.set_type)
...@@ -989,7 +991,7 @@ def p_dict_or_set_maker(s): ...@@ -989,7 +991,7 @@ def p_dict_or_set_maker(s):
# dict comprehension # dict comprehension
append = ExprNodes.DictComprehensionAppendNode( append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr=key, value_expr=value) item.pos, key_expr=key, value_expr=value)
loop = p_comp_for(s, append) loop = p_comp_for(s, append, first_in_genexp=True)
s.expect('}') s.expect('}')
return ExprNodes.ComprehensionNode( return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, type=Builtin.dict_type) pos, loop=loop, append=append, type=Builtin.dict_type)
...@@ -1087,8 +1089,9 @@ def p_testlist_comp(s): ...@@ -1087,8 +1089,9 @@ def p_testlist_comp(s):
def p_genexp(s, expr): def p_genexp(s, expr):
# s.sy == 'for' # s.sy == 'for'
loop = p_comp_for(s, Nodes.ExprStatNode( yield_node = Nodes.ExprStatNode(
expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr))) expr.pos, expr=ExprNodes.YieldExprNode(expr.pos, arg=expr))
loop = p_comp_for(s, yield_node, first_in_genexp=True)
return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop) return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop)
expr_terminators = cython.declare(set, set([ expr_terminators = cython.declare(set, set([
......
...@@ -1564,6 +1564,43 @@ class LocalScope(Scope): ...@@ -1564,6 +1564,43 @@ class LocalScope(Scope):
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname) entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
class ForeignName(str):
"""
String wrapper to store unnamed entries in Scope.entries dict.
"""
def __hash__(self):
return str.__hash__(self) + 1
def __eq__(self, other):
if self is other:
return True
return type(self) is type(other) and str.__eq__(self, other)
class NonLocalScopeWrapper(object):
"""
Wrapper around a local scope that inherits all names from the outer scope.
Used in generator expressions to analyse the outermost iterable.
"""
def __init__(self, scope):
self._scope = scope
self._lookup_outer = scope.outer_scope.lookup
def lookup(self, name):
entry = self._lookup_outer(name)
if entry and entry.scope.is_closure_scope:
entry.in_closure = True
inner_entry = InnerEntry(entry, self._scope)
inner_entry.is_variable = True
# do not overwrite locally declared names
self._scope.entries[ForeignName(name)] = inner_entry
return inner_entry
return entry
def __getattr__(self, name):
return getattr(self._scope, name)
class GeneratorExpressionScope(Scope): class GeneratorExpressionScope(Scope):
"""Scope for generator expressions and comprehensions. As opposed """Scope for generator expressions and comprehensions. As opposed
to generators, these can be easily inlined in some cases, so all to generators, these can be easily inlined in some cases, so all
......
...@@ -7,7 +7,6 @@ missing_baseclass_in_predecl_T262 ...@@ -7,7 +7,6 @@ missing_baseclass_in_predecl_T262
cfunc_call_tuple_args_T408 cfunc_call_tuple_args_T408
cpp_structs cpp_structs
closure_inside_cdef_T554 closure_inside_cdef_T554
genexpr_iterable_lookup_T600
generator_expressions_in_class generator_expressions_in_class
for_from_pyvar_loop_T601 for_from_pyvar_loop_T601
temp_sideeffects_T654 # not really a bug, Cython warns about it temp_sideeffects_T654 # not really a bug, Cython warns about it
......
...@@ -43,3 +43,13 @@ def genexpr_in_listcomp(L): ...@@ -43,3 +43,13 @@ def genexpr_in_listcomp(L):
[[1, 2, 3], [1, 2, 3]] [[1, 2, 3], [1, 2, 3]]
""" """
return [list(d for d in z) for z in L] return [list(d for d in z) for z in L]
@cython.test_assert_path_exists('//ForFromStatNode')
def genexpr_range_in_listcomp(L):
"""
>>> genexpr_range_in_listcomp( [1,2,3] )
[[0], [0, 1], [0, 1, 2]]
"""
cdef int z,d
return [list(d for d in range(z)) for z in L]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment