Commit a504dae0 authored by Robert Bradshaw's avatar Robert Bradshaw

First pass at closures

parent b712ad56
...@@ -3621,6 +3621,32 @@ class ClassNode(ExprNode): ...@@ -3621,6 +3621,32 @@ class ClassNode(ExprNode):
code.error_goto_if_null(self.result(), self.pos))) code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
class BoundMethodNode(ExprNode):
# Helper class used in the implementation of Python
# class definitions. Constructs an bound method
# object from a class and a function.
#
# function ExprNode Function object
# self_object ExprNode self object
subexprs = ['function']
def analyse_types(self, env):
self.function.analyse_types(env)
self.type = py_object_type
self.is_temp = 1
gil_message = "Constructing an bound method"
def generate_result_code(self, code):
code.putln(
"%s = PyMethod_New(%s, %s, (PyObject*)%s->ob_type); %s" % (
self.result(),
self.function.py_result(),
self.self_object.py_result(),
self.self_object.py_result(),
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result())
class UnboundMethodNode(ExprNode): class UnboundMethodNode(ExprNode):
# Helper class used in the implementation of Python # Helper class used in the implementation of Python
...@@ -3654,6 +3680,9 @@ class PyCFunctionNode(AtomicExprNode): ...@@ -3654,6 +3680,9 @@ class PyCFunctionNode(AtomicExprNode):
# from a PyMethodDef struct. # from a PyMethodDef struct.
# #
# pymethdef_cname string PyMethodDef structure # pymethdef_cname string PyMethodDef structure
# self_object ExprNode or None
self_object = None
def analyse_types(self, env): def analyse_types(self, env):
self.type = py_object_type self.type = py_object_type
...@@ -3662,10 +3691,15 @@ class PyCFunctionNode(AtomicExprNode): ...@@ -3662,10 +3691,15 @@ class PyCFunctionNode(AtomicExprNode):
gil_message = "Constructing Python function" gil_message = "Constructing Python function"
def generate_result_code(self, code): def generate_result_code(self, code):
if self.self_object is None:
self_result = "NULL"
else:
self_result = self.self_object.py_result()
code.putln( code.putln(
"%s = PyCFunction_New(&%s, 0); %s" % ( "%s = PyCFunction_New(&%s, %s); %s" % (
self.result(), self.result(),
self.pymethdef_cname, self.pymethdef_cname,
self_result,
code.error_goto_if_null(self.result(), self.pos))) code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
......
...@@ -114,11 +114,13 @@ class Context(object): ...@@ -114,11 +114,13 @@ class Context(object):
_specific_post_parse, _specific_post_parse,
InterpretCompilerDirectives(self, self.pragma_overrides), InterpretCompilerDirectives(self, self.pragma_overrides),
_align_function_definitions, _align_function_definitions,
MarkClosureVisitor(self),
ConstantFolding(), ConstantFolding(),
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(self), WithTransform(self),
DecoratorTransform(self), DecoratorTransform(self),
AnalyseDeclarationsTransform(self), AnalyseDeclarationsTransform(self),
CreateClosureClasses(self),
EmbedSignature(self), EmbedSignature(self),
TransformBuiltinMethods(self), TransformBuiltinMethods(self),
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
......
...@@ -44,6 +44,8 @@ vtabptr_prefix = pyrex_prefix + "vtabptr_" ...@@ -44,6 +44,8 @@ vtabptr_prefix = pyrex_prefix + "vtabptr_"
vtabstruct_prefix = pyrex_prefix + "vtabstruct_" vtabstruct_prefix = pyrex_prefix + "vtabstruct_"
opt_arg_prefix = pyrex_prefix + "opt_args_" opt_arg_prefix = pyrex_prefix + "opt_args_"
convert_func_prefix = pyrex_prefix + "convert_" convert_func_prefix = pyrex_prefix + "convert_"
closure_scope_prefix = pyrex_prefix + "scope_"
closure_class_prefix = pyrex_prefix + "scope_struct_"
args_cname = pyrex_prefix + "args" args_cname = pyrex_prefix + "args"
pykwdlist_cname = pyrex_prefix + "pyargnames" pykwdlist_cname = pyrex_prefix + "pyargnames"
...@@ -81,8 +83,6 @@ pymoduledef_cname = pyrex_prefix + "moduledef" ...@@ -81,8 +83,6 @@ pymoduledef_cname = pyrex_prefix + "moduledef"
optional_args_cname = pyrex_prefix + "optional_args" optional_args_cname = pyrex_prefix + "optional_args"
import_star = pyrex_prefix + "import_star" import_star = pyrex_prefix + "import_star"
import_star_set = pyrex_prefix + "import_star_set" import_star_set = pyrex_prefix + "import_star_set"
cur_scope_cname = pyrex_prefix + "cur_scope"
enc_scope_cname = pyrex_prefix + "enc_scope"
line_c_macro = "__LINE__" line_c_macro = "__LINE__"
......
...@@ -11,7 +11,7 @@ import Naming ...@@ -11,7 +11,7 @@ import Naming
import PyrexTypes import PyrexTypes
import TypeSlots import TypeSlots
from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType
from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \ from Symtab import ModuleScope, LocalScope, ClosureScope, \
StructOrUnionScope, PyClassScope, CClassScope StructOrUnionScope, PyClassScope, CClassScope
from Cython.Utils import open_new_file, replace_suffix, UtilityCode from Cython.Utils import open_new_file, replace_suffix, UtilityCode
from StringEncoding import EncodedString, escape_byte_string, split_docstring from StringEncoding import EncodedString, escape_byte_string, split_docstring
...@@ -977,7 +977,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -977,7 +977,7 @@ class FuncDefNode(StatNode, BlockNode):
while env.is_py_class_scope or env.is_c_class_scope: while env.is_py_class_scope or env.is_c_class_scope:
env = env.outer_scope env = env.outer_scope
if self.needs_closure: if self.needs_closure:
lenv = GeneratorLocalScope(name = self.entry.name, outer_scope = genv) lenv = ClosureScope(name = self.entry.name, scope_name = self.entry.cname, outer_scope = genv)
else: else:
lenv = LocalScope(name = self.entry.name, outer_scope = genv) lenv = LocalScope(name = self.entry.name, outer_scope = genv)
lenv.return_type = self.return_type lenv.return_type = self.return_type
...@@ -992,6 +992,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -992,6 +992,8 @@ class FuncDefNode(StatNode, BlockNode):
import Buffer import Buffer
lenv = self.local_scope lenv = self.local_scope
# Generate closure function definitions
self.body.generate_function_definitions(lenv, code)
is_getbuffer_slot = (self.entry.name == "__getbuffer__" and is_getbuffer_slot = (self.entry.name == "__getbuffer__" and
self.entry.scope.is_c_class_scope) self.entry.scope.is_c_class_scope)
...@@ -1007,16 +1009,23 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1007,16 +1009,23 @@ class FuncDefNode(StatNode, BlockNode):
code.putln("") code.putln("")
if self.py_func: if self.py_func:
self.py_func.generate_function_header(code, self.py_func.generate_function_header(code,
with_pymethdef = env.is_py_class_scope, with_pymethdef = env.is_py_class_scope or env.is_closure_scope,
proto_only=True) proto_only=True)
self.generate_function_header(code, self.generate_function_header(code,
with_pymethdef = env.is_py_class_scope) with_pymethdef = env.is_py_class_scope or env.is_closure_scope)
# ----- Local variable declarations # ----- Local variable declarations
lenv.mangle_closure_cnames(Naming.cur_scope_cname) # lenv.mangle_closure_cnames(Naming.cur_scope_cname)
self.generate_argument_declarations(lenv, code)
if self.needs_closure: if self.needs_closure:
code.putln("/* TODO: declare and create scope object */") code.put(lenv.scope_class.type.declaration_code(lenv.closure_cname))
code.put_var_declarations(lenv.var_entries) code.putln(";")
else:
self.generate_argument_declarations(lenv, code)
code.put_var_declarations(lenv.var_entries)
if env.is_closure_scope:
code.putln("%s = (%s)%s;" % (
env.scope_class.type.declaration_code(env.closure_cname),
env.scope_class.type.declaration_code(''),
Naming.self_cname))
init = "" init = ""
if not self.return_type.is_void: if not self.return_type.is_void:
if self.return_type.is_pyobject: if self.return_type.is_pyobject:
...@@ -1040,6 +1049,21 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1040,6 +1049,21 @@ class FuncDefNode(StatNode, BlockNode):
code.put_setup_refcount_context(self.entry.name) code.put_setup_refcount_context(self.entry.name)
if is_getbuffer_slot: if is_getbuffer_slot:
self.getbuffer_init(code) self.getbuffer_init(code)
# ----- Create closure scope object
if self.needs_closure:
code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
lenv.closure_cname,
lenv.scope_class.type.declaration_code(''),
lenv.scope_class.type.typeptr_cname,
lenv.scope_class.type.typeptr_cname,
Naming.empty_tuple))
# TODO: error handling
# The code below assumes the local variables are innitially NULL
# Note that it is unsafe to decref the scope at this point.
for entry in lenv.arg_entries + lenv.var_entries:
if entry.type.is_pyobject:
code.put_var_decref(entry)
code.putln("%s = NULL;" % entry.cname)
# ----- Fetch arguments # ----- Fetch arguments
self.generate_argument_parsing_code(env, code) self.generate_argument_parsing_code(env, code)
# If an argument is assigned to in the body, we must # If an argument is assigned to in the body, we must
...@@ -1141,13 +1165,16 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1141,13 +1165,16 @@ class FuncDefNode(StatNode, BlockNode):
for entry in lenv.var_entries: for entry in lenv.var_entries:
if lenv.control_flow.get_state((entry.name, 'initalized')) is not True: if lenv.control_flow.get_state((entry.name, 'initalized')) is not True:
entry.xdecref_cleanup = 1 entry.xdecref_cleanup = 1
code.put_var_decrefs(lenv.var_entries, used_only = 1)
# Decref any increfed args
for entry in lenv.arg_entries:
if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
code.put_var_decref(entry)
# code.putln("/* TODO: decref scope object */") if self.needs_closure:
code.put_decref(lenv.closure_cname, lenv.scope_class.type)
else:
code.put_var_decrefs(lenv.var_entries, used_only = 1)
# Decref any increfed args
for entry in lenv.arg_entries:
if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
code.put_var_decref(entry)
# ----- Return # ----- Return
# This code is duplicated in ModuleNode.generate_module_init_func # This code is duplicated in ModuleNode.generate_module_init_func
if not lenv.nogil: if not lenv.nogil:
...@@ -1776,16 +1803,25 @@ class DefNode(FuncDefNode): ...@@ -1776,16 +1803,25 @@ class DefNode(FuncDefNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.local_scope.directives = env.directives self.local_scope.directives = env.directives
self.analyse_default_values(env) self.analyse_default_values(env)
if env.is_py_class_scope: if env.is_py_class_scope or env.is_closure_scope:
# Shouldn't we be doing this at the module level too?
self.synthesize_assignment_node(env) self.synthesize_assignment_node(env)
def synthesize_assignment_node(self, env): def synthesize_assignment_node(self, env):
import ExprNodes import ExprNodes
self.assmt = SingleAssignmentNode(self.pos, if env.is_py_class_scope:
lhs = ExprNodes.NameNode(self.pos, name = self.name),
rhs = ExprNodes.UnboundMethodNode(self.pos, rhs = ExprNodes.UnboundMethodNode(self.pos,
function = ExprNodes.PyCFunctionNode(self.pos, function = ExprNodes.PyCFunctionNode(self.pos,
pymethdef_cname = self.entry.pymethdef_cname))) pymethdef_cname = self.entry.pymethdef_cname))
elif env.is_closure_scope:
self_object = ExprNodes.TempNode(self.pos, env.scope_class.type, env)
self_object.temp_cname = "((PyObject*)%s)" % env.closure_cname
rhs = ExprNodes.PyCFunctionNode(self.pos,
self_object = self_object,
pymethdef_cname = self.entry.pymethdef_cname)
self.assmt = SingleAssignmentNode(self.pos,
lhs = ExprNodes.NameNode(self.pos, name = self.name),
rhs = rhs)
self.assmt.analyse_declarations(env) self.assmt.analyse_declarations(env)
self.assmt.analyse_expressions(env) self.assmt.analyse_expressions(env)
......
...@@ -864,21 +864,24 @@ class CreateClosureClasses(CythonTransform): ...@@ -864,21 +864,24 @@ class CreateClosureClasses(CythonTransform):
return node return node
def create_class_from_scope(self, node, target_module_scope): def create_class_from_scope(self, node, target_module_scope):
as_name = temp_name_handle("closure") as_name = "%s%s" % (Naming.closure_class_prefix, node.entry.cname)
func_scope = node.local_scope func_scope = node.local_scope
entry = target_module_scope.declare_c_class(name = as_name, entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True) pos = node.pos, defining = True, implementing = True)
func_scope.scope_class = entry
class_scope = entry.type.scope class_scope = entry.type.scope
for entry in func_scope.entries.values(): for entry in func_scope.entries.values():
cname = entry.cname[entry.cname.index('->')+2:] # everywhere but here they're attached to this class
class_scope.declare_var(pos=node.pos, class_scope.declare_var(pos=node.pos,
name=entry.name, name=entry.name,
cname=entry.cname, cname=cname,
type=entry.type, type=entry.type,
is_cdef=True) is_cdef=True)
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.create_class_from_scope(node, self.module_scope) if node.needs_closure:
self.create_class_from_scope(node, self.module_scope)
return node return node
......
...@@ -1599,7 +1599,7 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1599,7 +1599,7 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.api: if ctx.api:
error(s.pos, "'api' not allowed with this statement") error(s.pos, "'api' not allowed with this statement")
elif s.sy == 'def': elif s.sy == 'def':
if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property'): if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property', 'function'):
s.error('def statement not allowed here') s.error('def statement not allowed here')
s.level = ctx.level s.level = ctx.level
return p_def_statement(s, decorators) return p_def_statement(s, decorators)
......
...@@ -204,6 +204,7 @@ class Scope(object): ...@@ -204,6 +204,7 @@ class Scope(object):
is_py_class_scope = 0 is_py_class_scope = 0
is_c_class_scope = 0 is_c_class_scope = 0
is_closure_scope = 0
is_module_scope = 0 is_module_scope = 0
scope_prefix = "" scope_prefix = ""
in_cinclude = 0 in_cinclude = 0
...@@ -1071,15 +1072,33 @@ class LocalScope(Scope): ...@@ -1071,15 +1072,33 @@ class LocalScope(Scope):
entry.cname = scope_var + "->" + entry.cname entry.cname = scope_var + "->" + entry.cname
class GeneratorLocalScope(LocalScope): class ClosureScope(LocalScope):
def mangle_closure_cnames(self, scope_var): is_closure_scope = True
def __init__(self, name, scope_name, outer_scope):
LocalScope.__init__(self, name, outer_scope)
self.closure_cname = "%s%s" % (Naming.closure_scope_prefix, scope_name)
# def mangle_closure_cnames(self, scope_var):
# for entry in self.entries.values() + self.temp_entries: # for entry in self.entries.values() + self.temp_entries:
# entry.in_closure = 1 # entry.in_closure = 1
LocalScope.mangle_closure_cnames(self, scope_var) # LocalScope.mangle_closure_cnames(self, scope_var)
# def mangle(self, prefix, name): def mangle(self, prefix, name):
# return "%s->%s" % (Naming.scope_obj_cname, name) return "%s->%s" % (self.closure_cname, name)
def declare_pyfunction(self, name, pos):
# Add an entry for a Python function.
entry = self.lookup_here(name)
if entry and not entry.type.is_cfunction:
# This is legal Python, but for now may produce invalid C.
error(pos, "'%s' already declared" % name)
entry = self.declare_var(name, py_object_type, pos)
entry.signature = pyfunction_signature
self.pyfunc_entries.append(entry)
return entry
class StructOrUnionScope(Scope): class StructOrUnionScope(Scope):
# Namespace of a C struct or union. # Namespace of a C struct or union.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment