Commit 747a0110 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Generates closure classes for all functions

parent 7308254f
...@@ -336,6 +336,7 @@ def create_generate_code(context, options, result): ...@@ -336,6 +336,7 @@ def create_generate_code(context, options, result):
def create_default_pipeline(context, options, result): def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, PostParse from ParseTreeTransforms import WithTransform, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
return [ return [
...@@ -345,6 +346,7 @@ def create_default_pipeline(context, options, result): ...@@ -345,6 +346,7 @@ def create_default_pipeline(context, options, result):
AnalyseDeclarationsTransform(), AnalyseDeclarationsTransform(),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(), AnalyseExpressionsTransform(),
CreateClosureClasses(),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.TreeFragment import TreeFragment
class PostParse(VisitorTransform): class PostParse(VisitorTransform):
""" """
This transform fixes up a few things after parsing This transform fixes up a few things after parsing
...@@ -170,7 +170,6 @@ class AnalyseDeclarationsTransform(VisitorTransform): ...@@ -170,7 +170,6 @@ class AnalyseDeclarationsTransform(VisitorTransform):
class AnalyseExpressionsTransform(VisitorTransform): class AnalyseExpressionsTransform(VisitorTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
self.visitchildren(node) self.visitchildren(node)
...@@ -185,3 +184,35 @@ class AnalyseExpressionsTransform(VisitorTransform): ...@@ -185,3 +184,35 @@ class AnalyseExpressionsTransform(VisitorTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
class CreateClosureClasses(VisitorTransform):
# Output closure classes in module scope for all functions
# that need it.
def visit_ModuleNode(self, node):
self.module_scope = node.scope
self.visitchildren(node)
return node
def create_class_from_scope(self, node, target_module_scope):
as_name = temp_name_handle("closure")
func_scope = node.local_scope
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True)
class_scope = entry.type.scope
for entry in func_scope.entries.values():
class_scope.declare_var(pos=node.pos,
name=entry.name,
cname=entry.cname,
type=entry.type,
is_cdef=True)
def visit_FuncDefNode(self, node):
self.create_class_from_scope(node, self.module_scope)
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
...@@ -1386,7 +1386,7 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1386,7 +1386,7 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.api: if ctx.api:
error(s.pos, "'api' not allowed with this statement") error(s.pos, "'api' not allowed with this statement")
elif s.sy == 'def': elif s.sy == 'def':
if ctx.level not in ('module', 'class', 'c_class', 'property'): if ctx.level not in ('module', 'class', 'c_class', 'function', 'property'):
s.error('def statement not allowed here') s.error('def statement not allowed here')
s.level = ctx.level s.level = ctx.level
return p_def_statement(s) return p_def_statement(s)
......
...@@ -27,6 +27,15 @@ class NodeTypeWriter(TreeVisitor): ...@@ -27,6 +27,15 @@ class NodeTypeWriter(TreeVisitor):
self.visitchildren(node) self.visitchildren(node)
self._indents -= 1 self._indents -= 1
def treetypes(root):
"""Returns a string representing the tree by class names.
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class CythonTest(unittest.TestCase): class CythonTest(unittest.TestCase):
def assertLines(self, expected, result): def assertLines(self, expected, result):
...@@ -58,13 +67,7 @@ class CythonTest(unittest.TestCase): ...@@ -58,13 +67,7 @@ class CythonTest(unittest.TestCase):
return TreeFragment(code, name, pxds) return TreeFragment(code, name, pxds)
def treetypes(self, root): def treetypes(self, root):
"""Returns a string representing the tree by class names. return treetypes(root)
There's a leading and trailing whitespace so that it can be
compared by simple string comparison while still making test
cases look ok."""
w = NodeTypeWriter()
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
class TransformTest(CythonTest): class TransformTest(CythonTest):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment