Commit 29793d0f authored by Tom Niget's avatar Tom Niget

Generic functions works

parent ed9984de
def f[T](x: T):
return x
if __name__ == "__main__":
#a = 5
print(f("abc"))
print(f(6))
import ast import ast
from typing import Iterable from typing import Iterable
from transpiler.phases.emit_cpp.function import emit_function
from transpiler.phases.typing.types import ConcreteType from transpiler.phases.typing.types import ConcreteType
...@@ -15,6 +16,7 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]: ...@@ -15,6 +16,7 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
yield f"struct Obj : referencemodel::instance<{name}__oo<>, Obj> {{" yield f"struct Obj : referencemodel::instance<{name}__oo<>, Obj> {{"
# inner = ClassInnerVisitor4(node.inner_scope) # inner = ClassInnerVisitor4(node.inner_scope)
# for stmt in node.body: # for stmt in node.body:
# yield from inner.visit(stmt) # yield from inner.visit(stmt)
...@@ -27,6 +29,11 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]: ...@@ -27,6 +29,11 @@ def emit_class(name: str, node: ConcreteType) -> Iterable[str]:
yield "};" yield "};"
for name, mdef in node.fields.items():
if isinstance(mdef.val, ast.FunctionDef):
yield from emit_function(name, mdef.type.deref(), "method")
yield "template <typename... T>" yield "template <typename... T>"
yield "auto operator() (T&&... args) const {" yield "auto operator() (T&&... args) const {"
yield "return referencemodel::rc(Obj{std::forward<T>(args)...});" yield "return referencemodel::rc(Obj{std::forward<T>(args)...});"
......
...@@ -6,16 +6,26 @@ from transpiler.phases.emit_cpp.expr import ExpressionVisitor ...@@ -6,16 +6,26 @@ from transpiler.phases.emit_cpp.expr import ExpressionVisitor
from transpiler.phases.typing.common import IsDeclare from transpiler.phases.typing.common import IsDeclare
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp.visitors import NodeVisitor, flatmap, CoroutineMode, join from transpiler.phases.emit_cpp.visitors import NodeVisitor, flatmap, CoroutineMode, join
from transpiler.phases.typing.types import CallableInstanceType, BaseType from transpiler.phases.typing.types import CallableInstanceType, BaseType, TypeVariable
def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]: def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=None) -> Iterable[str]:
yield f"struct : referencemodel::function {{" yield f"struct : referencemodel::{base} {{"
def emit_body(name: str, mode: CoroutineMode, rty): def emit_body(name: str, mode: CoroutineMode, rty):
#real_params = [p for p in func.generic_parent.parameters if not p.name.startswith("AutoVar$")]
real_params = func.generic_parent.parameters
if real_params:
yield "template<"
yield from join(",", (f"typename {p.name} = void" for p in real_params))
yield ">"
yield "auto" yield "auto"
yield name yield name
yield "(" yield "("
def emit_arg(arg, ty): def emit_arg(arg, ty):
if isinstance(ty, TypeVariable) and ty.emit_as_is:
yield ty.var_name
else:
raise NotImplementedError("can this happen?")
yield "auto" yield "auto"
yield arg.arg yield arg.arg
......
...@@ -4,7 +4,7 @@ from typing import Iterable ...@@ -4,7 +4,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp.class_ import emit_class from transpiler.phases.emit_cpp.class_ import emit_class
from transpiler.phases.emit_cpp.function import emit_function from transpiler.phases.emit_cpp.function import emit_function
from transpiler.phases.typing.modules import ModuleType from transpiler.phases.typing.modules import ModuleType
from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable from transpiler.phases.typing.types import CallableInstanceType, ClassTypeType, TypeVariable, BaseType
def emit_module(mod: ModuleType) -> Iterable[str]: def emit_module(mod: ModuleType) -> Iterable[str]:
...@@ -29,13 +29,15 @@ def emit_module(mod: ModuleType) -> Iterable[str]: ...@@ -29,13 +29,15 @@ def emit_module(mod: ModuleType) -> Iterable[str]:
for name, field in mod.fields.items(): for name, field in mod.fields.items():
if not field.in_class_def: if not field.in_class_def:
continue continue
ty = field.type.deref() gen_p = [TypeVariable(p.name, emit_as_is=True) for p in field.type.parameters]
ty = field.type.instantiate(gen_p)
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
x = 5 x = 5
match ty: match ty:
case CallableInstanceType(): case CallableInstanceType():
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, [TypeVariable() for _ in ty.parameters]) parameters_ = [TypeVariable() for _ in ty.parameters]
yield from emit_function(name, ty) ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, parameters_)
yield from emit_function(name, ty, gen_p=gen_p)
case ClassTypeType(inner_type): case ClassTypeType(inner_type):
yield from emit_class(name, inner_type) yield from emit_class(name, inner_type)
case _: case _:
......
...@@ -71,7 +71,10 @@ class NodeVisitor(UniversalVisitor): ...@@ -71,7 +71,10 @@ class NodeVisitor(UniversalVisitor):
yield "typon::TyNone" yield "typon::TyNone"
case types.TY_STR: case types.TY_STR:
yield 'decltype(""_ps)' yield 'decltype(""_ps)'
case types.TypeVariable(name): case types.TypeVariable(name, emit_as_is=em):
if em:
yield name
else:
yield f"$VAR__{name}" yield f"$VAR__{name}"
#raise UnresolvedTypeVariableError(node) #raise UnresolvedTypeVariableError(node)
......
import ast import ast
import copy
import dataclasses import dataclasses
from abc import ABCMeta from abc import ABCMeta
from dataclasses import dataclass, field from dataclasses import dataclass, field
...@@ -145,7 +146,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -145,7 +146,7 @@ class StdlibVisitor(NodeVisitorSeq):
def visit_nongeneric(scope: Scope, output: ResolvedConcreteType): def visit_nongeneric(scope: Scope, output: ResolvedConcreteType):
cl_scope = scope.child(ScopeKind.CLASS) cl_scope = scope.child(ScopeKind.CLASS)
cl_scope.declare_local("Self", output.type_type()) cl_scope.declare_local("Self", output.type_type())
output.block_data = BlockData(node, scope) output.block_data = BlockData(copy.deepcopy(node), scope)
visitor = StdlibVisitor(self.python_path, cl_scope, output, self.is_native) visitor = StdlibVisitor(self.python_path, cl_scope, output, self.is_native)
bases = [self.anno().visit(base) for base in node.bases] bases = [self.anno().visit(base) for base in node.bases]
match bases: match bases:
...@@ -169,8 +170,10 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -169,8 +170,10 @@ class StdlibVisitor(NodeVisitorSeq):
scope.function = scope scope.function = scope
scope.obj_type = output scope.obj_type = output
arg_visitor = TypeAnnotationVisitor(scope) arg_visitor = TypeAnnotationVisitor(scope)
output.block_data = BlockData(node, scope) output.block_data = BlockData(copy.deepcopy(node), scope)
output.parameters = [arg_visitor.visit(arg.annotation) for arg in node.args.args] output.parameters = [arg_visitor.visit(arg.annotation) for arg in node.args.args]
for arg, ty in zip(node.args.args, output.parameters):
scope.declare_local(arg.arg, ty)
output.return_type = arg_visitor.visit(node.returns) output.return_type = arg_visitor.visit(node.returns)
output.optional_at = len(node.args.args) - len(node.args.defaults) output.optional_at = len(node.args.args) - len(node.args.defaults)
output.is_variadic = args.vararg is not None output.is_variadic = args.vararg is not None
...@@ -199,7 +202,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -199,7 +202,7 @@ class StdlibVisitor(NodeVisitorSeq):
if i == 0 and self.cur_class is not None: if i == 0 and self.cur_class is not None:
arg_name = "Self" arg_name = "Self"
else: else:
arg_name = f"AutoVar${hash(arg.arg)}" arg_name = f"AutoVar${abs(hash(arg.arg))}"
node.type_params.append(ast.TypeVar(arg_name, None)) # todo: bounds node.type_params.append(ast.TypeVar(arg_name, None)) # todo: bounds
arg.annotation = ast.Name(arg_name, ast.Load()) arg.annotation = ast.Name(arg_name, ast.Load())
else: else:
...@@ -210,7 +213,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -210,7 +213,7 @@ class StdlibVisitor(NodeVisitorSeq):
# annotation is type variable so we keep it # annotation is type variable so we keep it
pass pass
else: else:
arg_name = f"AutoBoundedVar${hash(arg.arg)}" arg_name = f"AutoBoundedVar${abs(hash(arg.arg))}"
node.type_params.append(ast.TypeVar(arg_name, arg.annotation)) node.type_params.append(ast.TypeVar(arg_name, arg.annotation))
arg.annotation = ast.Name(arg_name, ast.Load()) arg.annotation = ast.Name(arg_name, ast.Load())
......
...@@ -118,6 +118,7 @@ class ConcreteType(BaseType): ...@@ -118,6 +118,7 @@ class ConcreteType(BaseType):
class TypeVariable(ConcreteType): class TypeVariable(ConcreteType):
var_name: str = field(default_factory=lambda: next_var_id()) var_name: str = field(default_factory=lambda: next_var_id())
resolved: Optional[ConcreteType] = None resolved: Optional[ConcreteType] = None
emit_as_is: bool = False
def resolve(self) -> ConcreteType: def resolve(self) -> ConcreteType:
if self.resolved is None: if self.resolved is None:
...@@ -557,11 +558,13 @@ class CallableInstanceType(GenericInstanceType, MethodType): ...@@ -557,11 +558,13 @@ class CallableInstanceType(GenericInstanceType, MethodType):
def remove_self(self, self_type): def remove_self(self, self_type):
assert self.parameters[0].try_assign(self_type) assert self.parameters[0].try_assign(self_type)
return dataclasses.replace( res = dataclasses.replace(
self, self,
parameters=self.parameters[1:], parameters=self.parameters[1:],
optional_at=self.optional_at - 1, optional_at=self.optional_at - 1
) )
res.block_data = self.block_data
return res
def __str__(self): def __str__(self):
return f"({", ".join(map(str, self.parameters + (["*args"] if self.is_variadic else [])))}) -> {self.return_type}" return f"({", ".join(map(str, self.parameters + (["*args"] if self.is_variadic else [])))}) -> {self.return_type}"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment