Commit 39b718cc authored by Tom Niget's avatar Tom Niget

Generic class and method emission

parent 9b8da153
...@@ -4,9 +4,22 @@ def f[T](x: T): ...@@ -4,9 +4,22 @@ def f[T](x: T):
def g[X](a: X, b: X): def g[X](a: X, b: X):
return a + b return a + b
class H:
def h[X](self, x: X):
return x
class Box[T]:
val: T
def __init__(self, val: T):
self.val = val
if __name__ == "__main__": if __name__ == "__main__":
#a = 5
print(f("abc")) print(f("abc"))
print(f(6)) print(f(6))
print(g(6, 8)) print(g(6, 8))
print(g("abc", "def")) print(g("abc", "def"))
# print(g("abc", 213)) # expected error
print(H().h(6))
Box(6)
\ No newline at end of file
...@@ -48,7 +48,7 @@ def exception_hook(exc_type, exc_value, tb): ...@@ -48,7 +48,7 @@ def exception_hook(exc_type, exc_value, tb):
if last_node is not None and last_file is not None: if last_node is not None and last_file is not None:
print() print()
if not hasattr(last_node, "lineno"): if not hasattr(last_node, "lineno"):
print(cf.red("Error: "), cf.white("No line number available")) print(cf.red("Error:"), cf.white("No line number available"))
last_node.lineno = 1 last_node.lineno = 1
print(ast.unparse(last_node)) print(ast.unparse(last_node))
return return
......
...@@ -2,7 +2,9 @@ import ast ...@@ -2,7 +2,9 @@ import ast
from typing import Iterable from typing import Iterable
from transpiler.phases.emit_cpp.function import emit_function from transpiler.phases.emit_cpp.function import emit_function
from transpiler.phases.typing.types import ConcreteType from transpiler.phases.emit_cpp.visitors import join, NodeVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.types import ConcreteType, TypeVariable, RuntimeValue
def emit_class(name: str, node: ConcreteType) -> Iterable[str]: def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
...@@ -14,8 +16,22 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]: ...@@ -14,8 +16,22 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
# for stmt in node.body: # for stmt in node.body:
# yield from inner.visit(stmt) # yield from inner.visit(stmt)
yield f"struct Obj : referencemodel::instance<{name}__oo<>, Obj> {{" if node.generic_parent.parameters:
yield "template<"
yield from join(",", (f"typename {p.name}" for p in node.generic_parent.parameters))
yield ">"
yield f"struct Obj : referencemodel::instance<{name}__oo<>, Obj"
if node.generic_parent.parameters:
yield "<"
yield from join(",", (p.name for p in node.generic_parent.parameters))
yield ">"
yield "> {"
for mname, mdef in node.fields.items():
if isinstance(mdef.val, RuntimeValue):
yield from NodeVisitor().visit_BaseType(mdef.type)
yield mname
yield ";"
# inner = ClassInnerVisitor4(node.inner_scope) # inner = ClassInnerVisitor4(node.inner_scope)
# for stmt in node.body: # for stmt in node.body:
...@@ -30,13 +46,19 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]: ...@@ -30,13 +46,19 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
yield "};" yield "};"
for name, mdef in node.fields.items(): for mname, mdef in node.fields.items():
if isinstance(mdef.val, ast.FunctionDef): if isinstance(mdef.val, ast.FunctionDef):
yield from emit_function(name, mdef.type.deref(), "method") gen_p = [TypeVariable(p.name, emit_as_is=True) for p in mdef.type.parameters]
ty = mdef.type.instantiate(gen_p)
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, [TypeVariable() for _ in ty.parameters])
yield from emit_function(mname, ty, "method")
yield "template <typename... T>" yield "template <typename... T>"
yield "auto operator() (T&&... args) const {" yield "auto operator() (T&&... args) const {"
yield "return referencemodel::rc(Obj{std::forward<T>(args)...});" yield "Obj obj;"
yield "dot(obj, __init__)(std::forward<T>(args)...);"
#yield "return referencemodel::rc(Obj(std::forward<T>(args)...));"
yield "return referencemodel::rc(obj);"
yield "}" yield "}"
yield f"}};" yield f"}};"
......
...@@ -10,18 +10,19 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType, TypeV ...@@ -10,18 +10,19 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType, TypeV
def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=None) -> Iterable[str]: def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=None) -> Iterable[str]:
__TB_NODE__ = func.block_data.node
yield f"struct : referencemodel::{base} {{" yield f"struct : referencemodel::{base} {{"
def emit_body(name: str, mode: CoroutineMode, rty): def emit_body(name: str, mode: CoroutineMode, rty):
#real_params = [p for p in func.generic_parent.parameters if not p.name.startswith("AutoVar$")] #real_params = [p for p in func.generic_parent.parameters if not p.name.startswith("AutoVar$")]
real_params = func.generic_parent.parameters if func.generic_parent.parameters:
if real_params:
yield "template<" yield "template<"
yield from join(",", (f"typename {p.name} = void" for p in real_params)) yield from join(",", (f"typename {p.name} = void" for p in func.generic_parent.parameters))
yield ">" yield ">"
yield "auto" yield "auto"
yield name yield name
yield "(" yield "("
def emit_arg(arg, ty): def emit_arg(arg, ty):
__TB_NODE__ = arg
if isinstance(ty, TypeVariable) and ty.emit_as_is: if isinstance(ty, TypeVariable) and ty.emit_as_is:
yield ty.var_name yield ty.var_name
else: else:
......
...@@ -4,10 +4,12 @@ from typing import Iterable ...@@ -4,10 +4,12 @@ from typing import Iterable
from transpiler.phases.emit_cpp.class_ import emit_class from transpiler.phases.emit_cpp.class_ import emit_class
from transpiler.phases.emit_cpp.function import emit_function from transpiler.phases.emit_cpp.function import emit_function
from transpiler.phases.typing.modules import ModuleType from transpiler.phases.typing.modules import ModuleType
from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable, BaseType from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable, BaseType, GenericType, \
GenericInstanceType, UserGenericType
def emit_module(mod: ModuleType) -> Iterable[str]: def emit_module(mod: ModuleType) -> Iterable[str]:
__TB_NODE__ = mod.block_data.node
yield "#include <python/builtins.hpp>" yield "#include <python/builtins.hpp>"
yield "#include <python/sys.hpp>" yield "#include <python/sys.hpp>"
incl_vars = [] incl_vars = []
...@@ -29,17 +31,22 @@ def emit_module(mod: ModuleType) -> Iterable[str]: ...@@ -29,17 +31,22 @@ def emit_module(mod: ModuleType) -> Iterable[str]:
for name, field in mod.fields.items(): for name, field in mod.fields.items():
if not field.in_class_def: if not field.in_class_def:
continue continue
gen_p = [TypeVariable(p.name, emit_as_is=True) for p in field.type.parameters]
ty = field.type.instantiate(gen_p) if isinstance(field.type, ClassTypeType):
ty = field.type.inner_type
else:
ty = field.type
gen_p = [TypeVariable(p.name, emit_as_is=True) for p in ty.parameters]
ty = ty.instantiate(gen_p)
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
x = 5 x = 5
match ty: match ty:
case CallableInstanceType(): case CallableInstanceType():
parameters_ = [TypeVariable() for _ in ty.parameters] ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, [TypeVariable() for _ in ty.parameters])
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, parameters_)
yield from emit_function(name, ty, gen_p=gen_p) yield from emit_function(name, ty, gen_p=gen_p)
case ClassTypeType(inner_type): case GenericInstanceType() if isinstance(ty.generic_parent, UserGenericType):
yield from emit_class(name, inner_type) yield from emit_class(name, ty)
case _: case _:
raise NotImplementedError(f"Unsupported module item type {ty}") raise NotImplementedError(f"Unsupported module item type {ty}")
......
...@@ -8,7 +8,7 @@ from typing import Optional, List, Dict, Callable ...@@ -8,7 +8,7 @@ from typing import Optional, List, Dict, Callable
from logging import debug from logging import debug
from transpiler.phases.typing.modules import parse_module from transpiler.phases.typing.modules import parse_module
from transpiler.utils import highlight from transpiler.utils import highlight, linenodata
from transpiler.phases.typing.annotations import TypeAnnotationVisitor from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE, is_builtin from transpiler.phases.typing.common import PRELUDE, is_builtin
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
...@@ -129,11 +129,12 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -129,11 +129,12 @@ class StdlibVisitor(NodeVisitorSeq):
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, mod) self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, mod)
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
force_generic = not self.is_native
if existing := self.scope.get(node.name): if existing := self.scope.get(node.name):
assert isinstance(existing.type, ClassTypeType) assert isinstance(existing.type, ClassTypeType)
NewType = existing.type.inner_type NewType = existing.type.inner_type
else: else:
if node.type_params: if node.type_params or force_generic:
base_class, base_type = create_builtin_generic_type, (BuiltinGenericType if self.is_native else UserGenericType) base_class, base_type = create_builtin_generic_type, (BuiltinGenericType if self.is_native else UserGenericType)
else: else:
base_class, base_type = create_builtin_type, (BuiltinType if self.is_native else UserType) base_class, base_type = create_builtin_type, (BuiltinType if self.is_native else UserType)
...@@ -158,8 +159,26 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -158,8 +159,26 @@ class StdlibVisitor(NodeVisitorSeq):
raise NotImplementedError("parents not handled yet: " + ", ".join(map(ast.unparse, node.bases))) raise NotImplementedError("parents not handled yet: " + ", ".join(map(ast.unparse, node.bases)))
for stmt in node.body: for stmt in node.body:
visitor.visit(stmt) visitor.visit(stmt)
if "__init__" not in output.fields:
visit_generic_item(visit_nongeneric, node, NewType, self.scope) visitor.visit(ast.FunctionDef(
name="__init__",
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg="self", annotation=None)],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[]
),
body=[ast.Pass()],
decorator_list=[],
returns=None,
type_params=[],
**linenodata(node)
))
visit_generic_item(visit_nongeneric, node, NewType, self.scope, force_generic=force_generic)
def visit_Pass(self, node: ast.Pass): def visit_Pass(self, node: ast.Pass):
pass pass
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment