Commit 39b718cc authored by Tom Niget's avatar Tom Niget

Generic class and method emission

parent 9b8da153
......@@ -4,9 +4,22 @@ def f[T](x: T):
def g[X](a: X, b: X):
return a + b
class H:
def h[X](self, x: X):
return x
class Box[T]:
val: T
def __init__(self, val: T):
self.val = val
if __name__ == "__main__":
#a = 5
print(f("abc"))
print(f(6))
print(g(6, 8))
print(g("abc", "def"))
# print(g("abc", 213)) # expected error
print(H().h(6))
Box(6)
\ No newline at end of file
......@@ -48,7 +48,7 @@ def exception_hook(exc_type, exc_value, tb):
if last_node is not None and last_file is not None:
print()
if not hasattr(last_node, "lineno"):
print(cf.red("Error: "), cf.white("No line number available"))
print(cf.red("Error:"), cf.white("No line number available"))
last_node.lineno = 1
print(ast.unparse(last_node))
return
......
......@@ -2,7 +2,9 @@ import ast
from typing import Iterable
from transpiler.phases.emit_cpp.function import emit_function
from transpiler.phases.typing.types import ConcreteType
from transpiler.phases.emit_cpp.visitors import join, NodeVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.types import ConcreteType, TypeVariable, RuntimeValue
def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
......@@ -14,8 +16,22 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
# for stmt in node.body:
# yield from inner.visit(stmt)
yield f"struct Obj : referencemodel::instance<{name}__oo<>, Obj> {{"
if node.generic_parent.parameters:
yield "template<"
yield from join(",", (f"typename {p.name}" for p in node.generic_parent.parameters))
yield ">"
yield f"struct Obj : referencemodel::instance<{name}__oo<>, Obj"
if node.generic_parent.parameters:
yield "<"
yield from join(",", (p.name for p in node.generic_parent.parameters))
yield ">"
yield "> {"
for mname, mdef in node.fields.items():
if isinstance(mdef.val, RuntimeValue):
yield from NodeVisitor().visit_BaseType(mdef.type)
yield mname
yield ";"
# inner = ClassInnerVisitor4(node.inner_scope)
# for stmt in node.body:
......@@ -30,13 +46,19 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
yield "};"
for name, mdef in node.fields.items():
for mname, mdef in node.fields.items():
if isinstance(mdef.val, ast.FunctionDef):
yield from emit_function(name, mdef.type.deref(), "method")
gen_p = [TypeVariable(p.name, emit_as_is=True) for p in mdef.type.parameters]
ty = mdef.type.instantiate(gen_p)
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, [TypeVariable() for _ in ty.parameters])
yield from emit_function(mname, ty, "method")
yield "template <typename... T>"
yield "auto operator() (T&&... args) const {"
yield "return referencemodel::rc(Obj{std::forward<T>(args)...});"
yield "Obj obj;"
yield "dot(obj, __init__)(std::forward<T>(args)...);"
#yield "return referencemodel::rc(Obj(std::forward<T>(args)...));"
yield "return referencemodel::rc(obj);"
yield "}"
yield f"}};"
......
......@@ -10,18 +10,19 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType, TypeV
def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=None) -> Iterable[str]:
__TB_NODE__ = func.block_data.node
yield f"struct : referencemodel::{base} {{"
def emit_body(name: str, mode: CoroutineMode, rty):
#real_params = [p for p in func.generic_parent.parameters if not p.name.startswith("AutoVar$")]
real_params = func.generic_parent.parameters
if real_params:
if func.generic_parent.parameters:
yield "template<"
yield from join(",", (f"typename {p.name} = void" for p in real_params))
yield from join(",", (f"typename {p.name} = void" for p in func.generic_parent.parameters))
yield ">"
yield "auto"
yield name
yield "("
def emit_arg(arg, ty):
__TB_NODE__ = arg
if isinstance(ty, TypeVariable) and ty.emit_as_is:
yield ty.var_name
else:
......
......@@ -4,10 +4,12 @@ from typing import Iterable
from transpiler.phases.emit_cpp.class_ import emit_class
from transpiler.phases.emit_cpp.function import emit_function
from transpiler.phases.typing.modules import ModuleType
from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable, BaseType
from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable, BaseType, GenericType, \
GenericInstanceType, UserGenericType
def emit_module(mod: ModuleType) -> Iterable[str]:
__TB_NODE__ = mod.block_data.node
yield "#include <python/builtins.hpp>"
yield "#include <python/sys.hpp>"
incl_vars = []
......@@ -29,17 +31,22 @@ def emit_module(mod: ModuleType) -> Iterable[str]:
for name, field in mod.fields.items():
if not field.in_class_def:
continue
gen_p = [TypeVariable(p.name, emit_as_is=True) for p in field.type.parameters]
ty = field.type.instantiate(gen_p)
if isinstance(field.type, ClassTypeType):
ty = field.type.inner_type
else:
ty = field.type
gen_p = [TypeVariable(p.name, emit_as_is=True) for p in ty.parameters]
ty = ty.instantiate(gen_p)
from transpiler.phases.typing.expr import ScoperExprVisitor
x = 5
match ty:
case CallableInstanceType():
parameters_ = [TypeVariable() for _ in ty.parameters]
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, parameters_)
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, [TypeVariable() for _ in ty.parameters])
yield from emit_function(name, ty, gen_p=gen_p)
case ClassTypeType(inner_type):
yield from emit_class(name, inner_type)
case GenericInstanceType() if isinstance(ty.generic_parent, UserGenericType):
yield from emit_class(name, ty)
case _:
raise NotImplementedError(f"Unsupported module item type {ty}")
......
......@@ -8,7 +8,7 @@ from typing import Optional, List, Dict, Callable
from logging import debug
from transpiler.phases.typing.modules import parse_module
from transpiler.utils import highlight
from transpiler.utils import highlight, linenodata
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE, is_builtin
from transpiler.phases.typing.expr import ScoperExprVisitor
......@@ -129,11 +129,12 @@ class StdlibVisitor(NodeVisitorSeq):
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, mod)
def visit_ClassDef(self, node: ast.ClassDef):
force_generic = not self.is_native
if existing := self.scope.get(node.name):
assert isinstance(existing.type, ClassTypeType)
NewType = existing.type.inner_type
else:
if node.type_params:
if node.type_params or force_generic:
base_class, base_type = create_builtin_generic_type, (BuiltinGenericType if self.is_native else UserGenericType)
else:
base_class, base_type = create_builtin_type, (BuiltinType if self.is_native else UserType)
......@@ -158,8 +159,26 @@ class StdlibVisitor(NodeVisitorSeq):
raise NotImplementedError("parents not handled yet: " + ", ".join(map(ast.unparse, node.bases)))
for stmt in node.body:
visitor.visit(stmt)
visit_generic_item(visit_nongeneric, node, NewType, self.scope)
if "__init__" not in output.fields:
visitor.visit(ast.FunctionDef(
name="__init__",
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg="self", annotation=None)],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[]
),
body=[ast.Pass()],
decorator_list=[],
returns=None,
type_params=[],
**linenodata(node)
))
visit_generic_item(visit_nongeneric, node, NewType, self.scope, force_generic=force_generic)
def visit_Pass(self, node: ast.Pass):
pass
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment