Commit 3da937ba authored by Tom Niget's avatar Tom Niget

Unify function and method parsing code

parent 014ef7ed
...@@ -11,7 +11,7 @@ from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER ...@@ -11,7 +11,7 @@ from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT
from transpiler.phases.utils import PlainBlock, AnnotationName from transpiler.phases.utils import PlainBlock, AnnotationName
...@@ -140,30 +140,12 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -140,30 +140,12 @@ class ScoperBlockVisitor(ScoperVisitor):
else: else:
raise NotImplementedError(ast.unparse(target)) raise NotImplementedError(ast.unparse(target))
def annotate_arg(self, arg: ast.arg) -> BaseType:
if arg.annotation is None:
res = TypeVariable()
arg.annotation = AnnotationName(res)
return res
else:
return self.visit_annotation(arg.annotation)
def visit_FunctionDef(self, node: ast.FunctionDef): def visit_FunctionDef(self, node: ast.FunctionDef):
argtypes = [self.annotate_arg(arg) for arg in node.args.args] ftype = self.parse_function(node)
rtype = Promise(self.visit_annotation(node.returns), PromiseKind.TASK) ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
ftype = FunctionType(argtypes, rtype)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
ftype.optional_at = len(node.args.args) - len(node.args.defaults)
for ty, default in zip(argtypes[ftype.optional_at:], node.args.defaults):
self.expr().visit(default).unify(ty)
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
self.fdecls.append((node, rtype.return_type))
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
ctype = UserType(node.name) ctype = UserType(node.name)
...@@ -206,13 +188,43 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -206,13 +188,43 @@ class ScoperBlockVisitor(ScoperVisitor):
) )
_, rtype = visitor.visit_FunctionDef(init_method) _, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype) visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method)
else: else:
raise NotImplementedError(deco) raise NotImplementedError(deco)
for base in node.bases: for base in node.bases:
base = self.expr().visit(base) base = self.expr().visit(base)
if is_builtin(base, "Enum"): if is_builtin(base, "Enum"):
ctype.parents.append(TY_INT)
for k in ctype.members: for k in ctype.members:
ctype.members[k] = ctype ctype.members[k] = ctype
ctype.members["value"] = TY_INT
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), ast.arg(arg="value")],
defaults=[],
kw_defaults=[],
kwarg=None,
kwonlyargs=[],
posonlyargs=[],
),
body=[
ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr="value")],
value=ast.Name(id="value"),
**lnd
)
],
decorator_list=[],
returns=None,
type_comment=None,
**lnd
)
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method)
ctype.is_enum = True
else: else:
raise NotImplementedError(base) raise NotImplementedError(base)
...@@ -228,6 +240,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -228,6 +240,8 @@ class ScoperBlockVisitor(ScoperVisitor):
else_visitor = ScoperBlockVisitor(else_scope, self.root_decls) else_visitor = ScoperBlockVisitor(else_scope, self.root_decls)
else_visitor.visit_block(node.orelse) else_visitor.visit_block(node.orelse)
node.orelse_scope = else_scope node.orelse_scope = else_scope
if then_scope.diverges and else_scope.diverges:
self.scope.diverges = True
def visit_While(self, node: ast.While): def visit_While(self, node: ast.While):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
...@@ -281,7 +295,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -281,7 +295,8 @@ class ScoperBlockVisitor(ScoperVisitor):
assert isinstance(ftype, FunctionType) assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else TY_NONE vtype = self.expr().visit(node.value) if node.value else TY_NONE
vtype.unify(ftype.return_type.return_type if isinstance(ftype.return_type, Promise) else ftype.return_type) vtype.unify(ftype.return_type.return_type if isinstance(ftype.return_type, Promise) else ftype.return_type)
fct.has_return = True self.scope.diverges = True
#fct.has_return = True
def visit_Global(self, node: ast.Global): def visit_Global(self, node: ast.Global):
for name in node.names: for name in node.names:
...@@ -348,6 +363,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -348,6 +363,7 @@ class ScoperBlockVisitor(ScoperVisitor):
raise NotImplementedError(node.finalbody) raise NotImplementedError(node.finalbody)
def visit_Raise(self, node: ast.Raise): def visit_Raise(self, node: ast.Raise):
self.scope.diverges = True
if node.exc: if node.exc:
self.expr().visit(node.exc) self.expr().visit(node.exc)
if node.cause: if node.cause:
......
...@@ -26,23 +26,11 @@ class ScoperClassVisitor(ScoperVisitor): ...@@ -26,23 +26,11 @@ class ScoperClassVisitor(ScoperVisitor):
self.scope.obj_type.members[node.targets[0].id] = valtype self.scope.obj_type.members[node.targets[0].id] = valtype
def visit_FunctionDef(self, node: ast.FunctionDef): def visit_FunctionDef(self, node: ast.FunctionDef):
from transpiler.phases.typing.block import ScoperBlockVisitor ftype = self.parse_function(node)
# TODO: maybe merge this code with ScoperBlockVisitor.visit_FunctionDef ftype.parameters[0].unify(self.scope.obj_type)
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args] inner = ftype.return_type
argtypes[0].unify(self.scope.obj_type) # self parameter
rtype = self.visit_annotation(node.returns)
inner_rtype = rtype
if node.name != "__init__": if node.name != "__init__":
rtype = Promise(rtype, PromiseKind.TASK) ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
ftype = FunctionType(argtypes, rtype) ftype.is_method = True
self.scope.obj_type.methods[node.name] = ftype self.scope.obj_type.methods[node.name] = ftype
scope = self.scope.child(ScopeKind.FUNCTION) return (node, inner)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
res = (node, inner_rtype)
self.fdecls.append(res)
return res
...@@ -4,9 +4,10 @@ from typing import Dict, Optional ...@@ -4,9 +4,10 @@ from typing import Dict, Optional
from transpiler.utils import highlight from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType, BuiltinFeature from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType, BuiltinFeature, FunctionType, \
from transpiler.phases.utils import NodeVisitorSeq Promise, PromiseKind
from transpiler.phases.utils import NodeVisitorSeq, AnnotationName
PRELUDE = Scope.make_global() PRELUDE = Scope.make_global()
...@@ -28,6 +29,31 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -28,6 +29,31 @@ class ScoperVisitor(NodeVisitorSeq):
assert not isinstance(res, TypeType) assert not isinstance(res, TypeType)
return res return res
def annotate_arg(self, arg: ast.arg) -> BaseType:
if arg.annotation is None:
res = TypeVariable()
arg.annotation = AnnotationName(res)
return res
else:
return self.visit_annotation(arg.annotation)
def parse_function(self, node: ast.FunctionDef):
argtypes = [self.annotate_arg(arg) for arg in node.args.args]
rtype = self.visit_annotation(node.returns)
ftype = FunctionType(argtypes, rtype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
ftype.optional_at = len(node.args.args) - len(node.args.defaults)
for ty, default in zip(argtypes[ftype.optional_at:], node.args.defaults):
self.expr().visit(default).unify(ty)
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
self.fdecls.append((node, rtype))
return ftype
def visit_block(self, block: list[ast.AST]): def visit_block(self, block: list[ast.AST]):
if not block: if not block:
return return
...@@ -69,11 +95,16 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -69,11 +95,16 @@ class ScoperVisitor(NodeVisitorSeq):
elif len(visitor.fdecls) == 1: elif len(visitor.fdecls) == 1:
fnode, frtype = visitor.fdecls[0] fnode, frtype = visitor.fdecls[0]
self.visit_function_definition(fnode, frtype) self.visit_function_definition(fnode, frtype)
del node.inner_scope.vars[fnode.name] #del node.inner_scope.vars[fnode.name]
visitor.visit_assign_target(ast.Name(fnode.name), fnode.type) visitor.visit_assign_target(ast.Name(fnode.name), fnode.type)
b.decls = decls b.decls = decls
if not node.inner_scope.has_return: if not node.inner_scope.diverges and not (isinstance(node.type.return_type, Promise) and node.type.return_type.kind == PromiseKind.GENERATOR):
rtype.unify(TY_NONE) # todo: properly indicate missing return from transpiler.phases.typing.exceptions import TypeMismatchError
try:
rtype.unify(TY_NONE)
except TypeMismatchError as e:
from transpiler.phases.typing.exceptions import MissingReturnError
raise MissingReturnError(node) from e
def get_iter(seq_type): def get_iter(seq_type):
try: try:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment