Commit 614e0729 authored by Stefan Behnel's avatar Stefan Behnel

repair scoped comprehensions inside of generators and async functions by...

repair scoped comprehensions inside of generators and async functions by allowing their loop variables to be added to closures when necessary
parent 36d85bbc
...@@ -7878,9 +7878,8 @@ class ScopedExprNode(ExprNode): ...@@ -7878,9 +7878,8 @@ class ScopedExprNode(ExprNode):
code.putln('{ /* enter inner scope */') code.putln('{ /* enter inner scope */')
py_entries = [] py_entries = []
for entry in self.expr_scope.var_entries: for _, entry in sorted(item for item in self.expr_scope.entries.items() if item[0]):
if not entry.in_closure: if not entry.in_closure:
code.put_var_declaration(entry)
if entry.type.is_pyobject and entry.used: if entry.type.is_pyobject and entry.used:
py_entries.append(entry) py_entries.append(entry)
if not py_entries: if not py_entries:
...@@ -7890,14 +7889,14 @@ class ScopedExprNode(ExprNode): ...@@ -7890,14 +7889,14 @@ class ScopedExprNode(ExprNode):
return return
# must free all local Python references at each exit point # must free all local Python references at each exit point
old_loop_labels = tuple(code.new_loop_labels()) old_loop_labels = code.new_loop_labels()
old_error_label = code.new_error_label() old_error_label = code.new_error_label()
generate_inner_evaluation_code(code) generate_inner_evaluation_code(code)
# normal (non-error) exit # normal (non-error) exit
for entry in py_entries: for entry in py_entries:
code.put_var_decref(entry) code.put_var_decref_clear(entry)
# error/loop body exit points # error/loop body exit points
exit_scope = code.new_label('exit_scope') exit_scope = code.new_label('exit_scope')
...@@ -7907,7 +7906,7 @@ class ScopedExprNode(ExprNode): ...@@ -7907,7 +7906,7 @@ class ScopedExprNode(ExprNode):
if code.label_used(label): if code.label_used(label):
code.put_label(label) code.put_label(label)
for entry in py_entries: for entry in py_entries:
code.put_var_decref(entry) code.put_var_decref_clear(entry)
code.put_goto(old_label) code.put_goto(old_label)
code.put_label(exit_scope) code.put_label(exit_scope)
code.putln('} /* exit inner scope */') code.putln('} /* exit inner scope */')
......
...@@ -4084,7 +4084,7 @@ class GeneratorBodyDefNode(DefNode): ...@@ -4084,7 +4084,7 @@ class GeneratorBodyDefNode(DefNode):
code.putln("PyObject *%s = NULL;" % Naming.retval_cname) code.putln("PyObject *%s = NULL;" % Naming.retval_cname)
tempvardecl_code = code.insertion_point() tempvardecl_code = code.insertion_point()
code.put_declare_refcount_context() code.put_declare_refcount_context()
code.put_setup_refcount_context(self.entry.name) code.put_setup_refcount_context(self.entry.name or self.entry.qualified_name)
profile = code.globalstate.directives['profile'] profile = code.globalstate.directives['profile']
linetrace = code.globalstate.directives['linetrace'] linetrace = code.globalstate.directives['linetrace']
if profile or linetrace: if profile or linetrace:
...@@ -4119,7 +4119,7 @@ class GeneratorBodyDefNode(DefNode): ...@@ -4119,7 +4119,7 @@ class GeneratorBodyDefNode(DefNode):
# ----- Function body # ----- Function body
self.generate_function_body(env, code) self.generate_function_body(env, code)
# ----- Closure initialization # ----- Closure initialization
if lenv.scope_class.type.scope.entries: if lenv.scope_class.type.scope.var_entries:
closure_init_code.putln('%s = %s;' % ( closure_init_code.putln('%s = %s;' % (
lenv.scope_class.type.declaration_code(Naming.cur_scope_cname), lenv.scope_class.type.declaration_code(Naming.cur_scope_cname),
lenv.scope_class.type.cast_code('%s->closure' % lenv.scope_class.type.cast_code('%s->closure' %
......
...@@ -2599,24 +2599,28 @@ class CreateClosureClasses(CythonTransform): ...@@ -2599,24 +2599,28 @@ class CreateClosureClasses(CythonTransform):
def find_entries_used_in_closures(self, node): def find_entries_used_in_closures(self, node):
from_closure = [] from_closure = []
in_closure = [] in_closure = []
for name, entry in node.local_scope.entries.items(): for scope in node.local_scope.iter_local_scopes():
if entry.from_closure: for name, entry in scope.entries.items():
from_closure.append((name, entry)) if not name:
elif entry.in_closure: continue
in_closure.append((name, entry)) if entry.from_closure:
from_closure.append((name, entry))
elif entry.in_closure:
in_closure.append((name, entry))
return from_closure, in_closure return from_closure, in_closure
def create_class_from_scope(self, node, target_module_scope, inner_node=None): def create_class_from_scope(self, node, target_module_scope, inner_node=None):
# move local variables into closure # move local variables into closure
if node.is_generator: if node.is_generator:
for entry in node.local_scope.entries.values(): for scope in node.local_scope.iter_local_scopes():
if not entry.from_closure: for entry in scope.entries.values():
entry.in_closure = True if not entry.from_closure:
entry.in_closure = True
from_closure, in_closure = self.find_entries_used_in_closures(node) from_closure, in_closure = self.find_entries_used_in_closures(node)
in_closure.sort() in_closure.sort()
# Now from the begining # Now from the beginning
node.needs_closure = False node.needs_closure = False
node.needs_outer_scope = False node.needs_outer_scope = False
...@@ -2668,11 +2672,12 @@ class CreateClosureClasses(CythonTransform): ...@@ -2668,11 +2672,12 @@ class CreateClosureClasses(CythonTransform):
is_cdef=True) is_cdef=True)
node.needs_outer_scope = True node.needs_outer_scope = True
for name, entry in in_closure: for name, entry in in_closure:
closure_entry = class_scope.declare_var(pos=entry.pos, closure_entry = class_scope.declare_var(
name=entry.name, pos=entry.pos,
cname=entry.cname, name=entry.name if not entry.in_subscope else None,
type=entry.type, cname=entry.cname,
is_cdef=True) type=entry.type,
is_cdef=True)
if entry.is_declared_generic: if entry.is_declared_generic:
closure_entry.is_declared_generic = 1 closure_entry.is_declared_generic = 1
node.needs_closure = True node.needs_closure = True
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
from __future__ import absolute_import from __future__ import absolute_import
import copy
import re import re
import copy
import operator
try: try:
import __builtin__ as builtins import __builtin__ as builtins
...@@ -88,6 +89,7 @@ class Entry(object): ...@@ -88,6 +89,7 @@ class Entry(object):
# is_arg boolean Is the arg of a method # is_arg boolean Is the arg of a method
# is_local boolean Is a local variable # is_local boolean Is a local variable
# in_closure boolean Is referenced in an inner scope # in_closure boolean Is referenced in an inner scope
# in_subscope boolean Belongs to a generator expression scope
# is_readonly boolean Can't be assigned to # is_readonly boolean Can't be assigned to
# func_cname string C func implementing Python func # func_cname string C func implementing Python func
# func_modifiers [string] C function modifiers ('inline') # func_modifiers [string] C function modifiers ('inline')
...@@ -163,6 +165,7 @@ class Entry(object): ...@@ -163,6 +165,7 @@ class Entry(object):
is_local = 0 is_local = 0
in_closure = 0 in_closure = 0
from_closure = 0 from_closure = 0
in_subscope = 0
is_declared_generic = 0 is_declared_generic = 0
is_readonly = 0 is_readonly = 0
pyfunc_cname = None pyfunc_cname = None
...@@ -299,6 +302,7 @@ class Scope(object): ...@@ -299,6 +302,7 @@ class Scope(object):
is_py_class_scope = 0 is_py_class_scope = 0
is_c_class_scope = 0 is_c_class_scope = 0
is_closure_scope = 0 is_closure_scope = 0
is_genexpr_scope = 0
is_passthrough = 0 is_passthrough = 0
is_cpp_class_scope = 0 is_cpp_class_scope = 0
is_property_scope = 0 is_property_scope = 0
...@@ -308,6 +312,7 @@ class Scope(object): ...@@ -308,6 +312,7 @@ class Scope(object):
in_cinclude = 0 in_cinclude = 0
nogil = 0 nogil = 0
fused_to_specific = None fused_to_specific = None
return_type = None
def __init__(self, name, outer_scope, parent_scope): def __init__(self, name, outer_scope, parent_scope):
# The outer_scope is the next scope in the lookup chain. # The outer_scope is the next scope in the lookup chain.
...@@ -324,6 +329,7 @@ class Scope(object): ...@@ -324,6 +329,7 @@ class Scope(object):
self.qualified_name = EncodedString(name) self.qualified_name = EncodedString(name)
self.scope_prefix = mangled_name self.scope_prefix = mangled_name
self.entries = {} self.entries = {}
self.subscopes = set()
self.const_entries = [] self.const_entries = []
self.type_entries = [] self.type_entries = []
self.sue_entries = [] self.sue_entries = []
...@@ -341,7 +347,6 @@ class Scope(object): ...@@ -341,7 +347,6 @@ class Scope(object):
self.obj_to_entry = {} self.obj_to_entry = {}
self.buffer_entries = [] self.buffer_entries = []
self.lambda_defs = [] self.lambda_defs = []
self.return_type = None
self.id_counters = {} self.id_counters = {}
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
...@@ -419,6 +424,12 @@ class Scope(object): ...@@ -419,6 +424,12 @@ class Scope(object):
""" Return the module-level scope containing this scope. """ """ Return the module-level scope containing this scope. """
return self.outer_scope.builtin_scope() return self.outer_scope.builtin_scope()
def iter_local_scopes(self):
yield self
if self.subscopes:
for scope in sorted(self.subscopes, key=operator.attrgetter('scope_prefix')):
yield scope
def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0, create_wrapper = 0): def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0, create_wrapper = 0):
# Create new entry, and add to dictionary if # Create new entry, and add to dictionary if
# name is not None. Reports a warning if already # name is not None. Reports a warning if already
...@@ -1690,18 +1701,19 @@ class LocalScope(Scope): ...@@ -1690,18 +1701,19 @@ class LocalScope(Scope):
return entry return entry
def mangle_closure_cnames(self, outer_scope_cname): def mangle_closure_cnames(self, outer_scope_cname):
for entry in self.entries.values(): for scope in self.iter_local_scopes():
if entry.from_closure: for entry in scope.entries.values():
cname = entry.outer_entry.cname if entry.from_closure:
if self.is_passthrough: cname = entry.outer_entry.cname
entry.cname = cname if self.is_passthrough:
else: entry.cname = cname
if cname.startswith(Naming.cur_scope_cname): else:
cname = cname[len(Naming.cur_scope_cname)+2:] if cname.startswith(Naming.cur_scope_cname):
entry.cname = "%s->%s" % (outer_scope_cname, cname) cname = cname[len(Naming.cur_scope_cname)+2:]
elif entry.in_closure: entry.cname = "%s->%s" % (outer_scope_cname, cname)
entry.original_cname = entry.cname elif entry.in_closure:
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname) entry.original_cname = entry.cname
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
class GeneratorExpressionScope(Scope): class GeneratorExpressionScope(Scope):
...@@ -1709,12 +1721,19 @@ class GeneratorExpressionScope(Scope): ...@@ -1709,12 +1721,19 @@ class GeneratorExpressionScope(Scope):
to generators, these can be easily inlined in some cases, so all to generators, these can be easily inlined in some cases, so all
we really need is a scope that holds the loop variable(s). we really need is a scope that holds the loop variable(s).
""" """
is_genexpr_scope = True
def __init__(self, outer_scope): def __init__(self, outer_scope):
name = outer_scope.global_scope().next_id(Naming.genexpr_id_ref) name = outer_scope.global_scope().next_id(Naming.genexpr_id_ref)
Scope.__init__(self, name, outer_scope, outer_scope) Scope.__init__(self, name, outer_scope, outer_scope)
self.var_entries = outer_scope.var_entries # keep declarations outside
self.directives = outer_scope.directives self.directives = outer_scope.directives
self.genexp_prefix = "%s%d%s" % (Naming.pyrex_prefix, len(name), name) self.genexp_prefix = "%s%d%s" % (Naming.pyrex_prefix, len(name), name)
while outer_scope.is_genexpr_scope:
outer_scope = outer_scope.outer_scope
outer_scope.subscopes.add(self)
def mangle(self, prefix, name): def mangle(self, prefix, name):
return '%s%s' % (self.genexp_prefix, self.parent_scope.mangle(prefix, name)) return '%s%s' % (self.genexp_prefix, self.parent_scope.mangle(prefix, name))
...@@ -1730,8 +1749,9 @@ class GeneratorExpressionScope(Scope): ...@@ -1730,8 +1749,9 @@ class GeneratorExpressionScope(Scope):
# this scope must hold its name exclusively # this scope must hold its name exclusively
cname = '%s%s' % (self.genexp_prefix, self.parent_scope.mangle(Naming.var_prefix, name or self.next_id())) cname = '%s%s' % (self.genexp_prefix, self.parent_scope.mangle(Naming.var_prefix, name or self.next_id()))
entry = self.declare(name, cname, type, pos, visibility) entry = self.declare(name, cname, type, pos, visibility)
entry.is_variable = 1 entry.is_variable = True
entry.is_local = 1 entry.is_local = True
entry.in_subscope = True
self.var_entries.append(entry) self.var_entries.append(entry)
self.entries[name] = entry self.entries[name] = entry
return entry return entry
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment