Commit a8d3d9a0 authored by Vitja Makarov's avatar Vitja Makarov

Closure optimization

parent d5ceb489
No related merge requests found
......@@ -4838,10 +4838,14 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
class InnerFunctionNode(PyCFunctionNode):
# Special PyCFunctionNode that depends on a closure class
#
binding = True
needs_self_code = True
def self_result_code(self):
return "((PyObject*)%s)" % Naming.cur_scope_cname
if self.needs_self_code:
return "((PyObject*)%s)" % (Naming.cur_scope_cname)
return "NULL"
class LambdaNode(InnerFunctionNode):
# Lambda expression node (only used as a function reference)
......@@ -4859,7 +4863,6 @@ class LambdaNode(InnerFunctionNode):
name = StringEncoding.EncodedString('<lambda>')
def analyse_declarations(self, env):
#self.def_node.needs_closure = self.needs_closure
self.def_node.analyse_declarations(env)
self.pymethdef_cname = self.def_node.entry.pymethdef_cname
env.add_lambda_def(self.def_node)
......
......@@ -134,7 +134,6 @@ class Context(object):
WithTransform(self),
DecoratorTransform(self),
AnalyseDeclarationsTransform(self),
CreateClosureClasses(self),
AutoTestDictTransform(self),
EmbedSignature(self),
EarlyReplaceBuiltinCalls(self), ## Necessary?
......@@ -144,6 +143,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self),
_check_c_declarations,
AnalyseExpressionsTransform(self),
CreateClosureClasses(self), ## After all lookups and type inference
ExpandInplaceOperators(self),
OptimizeBuiltinCalls(self), ## Necessary?
IterationTransform(),
......
......@@ -1146,11 +1146,13 @@ class FuncDefNode(StatNode, BlockNode):
# #filename string C name of filename string const
# entry Symtab.Entry
# needs_closure boolean Whether or not this function has inner functions/classes/yield
# needs_outer_scope boolean Whether or not this function requires outer scope
# directive_locals { string : NameNode } locals defined by cython.locals(...)
py_func = None
assmt = None
needs_closure = False
needs_outer_scope = False
modifiers = []
def analyse_default_values(self, env):
......@@ -1198,7 +1200,7 @@ class FuncDefNode(StatNode, BlockNode):
import Buffer
lenv = self.local_scope
if lenv.is_closure_scope:
if lenv.is_closure_scope and not lenv.is_passthrough:
outer_scope_cname = "%s->%s" % (Naming.cur_scope_cname,
Naming.outer_scope_cname)
else:
......@@ -1259,10 +1261,13 @@ class FuncDefNode(StatNode, BlockNode):
cenv = env
while cenv.is_py_class_scope or cenv.is_c_class_scope:
cenv = cenv.outer_scope
if lenv.is_closure_scope:
if self.needs_closure:
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";")
elif cenv.is_closure_scope:
elif self.needs_outer_scope:
if lenv.is_passthrough:
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";")
code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname))
code.putln(";")
self.generate_argument_declarations(lenv, code)
......@@ -1314,12 +1319,14 @@ class FuncDefNode(StatNode, BlockNode):
code.putln("}")
code.put_gotref(Naming.cur_scope_cname)
# Note that it is unsafe to decref the scope at this point.
if cenv.is_closure_scope:
if self.needs_outer_scope:
code.putln("%s = (%s)%s;" % (
outer_scope_cname,
cenv.scope_class.type.declaration_code(''),
Naming.self_cname))
if self.needs_closure:
if lenv.is_passthrough:
code.putln("%s = %s;" % (Naming.cur_scope_cname, outer_scope_cname));
elif self.needs_closure:
# inner closures own a reference to their outer parent
code.put_incref(outer_scope_cname, cenv.scope_class.type)
code.put_giveref(outer_scope_cname)
......
......@@ -1317,16 +1317,58 @@ class MarkClosureVisitor(CythonTransform):
class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions
# that need it.
# that really need it.
def __init__(self, context):
super(CreateClosureClasses, self).__init__(context)
self.path = []
self.in_lambda = False
def visit_ModuleNode(self, node):
self.module_scope = node.scope
self.visitchildren(node)
return node
def create_class_from_scope(self, node, target_module_scope):
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
def get_scope_use(self, node):
from_closure = []
in_closure = []
for name, entry in node.local_scope.entries.items():
if entry.from_closure:
from_closure.append((name, entry))
elif entry.in_closure and not entry.from_closure:
in_closure.append((name, entry))
return from_closure, in_closure
def create_class_from_scope(self, node, target_module_scope, inner_node=None):
from_closure, in_closure = self.get_scope_use(node)
in_closure.sort()
# Now from the begining
node.needs_closure = False
node.needs_outer_scope = False
func_scope = node.local_scope
cscope = node.entry.scope
while cscope.is_py_class_scope or cscope.is_c_class_scope:
cscope = cscope.outer_scope
if not from_closure and self.path:
if not inner_node:
if not node.assmt:
raise InternalError, "DefNode does not have assignment node"
inner_node = node.assmt.rhs
inner_node.needs_self_code = False
node.needs_outer_scope = False
# Simple cases
if not in_closure and not from_closure:
return
elif not in_closure:
func_scope.is_passthrough = True
func_scope.scope_class = cscope.scope_class
node.needs_outer_scope = True
return
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True)
......@@ -1335,34 +1377,41 @@ class CreateClosureClasses(CythonTransform):
class_scope.is_internal = True
class_scope.directives = {'final': True}
cscope = node.entry.scope
while cscope.is_py_class_scope or cscope.is_c_class_scope:
cscope = cscope.outer_scope
if cscope.is_closure_scope:
if from_closure:
assert cscope.is_closure_scope
class_scope.declare_var(pos=node.pos,
name=Naming.outer_scope_cname, # this could conflict?
name=Naming.outer_scope_cname,
cname=Naming.outer_scope_cname,
type=cscope.scope_class.type,
is_cdef=True)
entries = func_scope.entries.items()
entries.sort()
for name, entry in entries:
# This is wasteful--we should do this later when we know
# which vars are actually being used inside...
#
# Also, this happens before type inference and type
# analysis, so the entries created here may end up having
# incorrect or at least unspecified types.
node.needs_outer_scope = True
for name, entry in in_closure:
class_scope.declare_var(pos=entry.pos,
name=entry.name,
cname=entry.cname,
type=entry.type,
is_cdef=True)
node.needs_closure = True
# Do it here because other classes are already checked
target_module_scope.check_c_class(func_scope.scope_class)
def visit_LambdaNode(self, node):
was_in_lambda = self.in_lambda
self.in_lambda = True
self.create_class_from_scope(node.def_node, self.module_scope, node)
self.visitchildren(node)
self.in_lambda = was_in_lambda
return node
def visit_FuncDefNode(self, node):
if node.needs_closure:
if self.in_lambda:
self.visitchildren(node)
return node
if node.needs_closure or self.path:
self.create_class_from_scope(node, self.module_scope)
self.path.append(node)
self.visitchildren(node)
self.path.pop()
return node
......
......@@ -211,7 +211,8 @@ class Scope(object):
# return_type PyrexType or None Return type of function owning scope
# is_py_class_scope boolean Is a Python class scope
# is_c_class_scope boolean Is an extension type scope
# is_closure_scope boolean
# is_closure_scope boolean Is a closure scope
# is_passthrough boolean Outer scope is passed directly
# is_cpp_class_scope boolean Is a C++ class scope
# is_property_scope boolean Is a extension type property scope
# scope_prefix string Disambiguator for C names
......@@ -228,6 +229,7 @@ class Scope(object):
is_py_class_scope = 0
is_c_class_scope = 0
is_closure_scope = 0
is_passthrough = 0
is_cpp_class_scope = 0
is_property_scope = 0
is_module_scope = 0
......@@ -1121,7 +1123,30 @@ class ModuleScope(Scope):
# Check defined
if not entry.type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % entry.name)
def check_c_class(self, entry):
type = entry.type
name = entry.name
visibility = entry.visibility
# Check defined
if not type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % name)
# Generate typeobj_cname
if visibility != 'extern' and not type.typeobj_cname:
type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
## Generate typeptr_cname
#type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
# Check C methods defined
if type.scope:
for method_entry in type.scope.cfunc_entries:
if not method_entry.is_inherited and not method_entry.func_cname:
error(method_entry.pos, "C method '%s' is declared but not defined" %
method_entry.name)
# Allocate vtable name if necessary
if type.vtabslot_cname:
#print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
def check_c_classes(self):
# Performs post-analysis checking and finishing up of extension types
# being implemented in this module. This is called only for the main
......@@ -1144,28 +1169,8 @@ class ModuleScope(Scope):
print("...entry %s %s" % (entry.name, entry))
print("......type = ", entry.type)
print("......visibility = ", entry.visibility)
type = entry.type
name = entry.name
visibility = entry.visibility
# Check defined
if not type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % name)
# Generate typeobj_cname
if visibility != 'extern' and not type.typeobj_cname:
type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
## Generate typeptr_cname
#type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
# Check C methods defined
if type.scope:
for method_entry in type.scope.cfunc_entries:
if not method_entry.is_inherited and not method_entry.func_cname:
error(method_entry.pos, "C method '%s' is declared but not defined" %
method_entry.name)
# Allocate vtable name if necessary
if type.vtabslot_cname:
#print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
self.check_c_class(entry)
def check_c_functions(self):
# Performs post-analysis checking making sure all
# defined c functions are actually implemented.
......@@ -1253,6 +1258,8 @@ class LocalScope(Scope):
entry = Scope.lookup(self, name)
if entry is not None:
if entry.scope is not self and entry.scope.is_closure_scope:
if hasattr(entry.scope, "scope_class"):
raise InternalError, "lookup() after scope class created."
# The actual c fragment for the different scopes differs
# on the outside and inside, so we make a new entry
entry.in_closure = True
......@@ -1270,14 +1277,16 @@ class LocalScope(Scope):
for entry in self.entries.values():
if entry.from_closure:
cname = entry.outer_entry.cname
if cname.startswith(Naming.cur_scope_cname):
cname = cname[len(Naming.cur_scope_cname)+2:]
entry.cname = "%s->%s" % (outer_scope_cname, cname)
if self.is_passthrough:
entry.cname = cname
else:
if cname.startswith(Naming.cur_scope_cname):
cname = cname[len(Naming.cur_scope_cname)+2:]
entry.cname = "%s->%s" % (outer_scope_cname, cname)
elif entry.in_closure:
entry.original_cname = entry.cname
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
class GeneratorExpressionScope(LocalScope):
"""Scope for generator expressions and comprehensions. As opposed
to generators, these can be easily inlined in some cases, so all
......
......@@ -225,8 +225,6 @@ class SimpleAssignmentTypeInferer(object):
for entry in scope.entries.values():
if entry.type is unspecified_type:
entry.type = py_object_type
if scope.is_closure_scope:
fix_closure_entries(scope)
return
dependancies_by_entry = {} # entry -> dependancies
......@@ -288,19 +286,6 @@ class SimpleAssignmentTypeInferer(object):
entry.type = py_object_type
if verbose:
message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type))
#if scope.is_closure_scope:
# fix_closure_entries(scope)
def fix_closure_entries(scope):
"""Temporary work-around to fix field types in the closure class
that were unknown at the time of creation and only determined
during type inference.
"""
closure_entries = scope.scope_class.type.scope.entries
for name, entry in scope.entries.iteritems():
if name in closure_entries:
closure_entry = closure_entries[name]
closure_entry.type = entry.type
def find_spanning_type(type1, type2):
if type1 is type2:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment