Commit 56e68e46 authored by Tom Niget's avatar Tom Niget

Analyze functions in two passes (signatures, then bodies) to allow for forward use

parent ee666ed5
...@@ -38,8 +38,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -38,8 +38,7 @@ class ScoperBlockVisitor(ScoperVisitor):
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing) self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing)
def visit_Module(self, node: ast.Module): def visit_Module(self, node: ast.Module):
for stmt in node.body: self.visit_block(node.body)
self.visit(stmt)
def get_type(self, node: ast.expr) -> BaseType: def get_type(self, node: ast.expr) -> BaseType:
if type := getattr(node, "type", None): if type := getattr(node, "type", None):
...@@ -98,13 +97,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -98,13 +97,7 @@ class ScoperBlockVisitor(ScoperVisitor):
ftype.optional_at = 1 + len(node.args.args) - len(node.args.defaults) ftype.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
for arg, ty in zip(node.args.args, argtypes): for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty) scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body: self.fdecls.append((node, rtype.return_type))
decls = {}
visitor = ScoperBlockVisitor(scope, decls)
visitor.visit(b)
b.decls = decls
if not scope.has_return:
rtype.return_type.unify(TY_NONE)
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
ctype = UserType(node.name) ctype = UserType(node.name)
...@@ -115,8 +108,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -115,8 +108,7 @@ class ScoperBlockVisitor(ScoperVisitor):
node.inner_scope = scope node.inner_scope = scope
node.type = ctype node.type = ctype
visitor = ScoperClassVisitor(scope) visitor = ScoperClassVisitor(scope)
for b in node.body: visitor.visit_block(node.body)
visitor.visit(b)
def visit_If(self, node: ast.If): def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
...@@ -124,13 +116,11 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -124,13 +116,11 @@ class ScoperBlockVisitor(ScoperVisitor):
self.expr().visit(node.test) self.expr().visit(node.test)
then_scope = scope.child(ScopeKind.FUNCTION_INNER) then_scope = scope.child(ScopeKind.FUNCTION_INNER)
then_visitor = ScoperBlockVisitor(then_scope, self.root_decls) then_visitor = ScoperBlockVisitor(then_scope, self.root_decls)
for b in node.body: then_visitor.visit_block(node.body)
then_visitor.visit(b)
if node.orelse: if node.orelse:
else_scope = scope.child(ScopeKind.FUNCTION_INNER) else_scope = scope.child(ScopeKind.FUNCTION_INNER)
else_visitor = ScoperBlockVisitor(else_scope, self.root_decls) else_visitor = ScoperBlockVisitor(else_scope, self.root_decls)
for b in node.orelse: else_visitor.visit_block(node.orelse.body)
else_visitor.visit(b)
def visit_While(self, node: ast.While): def visit_While(self, node: ast.While):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
...@@ -138,8 +128,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -138,8 +128,7 @@ class ScoperBlockVisitor(ScoperVisitor):
self.expr().visit(node.test) self.expr().visit(node.test)
body_scope = scope.child(ScopeKind.FUNCTION_INNER) body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls) body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
for b in node.body: body_visitor.visit_block(node.body)
body_visitor.visit(b)
if node.orelse: if node.orelse:
raise NotImplementedError(node.orelse) raise NotImplementedError(node.orelse)
...@@ -151,8 +140,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -151,8 +140,7 @@ class ScoperBlockVisitor(ScoperVisitor):
self.expr().visit(node.iter) self.expr().visit(node.iter)
body_scope = scope.child(ScopeKind.FUNCTION_INNER) body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls) body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
for b in node.body: body_visitor.visit_block(node.body)
body_visitor.visit(b)
if node.orelse: if node.orelse:
raise NotImplementedError(node.orelse) raise NotImplementedError(node.orelse)
...@@ -183,6 +171,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -183,6 +171,8 @@ class ScoperBlockVisitor(ScoperVisitor):
if isinstance(node, ast.AST): if isinstance(node, ast.AST):
super().visit(node) super().visit(node)
node.scope = self.scope node.scope = self.scope
else:
raise NotImplementedError(node)
def visit_Break(self, node: ast.Break): def visit_Break(self, node: ast.Break):
pass # TODO: check in loop pass # TODO: check in loop
# coding: utf-8 # coding: utf-8
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass, field
from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import PromiseKind, Promise, BaseType
@dataclass @dataclass
class ScoperClassVisitor(ScoperVisitor): class ScoperClassVisitor(ScoperVisitor):
fdecls: list[(ast.FunctionDef, BaseType)] = field(default_factory=list)
def visit_AnnAssign(self, node: ast.AnnAssign): def visit_AnnAssign(self, node: ast.AnnAssign):
assert node.value is None, "Class field should not have a value" assert node.value is None, "Class field should not have a value"
assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)" assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)"
...@@ -20,6 +23,9 @@ class ScoperClassVisitor(ScoperVisitor): ...@@ -20,6 +23,9 @@ class ScoperClassVisitor(ScoperVisitor):
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args] argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
argtypes[0].unify(self.scope.obj_type) # self parameter argtypes[0].unify(self.scope.obj_type) # self parameter
rtype = self.visit_annotation(node.returns) rtype = self.visit_annotation(node.returns)
inner_rtype = rtype
if node.name != "__init__":
rtype = Promise(rtype, PromiseKind.TASK)
ftype = FunctionType(argtypes, rtype) ftype = FunctionType(argtypes, rtype)
self.scope.obj_type.methods[node.name] = ftype self.scope.obj_type.methods[node.name] = ftype
scope = self.scope.child(ScopeKind.FUNCTION) scope = self.scope.child(ScopeKind.FUNCTION)
...@@ -29,10 +35,4 @@ class ScoperClassVisitor(ScoperVisitor): ...@@ -29,10 +35,4 @@ class ScoperClassVisitor(ScoperVisitor):
node.type = ftype node.type = ftype
for arg, ty in zip(node.args.args, argtypes): for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty) scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body: self.fdecls.append((node, inner_rtype))
decls = {}
visitor = ScoperBlockVisitor(scope, decls)
visitor.visit(b)
b.decls = decls
if not scope.has_return:
rtype.unify(TY_NONE)
...@@ -4,7 +4,7 @@ from typing import Dict, Optional ...@@ -4,7 +4,7 @@ from typing import Dict, Optional
from transpiler.phases.typing.annotations import TypeAnnotationVisitor from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl
from transpiler.phases.typing.types import BaseType, TypeVariable from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE
from transpiler.phases.utils import NodeVisitorSeq from transpiler.phases.utils import NodeVisitorSeq
PRELUDE = Scope.make_global() PRELUDE = Scope.make_global()
...@@ -18,4 +18,18 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -18,4 +18,18 @@ class ScoperVisitor(NodeVisitorSeq):
return TypeAnnotationVisitor(self.scope) return TypeAnnotationVisitor(self.scope)
def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType: def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType:
return self.anno().visit(expr) if expr else TypeVariable() return self.anno().visit(expr) if expr else TypeVariable()
\ No newline at end of file
def visit_block(self, block: list[ast.AST]):
from transpiler.phases.typing.block import ScoperBlockVisitor
self.fdecls = []
for b in block:
self.visit(b)
for node, rtype in self.fdecls:
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(node.inner_scope, decls)
visitor.visit(b)
b.decls = decls
if not node.inner_scope.has_return:
rtype.unify(TY_NONE)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment