Commit 28ff809a authored by Tom Niget's avatar Tom Niget

Add initial support for reference-counted user-defined classes

parent e4d6f647
# coding: utf-8
class Person:
name: str
age: int
def __init__(self, name: str, age: int):
self.name = name
self.age = age
def afficher(self):
print(self.name, self.age)
def creer():
return Person("jean", 123)
if __name__ == "__main__":
x = creer()
print(x.name)
print(x.age)
x.afficher()
......@@ -6,7 +6,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp.consts import MAPPINGS
from transpiler.phases.typing import TypeVariable
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind, TY_STR, UserType
from transpiler.utils import UnsupportedNodeError
class UniversalVisitor:
......@@ -55,6 +55,10 @@ class NodeVisitor(UniversalVisitor):
yield "bool"
elif node is TY_NONE:
yield "void"
elif node is TY_STR:
yield "std::string"
elif isinstance(node, UserType):
yield f"std::shared_ptr<decltype({node.name})::type>"
elif isinstance(node, Promise):
yield "typon::"
if node.kind == PromiseKind.TASK:
......
......@@ -22,6 +22,9 @@ class BlockVisitor(NodeVisitor):
def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(self.scope, self.generator)
def visit_Pass(self, node: ast.Pass) -> Iterable[str]:
yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
if getattr(node, "is_main", False):
# Special case handling for Python's interesting way of defining an entry point.
......@@ -83,11 +86,14 @@ class BlockVisitor(NodeVisitor):
yield "}"
yield f"}} {node.name};"
def visit_func_new(self, node: ast.FunctionDef) -> Iterable[str]:
def visit_func_new(self, node: ast.FunctionDef, skip_first_arg: bool = False) -> Iterable[str]:
yield from self.visit(node.type.return_type)
yield "operator()"
yield "("
for i, (arg, argty) in enumerate(zip(node.args.args, node.type.parameters)):
args_iter = zip(node.args.args, node.type.parameters)
if skip_first_arg:
next(args_iter)
for i, (arg, argty) in enumerate(args_iter):
if i != 0:
yield ", "
yield from self.visit(argty)
......@@ -241,6 +247,8 @@ class BlockVisitor(NodeVisitor):
yield name
elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue)
elif isinstance(lvalue, ast.Attribute):
yield from self.expr().visit(lvalue)
else:
raise NotImplementedError(lvalue)
......
# coding: utf-8
import ast
from typing import Iterable
from dataclasses import dataclass
from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp import NodeVisitor
class ClassVisitor(NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
yield "struct {"
yield "struct type {"
inner = ClassInnerVisitor(node.inner_scope)
for stmt in node.body:
yield from inner.visit(stmt)
yield "template<typename... T> type(T&&... args) {"
yield "__init__(std::forward<T>(args)...);"
yield "}"
yield "type(const type&) = delete;"
yield "type(type&&) = delete;"
yield "};"
yield "template<typename... T> auto operator()(T&&... args) {"
yield "return std::make_shared<type>(std::forward<T>(args)...);"
yield "}"
outer = ClassOuterVisitor(node.inner_scope)
for stmt in node.body:
yield from outer.visit(stmt)
yield f"}} {node.name};"
@dataclass
class ClassInnerVisitor(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
member = self.scope.obj_type.members[node.target.id]
yield from self.visit(member)
yield node.target.id
yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {"
yield "type* self;"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_func_new(node, True)
yield f"}} {node.name} {{ this }};"
@dataclass
class ClassOuterVisitor(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
yield ""
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {"
yield "template<typename... T>"
yield "auto operator()(type& self, T&&... args) {"
yield f"return self.{node.name}(std::forward<T>(args)...);"
yield "}"
yield f"}} {node.name};"
......@@ -3,6 +3,7 @@ import ast
from dataclasses import dataclass, field
from typing import List, Iterable
from transpiler.phases.typing.types import UserType
from transpiler.utils import compare_ast
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor
......@@ -166,7 +167,10 @@ class ExpressionVisitor(NodeVisitor):
def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]:
yield from self.prec(".").visit(node.value)
yield "."
if isinstance(node.value.type, UserType):
yield "->"
else:
yield "."
yield node.attr
def visit_List(self, node: ast.List) -> Iterable[str]:
......
......@@ -4,6 +4,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp import CoroutineMode
from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.emit_cpp.class_ import ClassVisitor
from transpiler.phases.emit_cpp.function import FunctionVisitor
from transpiler.utils import compare_ast
......@@ -37,3 +38,6 @@ class ModuleVisitor(BlockVisitor):
yield f"//{node.value.s}"
else:
raise NotImplementedError(node)
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
yield from ClassVisitor().visit(node)
import ast
from dataclasses import dataclass
from typing import Optional
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \
Promise, TY_NONE, PromiseKind, TupleType
Promise, TY_NONE, PromiseKind, TupleType, UserType
@dataclass
......@@ -17,6 +16,9 @@ class ScoperBlockVisitor(ScoperVisitor):
def expr(self) -> ScoperExprVisitor:
return ScoperExprVisitor(self.scope, self.root_decls)
def visit_Pass(self, node: ast.Pass):
pass
def visit_Import(self, node: ast.Import):
for alias in node.names:
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, None)
......@@ -51,11 +53,12 @@ class ScoperBlockVisitor(ScoperVisitor):
raise NotImplementedError(node)
target = node.targets[0]
ty = self.get_type(node.value)
target.type = ty
node.is_declare = self.visit_assign_target(target, ty)
target.type.unify(ty)
def visit_assign_target(self, target, decl_val: BaseType) -> bool:
if isinstance(target, ast.Name):
target.type = decl_val
if vdecl := self.scope.get(target.id):
vdecl.type.unify(decl_val)
return False
......@@ -68,15 +71,13 @@ class ScoperBlockVisitor(ScoperVisitor):
if not (isinstance(decl_val, TupleType) and len(target.elts) == len(decl_val.args)):
raise IncompatibleTypesError(f"Cannot unpack {decl_val} into {target}")
return any(self.visit_assign_target(t, ty) for t, ty in zip(target.elts, decl_val.args))
elif isinstance(target, ast.Attribute):
attr_type = self.expr().visit(target)
attr_type.unify(decl_val)
return False
else:
raise NotImplementedError(target)
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope)
def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType:
return self.anno().visit(expr) if expr else TypeVariable()
def visit_FunctionDef(self, node: ast.FunctionDef):
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
rtype = Promise(self.visit_annotation(node.returns), PromiseKind.TASK)
......@@ -97,6 +98,18 @@ class ScoperBlockVisitor(ScoperVisitor):
if not scope.has_return:
rtype.return_type.unify(TY_NONE)
def visit_ClassDef(self, node: ast.ClassDef):
ctype = UserType(node.name)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ctype)
scope = self.scope.child(ScopeKind.CLASS)
scope.obj_type = ctype
scope.class_ = scope
node.inner_scope = scope
node.type = ctype
visitor = ScoperClassVisitor(scope)
for b in node.body:
visitor.visit(b)
def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
node.inner_scope = scope
......
# coding: utf-8
import ast
from dataclasses import dataclass
from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE
from transpiler.phases.typing.common import ScoperVisitor
@dataclass
class ScoperClassVisitor(ScoperVisitor):
def visit_AnnAssign(self, node: ast.AnnAssign):
assert node.value is None, "Class field should not have a value"
assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)"
assert isinstance(node.target, ast.Name)
self.scope.obj_type.members[node.target.id] = self.visit_annotation(node.annotation)
def visit_FunctionDef(self, node: ast.FunctionDef):
from transpiler.phases.typing.block import ScoperBlockVisitor
# TODO: maybe merge this code with ScoperBlockVisitor.visit_FunctionDef
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
argtypes[0].unify(self.scope.obj_type) # self parameter
rtype = self.visit_annotation(node.returns)
ftype = FunctionType(argtypes, rtype)
self.scope.obj_type.methods[node.name] = ftype
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(scope, decls)
visitor.visit(b)
b.decls = decls
if not scope.has_return:
rtype.unify(TY_NONE)
import ast
from dataclasses import dataclass, field
from typing import Dict
from typing import Dict, Optional
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl
from transpiler.phases.typing.types import BaseType, TypeVariable
from transpiler.phases.utils import NodeVisitorSeq
PRELUDE = Scope.make_global()
......@@ -9,4 +12,10 @@ PRELUDE = Scope.make_global()
@dataclass
class ScoperVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL))
root_decls: Dict[str, VarDecl] = field(default_factory=dict)
\ No newline at end of file
root_decls: Dict[str, VarDecl] = field(default_factory=dict)
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope)
def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType:
return self.anno().visit(expr) if expr else TypeVariable()
\ No newline at end of file
......@@ -6,7 +6,7 @@ from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType
DUNDER = {
ast.Eq: "eq",
......@@ -105,6 +105,10 @@ class ScoperExprVisitor(ScoperVisitor):
return actual
def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]):
if isinstance(ftype, UserType):
init: FunctionType = self.visit_getattr(ftype, "__init__")
ctor = FunctionType(init.args[1:], ftype)
return self.visit_function_call(ctor, arguments)
if not isinstance(ftype, FunctionType):
raise IncompatibleTypesError(f"Cannot call {ftype}")
#is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list())
......
......@@ -53,6 +53,7 @@ class Scope:
children: List["Scope"] = field(default_factory=list)
obj_type: Optional[BaseType] = None
has_return: bool = False
class_: Optional["Scope"] = None
@staticmethod
def make_global():
......
......@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from itertools import zip_longest
from typing import Dict, Optional, List, ClassVar, Callable, Any
from typing import Dict, Optional, List, ClassVar, Callable
class IncompatibleTypesError(Exception):
......@@ -215,7 +215,6 @@ class TypeOperator(BaseType, ABC):
return [self, *self.args]
class FunctionType(TypeOperator):
def __init__(self, args: List[BaseType], ret: BaseType):
super().__init__([ret, *args])
......@@ -374,3 +373,12 @@ class Future(Promise):
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.FUTURE)
class UserType(TypeOperator):
def __init__(self, name: str):
super().__init__([], name=name)
def unify_internal(self, other: "BaseType"):
if type(self) != type(other):
raise IncompatibleTypesError()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment