Commit a6f9a9dc authored by Tom Niget's avatar Tom Niget

Add support for nested functions

parent e54a10fc
......@@ -111,6 +111,7 @@ class FunctionEmissionKind(enum.Enum):
DECLARATION = enum.auto()
DEFINITION = enum.auto()
METHOD = enum.auto()
LAMBDA = enum.auto()
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
items = iter(items)
......
......@@ -56,12 +56,15 @@ class BlockVisitor(NodeVisitor):
yield f"}} {node.name};"
def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
if emission == FunctionEmissionKind.METHOD:
yield "template <typename Self>"
yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DEFINITION:
yield f"{node.name}_inner::"
yield "operator()"
if emission == FunctionEmissionKind.LAMBDA:
yield "[&]"
else:
if emission == FunctionEmissionKind.METHOD:
yield "template <typename Self>"
yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DEFINITION:
yield f"{node.name}_inner::"
yield "operator()"
yield "("
padded_defaults = [None] * (node.type.optional_at or len(node.args.args)) + node.args.defaults
args_iter = zip(node.args.args, node.type.parameters, padded_defaults)
......@@ -89,6 +92,10 @@ class BlockVisitor(NodeVisitor):
yield ";"
return
if emission == FunctionEmissionKind.LAMBDA:
yield "->"
yield from self.visit(node.type.return_type)
yield "{"
class ReturnVisitor(SearchVisitor):
......
......@@ -113,3 +113,9 @@ class FunctionVisitor(BlockVisitor):
#yield from self.visit(handler)
pass
# todo
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "auto"
yield self.fix_name(node.name)
yield "="
yield from self.visit_func_new(node, FunctionEmissionKind.LAMBDA)
yield ";"
......@@ -5,7 +5,7 @@ from typing import Dict, Optional
from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType, BuiltinFeature
from transpiler.phases.utils import NodeVisitorSeq
PRELUDE = Scope.make_global()
......@@ -16,6 +16,10 @@ class ScoperVisitor(NodeVisitorSeq):
root_decls: Dict[str, VarDecl] = field(default_factory=dict)
cur_class: Optional[TypeType] = None
def expr(self) -> "ScoperExprVisitor":
from transpiler.phases.typing.expr import ScoperExprVisitor
return ScoperExprVisitor(self.scope, self.root_decls)
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope, self.cur_class)
......@@ -58,7 +62,15 @@ class ScoperVisitor(NodeVisitorSeq):
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(node.inner_scope, decls)
visitor.fdecls = []
visitor.visit(b)
if len(visitor.fdecls) > 1:
raise NotImplementedError("?")
elif len(visitor.fdecls) == 1:
fnode, frtype = visitor.fdecls[0]
self.visit_function_definition(fnode, frtype)
del node.inner_scope.vars[fnode.name]
visitor.visit_assign_target(ast.Name(fnode.name), fnode.type)
b.decls = decls
if not node.inner_scope.has_return:
rtype.unify(TY_NONE) # todo: properly indicate missing return
......@@ -77,4 +89,7 @@ def get_next(iter_type):
except:
from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type)
return next_type
\ No newline at end of file
return next_type
def is_builtin(x, feature):
return isinstance(x, BuiltinFeature) and x.val == feature
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment