Commit 067451b0 authored by Tom Niget's avatar Tom Niget

Add initial support for reference-counted user-defined classes

parent 2c5aca4d
# coding: utf-8
class Person:
name: str
age: int
def __init__(self, name: str, age: int):
self.name = name
self.age = age
def afficher(self):
print(self.name, self.age)
def creer():
return Person("jean", 123)
if __name__ == "__main__":
x = creer()
print(x.name)
print(x.age)
x.afficher()
...@@ -6,7 +6,7 @@ from typing import Iterable ...@@ -6,7 +6,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp.consts import MAPPINGS from transpiler.phases.emit_cpp.consts import MAPPINGS
from transpiler.phases.typing import TypeVariable from transpiler.phases.typing import TypeVariable
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind, TY_STR, UserType
from transpiler.utils import UnsupportedNodeError from transpiler.utils import UnsupportedNodeError
class UniversalVisitor: class UniversalVisitor:
...@@ -55,6 +55,10 @@ class NodeVisitor(UniversalVisitor): ...@@ -55,6 +55,10 @@ class NodeVisitor(UniversalVisitor):
yield "bool" yield "bool"
elif node is TY_NONE: elif node is TY_NONE:
yield "void" yield "void"
elif node is TY_STR:
yield "std::string"
elif isinstance(node, UserType):
yield f"std::shared_ptr<decltype({node.name})::type>"
elif isinstance(node, Promise): elif isinstance(node, Promise):
yield "typon::" yield "typon::"
if node.kind == PromiseKind.TASK: if node.kind == PromiseKind.TASK:
......
...@@ -22,6 +22,9 @@ class BlockVisitor(NodeVisitor): ...@@ -22,6 +22,9 @@ class BlockVisitor(NodeVisitor):
def expr(self) -> ExpressionVisitor: def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(self.scope, self.generator) return ExpressionVisitor(self.scope, self.generator)
def visit_Pass(self, node: ast.Pass) -> Iterable[str]:
yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
if getattr(node, "is_main", False): if getattr(node, "is_main", False):
# Special case handling for Python's interesting way of defining an entry point. # Special case handling for Python's interesting way of defining an entry point.
...@@ -83,11 +86,14 @@ class BlockVisitor(NodeVisitor): ...@@ -83,11 +86,14 @@ class BlockVisitor(NodeVisitor):
yield "}" yield "}"
yield f"}} {node.name};" yield f"}} {node.name};"
def visit_func_new(self, node: ast.FunctionDef) -> Iterable[str]: def visit_func_new(self, node: ast.FunctionDef, skip_first_arg: bool = False) -> Iterable[str]:
yield from self.visit(node.type.return_type) yield from self.visit(node.type.return_type)
yield "operator()" yield "operator()"
yield "(" yield "("
for i, (arg, argty) in enumerate(zip(node.args.args, node.type.parameters)): args_iter = zip(node.args.args, node.type.parameters)
if skip_first_arg:
next(args_iter)
for i, (arg, argty) in enumerate(args_iter):
if i != 0: if i != 0:
yield ", " yield ", "
yield from self.visit(argty) yield from self.visit(argty)
...@@ -241,6 +247,8 @@ class BlockVisitor(NodeVisitor): ...@@ -241,6 +247,8 @@ class BlockVisitor(NodeVisitor):
yield name yield name
elif isinstance(lvalue, ast.Subscript): elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue) yield from self.expr().visit(lvalue)
elif isinstance(lvalue, ast.Attribute):
yield from self.expr().visit(lvalue)
else: else:
raise NotImplementedError(lvalue) raise NotImplementedError(lvalue)
......
# coding: utf-8
import ast
from typing import Iterable
from dataclasses import dataclass
from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp import NodeVisitor
class ClassVisitor(NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
yield "struct {"
yield "struct type {"
inner = ClassInnerVisitor(node.inner_scope)
for stmt in node.body:
yield from inner.visit(stmt)
yield "template<typename... T> type(T&&... args) {"
yield "__init__(std::forward<T>(args)...);"
yield "}"
yield "type(const type&) = delete;"
yield "type(type&&) = delete;"
yield "};"
yield "template<typename... T> auto operator()(T&&... args) {"
yield "return std::make_shared<type>(std::forward<T>(args)...);"
yield "}"
outer = ClassOuterVisitor(node.inner_scope)
for stmt in node.body:
yield from outer.visit(stmt)
yield f"}} {node.name};"
@dataclass
class ClassInnerVisitor(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
member = self.scope.obj_type.members[node.target.id]
yield from self.visit(member)
yield node.target.id
yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {"
yield "type* self;"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_func_new(node, True)
yield f"}} {node.name} {{ this }};"
@dataclass
class ClassOuterVisitor(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
yield ""
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {"
yield "template<typename... T>"
yield "auto operator()(type& self, T&&... args) {"
yield f"return self.{node.name}(std::forward<T>(args)...);"
yield "}"
yield f"}} {node.name};"
...@@ -3,6 +3,7 @@ import ast ...@@ -3,6 +3,7 @@ import ast
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Iterable from typing import List, Iterable
from transpiler.phases.typing.types import UserType
from transpiler.utils import compare_ast from transpiler.utils import compare_ast
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor
...@@ -166,6 +167,9 @@ class ExpressionVisitor(NodeVisitor): ...@@ -166,6 +167,9 @@ class ExpressionVisitor(NodeVisitor):
def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]: def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]:
yield from self.prec(".").visit(node.value) yield from self.prec(".").visit(node.value)
if isinstance(node.value.type, UserType):
yield "->"
else:
yield "." yield "."
yield node.attr yield node.attr
......
...@@ -4,6 +4,7 @@ from typing import Iterable ...@@ -4,6 +4,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp import CoroutineMode from transpiler.phases.emit_cpp import CoroutineMode
from transpiler.phases.emit_cpp.block import BlockVisitor from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.emit_cpp.class_ import ClassVisitor
from transpiler.phases.emit_cpp.function import FunctionVisitor from transpiler.phases.emit_cpp.function import FunctionVisitor
from transpiler.utils import compare_ast from transpiler.utils import compare_ast
...@@ -37,3 +38,6 @@ class ModuleVisitor(BlockVisitor): ...@@ -37,3 +38,6 @@ class ModuleVisitor(BlockVisitor):
yield f"//{node.value.s}" yield f"//{node.value.s}"
else: else:
raise NotImplementedError(node) raise NotImplementedError(node)
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
yield from ClassVisitor().visit(node)
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \
Promise, TY_NONE, PromiseKind, TupleType Promise, TY_NONE, PromiseKind, TupleType, UserType
@dataclass @dataclass
...@@ -17,6 +16,9 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -17,6 +16,9 @@ class ScoperBlockVisitor(ScoperVisitor):
def expr(self) -> ScoperExprVisitor: def expr(self) -> ScoperExprVisitor:
return ScoperExprVisitor(self.scope, self.root_decls) return ScoperExprVisitor(self.scope, self.root_decls)
def visit_Pass(self, node: ast.Pass):
pass
def visit_Import(self, node: ast.Import): def visit_Import(self, node: ast.Import):
for alias in node.names: for alias in node.names:
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, None) self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, None)
...@@ -51,11 +53,12 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -51,11 +53,12 @@ class ScoperBlockVisitor(ScoperVisitor):
raise NotImplementedError(node) raise NotImplementedError(node)
target = node.targets[0] target = node.targets[0]
ty = self.get_type(node.value) ty = self.get_type(node.value)
target.type = ty
node.is_declare = self.visit_assign_target(target, ty) node.is_declare = self.visit_assign_target(target, ty)
target.type.unify(ty)
def visit_assign_target(self, target, decl_val: BaseType) -> bool: def visit_assign_target(self, target, decl_val: BaseType) -> bool:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
target.type = decl_val
if vdecl := self.scope.get(target.id): if vdecl := self.scope.get(target.id):
vdecl.type.unify(decl_val) vdecl.type.unify(decl_val)
return False return False
...@@ -68,15 +71,13 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -68,15 +71,13 @@ class ScoperBlockVisitor(ScoperVisitor):
if not (isinstance(decl_val, TupleType) and len(target.elts) == len(decl_val.args)): if not (isinstance(decl_val, TupleType) and len(target.elts) == len(decl_val.args)):
raise IncompatibleTypesError(f"Cannot unpack {decl_val} into {target}") raise IncompatibleTypesError(f"Cannot unpack {decl_val} into {target}")
return any(self.visit_assign_target(t, ty) for t, ty in zip(target.elts, decl_val.args)) return any(self.visit_assign_target(t, ty) for t, ty in zip(target.elts, decl_val.args))
elif isinstance(target, ast.Attribute):
attr_type = self.expr().visit(target)
attr_type.unify(decl_val)
return False
else: else:
raise NotImplementedError(target) raise NotImplementedError(target)
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope)
def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType:
return self.anno().visit(expr) if expr else TypeVariable()
def visit_FunctionDef(self, node: ast.FunctionDef): def visit_FunctionDef(self, node: ast.FunctionDef):
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args] argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
rtype = Promise(self.visit_annotation(node.returns), PromiseKind.TASK) rtype = Promise(self.visit_annotation(node.returns), PromiseKind.TASK)
...@@ -97,6 +98,18 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -97,6 +98,18 @@ class ScoperBlockVisitor(ScoperVisitor):
if not scope.has_return: if not scope.has_return:
rtype.return_type.unify(TY_NONE) rtype.return_type.unify(TY_NONE)
def visit_ClassDef(self, node: ast.ClassDef):
ctype = UserType(node.name)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ctype)
scope = self.scope.child(ScopeKind.CLASS)
scope.obj_type = ctype
scope.class_ = scope
node.inner_scope = scope
node.type = ctype
visitor = ScoperClassVisitor(scope)
for b in node.body:
visitor.visit(b)
def visit_If(self, node: ast.If): def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
node.inner_scope = scope node.inner_scope = scope
......
# coding: utf-8
import ast
from dataclasses import dataclass
from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE
from transpiler.phases.typing.common import ScoperVisitor
@dataclass
class ScoperClassVisitor(ScoperVisitor):
def visit_AnnAssign(self, node: ast.AnnAssign):
assert node.value is None, "Class field should not have a value"
assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)"
assert isinstance(node.target, ast.Name)
self.scope.obj_type.members[node.target.id] = self.visit_annotation(node.annotation)
def visit_FunctionDef(self, node: ast.FunctionDef):
from transpiler.phases.typing.block import ScoperBlockVisitor
# TODO: maybe merge this code with ScoperBlockVisitor.visit_FunctionDef
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
argtypes[0].unify(self.scope.obj_type) # self parameter
rtype = self.visit_annotation(node.returns)
ftype = FunctionType(argtypes, rtype)
self.scope.obj_type.methods[node.name] = ftype
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(scope, decls)
visitor.visit(b)
b.decls = decls
if not scope.has_return:
rtype.unify(TY_NONE)
import ast
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict from typing import Dict, Optional
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl
from transpiler.phases.typing.types import BaseType, TypeVariable
from transpiler.phases.utils import NodeVisitorSeq from transpiler.phases.utils import NodeVisitorSeq
PRELUDE = Scope.make_global() PRELUDE = Scope.make_global()
...@@ -10,3 +13,9 @@ PRELUDE = Scope.make_global() ...@@ -10,3 +13,9 @@ PRELUDE = Scope.make_global()
class ScoperVisitor(NodeVisitorSeq): class ScoperVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL)) scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL))
root_decls: Dict[str, VarDecl] = field(default_factory=dict) root_decls: Dict[str, VarDecl] = field(default_factory=dict)
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope)
def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType:
return self.anno().visit(expr) if expr else TypeVariable()
\ No newline at end of file
...@@ -6,7 +6,7 @@ from typing import List ...@@ -6,7 +6,7 @@ from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \ from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType
DUNDER = { DUNDER = {
ast.Eq: "eq", ast.Eq: "eq",
...@@ -105,6 +105,10 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -105,6 +105,10 @@ class ScoperExprVisitor(ScoperVisitor):
return actual return actual
def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]): def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]):
if isinstance(ftype, UserType):
init: FunctionType = self.visit_getattr(ftype, "__init__")
ctor = FunctionType(init.args[1:], ftype)
return self.visit_function_call(ctor, arguments)
if not isinstance(ftype, FunctionType): if not isinstance(ftype, FunctionType):
raise IncompatibleTypesError(f"Cannot call {ftype}") raise IncompatibleTypesError(f"Cannot call {ftype}")
#is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list()) #is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list())
......
...@@ -53,6 +53,7 @@ class Scope: ...@@ -53,6 +53,7 @@ class Scope:
children: List["Scope"] = field(default_factory=list) children: List["Scope"] = field(default_factory=list)
obj_type: Optional[BaseType] = None obj_type: Optional[BaseType] = None
has_return: bool = False has_return: bool = False
class_: Optional["Scope"] = None
@staticmethod @staticmethod
def make_global(): def make_global():
......
...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod ...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from itertools import zip_longest from itertools import zip_longest
from typing import Dict, Optional, List, ClassVar, Callable, Any from typing import Dict, Optional, List, ClassVar, Callable
class IncompatibleTypesError(Exception): class IncompatibleTypesError(Exception):
...@@ -215,7 +215,6 @@ class TypeOperator(BaseType, ABC): ...@@ -215,7 +215,6 @@ class TypeOperator(BaseType, ABC):
return [self, *self.args] return [self, *self.args]
class FunctionType(TypeOperator): class FunctionType(TypeOperator):
def __init__(self, args: List[BaseType], ret: BaseType): def __init__(self, args: List[BaseType], ret: BaseType):
super().__init__([ret, *args]) super().__init__([ret, *args])
...@@ -374,3 +373,12 @@ class Future(Promise): ...@@ -374,3 +373,12 @@ class Future(Promise):
def __init__(self, ret: BaseType): def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.FUTURE) super().__init__(ret, PromiseKind.FUTURE)
class UserType(TypeOperator):
def __init__(self, name: str):
super().__init__([], name=name)
def unify_internal(self, other: "BaseType"):
if type(self) != type(other):
raise IncompatibleTypesError()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment