Commit a6f9a9dc authored by Tom Niget's avatar Tom Niget

Add support for nested functions

parent e54a10fc
...@@ -111,6 +111,7 @@ class FunctionEmissionKind(enum.Enum): ...@@ -111,6 +111,7 @@ class FunctionEmissionKind(enum.Enum):
DECLARATION = enum.auto() DECLARATION = enum.auto()
DEFINITION = enum.auto() DEFINITION = enum.auto()
METHOD = enum.auto() METHOD = enum.auto()
LAMBDA = enum.auto()
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]: def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
items = iter(items) items = iter(items)
......
...@@ -56,6 +56,9 @@ class BlockVisitor(NodeVisitor): ...@@ -56,6 +56,9 @@ class BlockVisitor(NodeVisitor):
yield f"}} {node.name};" yield f"}} {node.name};"
def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]: def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
if emission == FunctionEmissionKind.LAMBDA:
yield "[&]"
else:
if emission == FunctionEmissionKind.METHOD: if emission == FunctionEmissionKind.METHOD:
yield "template <typename Self>" yield "template <typename Self>"
yield from self.visit(node.type.return_type) yield from self.visit(node.type.return_type)
...@@ -89,6 +92,10 @@ class BlockVisitor(NodeVisitor): ...@@ -89,6 +92,10 @@ class BlockVisitor(NodeVisitor):
yield ";" yield ";"
return return
if emission == FunctionEmissionKind.LAMBDA:
yield "->"
yield from self.visit(node.type.return_type)
yield "{" yield "{"
class ReturnVisitor(SearchVisitor): class ReturnVisitor(SearchVisitor):
......
...@@ -113,3 +113,9 @@ class FunctionVisitor(BlockVisitor): ...@@ -113,3 +113,9 @@ class FunctionVisitor(BlockVisitor):
#yield from self.visit(handler) #yield from self.visit(handler)
pass pass
# todo # todo
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "auto"
yield self.fix_name(node.name)
yield "="
yield from self.visit_func_new(node, FunctionEmissionKind.LAMBDA)
yield ";"
...@@ -5,7 +5,7 @@ from typing import Dict, Optional ...@@ -5,7 +5,7 @@ from typing import Dict, Optional
from transpiler.utils import highlight from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType, BuiltinFeature
from transpiler.phases.utils import NodeVisitorSeq from transpiler.phases.utils import NodeVisitorSeq
PRELUDE = Scope.make_global() PRELUDE = Scope.make_global()
...@@ -16,6 +16,10 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -16,6 +16,10 @@ class ScoperVisitor(NodeVisitorSeq):
root_decls: Dict[str, VarDecl] = field(default_factory=dict) root_decls: Dict[str, VarDecl] = field(default_factory=dict)
cur_class: Optional[TypeType] = None cur_class: Optional[TypeType] = None
def expr(self) -> "ScoperExprVisitor":
from transpiler.phases.typing.expr import ScoperExprVisitor
return ScoperExprVisitor(self.scope, self.root_decls)
def anno(self) -> "TypeAnnotationVisitor": def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope, self.cur_class) return TypeAnnotationVisitor(self.scope, self.cur_class)
...@@ -58,7 +62,15 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -58,7 +62,15 @@ class ScoperVisitor(NodeVisitorSeq):
for b in node.body: for b in node.body:
decls = {} decls = {}
visitor = ScoperBlockVisitor(node.inner_scope, decls) visitor = ScoperBlockVisitor(node.inner_scope, decls)
visitor.fdecls = []
visitor.visit(b) visitor.visit(b)
if len(visitor.fdecls) > 1:
raise NotImplementedError("?")
elif len(visitor.fdecls) == 1:
fnode, frtype = visitor.fdecls[0]
self.visit_function_definition(fnode, frtype)
del node.inner_scope.vars[fnode.name]
visitor.visit_assign_target(ast.Name(fnode.name), fnode.type)
b.decls = decls b.decls = decls
if not node.inner_scope.has_return: if not node.inner_scope.has_return:
rtype.unify(TY_NONE) # todo: properly indicate missing return rtype.unify(TY_NONE) # todo: properly indicate missing return
...@@ -78,3 +90,6 @@ def get_next(iter_type): ...@@ -78,3 +90,6 @@ def get_next(iter_type):
from transpiler.phases.typing.exceptions import NotIteratorError from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type) raise NotIteratorError(iter_type)
return next_type return next_type
def is_builtin(x, feature):
return isinstance(x, BuiltinFeature) and x.val == feature
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment