Commit 6e8e1738 authored by Tom Niget's avatar Tom Niget

Handle recursive functions

parent 39b718cc
...@@ -386,9 +386,7 @@ using InterpGuard = py::scoped_interpreter; ...@@ -386,9 +386,7 @@ using InterpGuard = py::scoped_interpreter;
#endif #endif
template <typename T> template <typename T>
concept HasSync = requires(T t) { concept HasSync = requires(T t) { typename T::has_sync; };
{ t.sync() } -> std::same_as<T>;
};
/*auto call_sync(auto f, auto... args) { /*auto call_sync(auto f, auto... args) {
if constexpr (HasSync<decltype(f)>) { if constexpr (HasSync<decltype(f)>) {
......
...@@ -43,7 +43,9 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p= ...@@ -43,7 +43,9 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=
yield ")>::type" yield ")>::type"
yield var yield var
yield ";" yield ";"
yield from BlockVisitor(func.block_data.scope, generator=mode).visit(func.block_data.node.body) vis = BlockVisitor(func.block_data.scope, generator=mode)
for stmt in func.block_data.node.body:
yield from vis.visit(stmt)
if not getattr(func.block_data.scope, "has_return", False): if not getattr(func.block_data.scope, "has_return", False):
if mode == CoroutineMode.SYNC: if mode == CoroutineMode.SYNC:
yield "return" yield "return"
...@@ -53,10 +55,13 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p= ...@@ -53,10 +55,13 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=
yield "}" yield "}"
rty = func.return_type.generic_args[0] rty = func.return_type.generic_args[0]
has_sync = False
try: try:
rty_code = " ".join(NodeVisitor().visit_BaseType(func.return_type)) rty_code = " ".join(NodeVisitor().visit_BaseType(func.return_type))
except: except:
yield from emit_body("sync", CoroutineMode.SYNC, None) yield from emit_body("sync", CoroutineMode.SYNC, None)
has_sync = True
yield "using has_sync = std::true_type;"
def task_type(): def task_type():
yield from NodeVisitor().visit_BaseType(func.return_type.generic_parent) yield from NodeVisitor().visit_BaseType(func.return_type.generic_parent)
yield "<" yield "<"
...@@ -70,6 +75,8 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p= ...@@ -70,6 +75,8 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=
yield from emit_body("operator()", CoroutineMode.TASK, rty_code) yield from emit_body("operator()", CoroutineMode.TASK, rty_code)
yield f"}} static constexpr {name} {{}};" yield f"}} static constexpr {name} {{}};"
if has_sync:
yield f"static_assert(HasSync<decltype({name})>);"
yield f"static_assert(sizeof {name} == 1);" yield f"static_assert(sizeof {name} == 1);"
......
...@@ -43,6 +43,7 @@ def emit_module(mod: ModuleType) -> Iterable[str]: ...@@ -43,6 +43,7 @@ def emit_module(mod: ModuleType) -> Iterable[str]:
x = 5 x = 5
match ty: match ty:
case CallableInstanceType(): case CallableInstanceType():
ty.generic_parent.instance_cache = []
ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, [TypeVariable() for _ in ty.parameters]) ScoperExprVisitor(ty.block_data.scope).visit_function_call(ty, [TypeVariable() for _ in ty.parameters])
yield from emit_function(name, ty, gen_p=gen_p) yield from emit_function(name, ty, gen_p=gen_p)
case GenericInstanceType() if isinstance(ty.generic_parent, UserGenericType): case GenericInstanceType() if isinstance(ty.generic_parent, UserGenericType):
......
...@@ -168,16 +168,20 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -168,16 +168,20 @@ class ScoperExprVisitor(ScoperVisitor):
ftype.block_data.scope.declare_local(pname, b) ftype.block_data.scope.declare_local(pname, b)
if not ftype.is_native: if not ftype.is_native:
from transpiler.phases.typing.block import ScoperBlockVisitor existing = ftype.generic_parent.find_cached_instance(ftype.generic_args)
scope = ftype.block_data.scope if not existing:
vis = ScoperBlockVisitor(scope) ftype.generic_parent.cache_instance(ftype.generic_args, ftype)
for stmt in ftype.block_data.node.body: from transpiler.phases.typing.block import ScoperBlockVisitor
vis.visit(stmt) scope = ftype.block_data.scope
if not getattr(scope.function, "has_return", False): vis = ScoperBlockVisitor(scope)
stmt = ast.Return() for stmt in ftype.block_data.node.body:
ftype.block_data.node.body.append(stmt) vis.visit(stmt)
vis.visit(stmt) if not getattr(scope.function, "has_return", False):
#ftype.generic_parent.cache_instance(ftype) stmt = ast.Return()
ftype.block_data.node.body.append(stmt)
vis.visit(stmt)
else:
return existing.return_type.resolve()
return ftype.return_type.resolve() return ftype.return_type.resolve()
# if isinstance(ftype, TypeType):# and isinstance(ftype.type_object, UserType): # if isinstance(ftype, TypeType):# and isinstance(ftype.type_object, UserType):
# init: FunctionType = self.visit_getattr(ftype, "__init__").remove_self() # init: FunctionType = self.visit_getattr(ftype, "__init__").remove_self()
......
...@@ -320,7 +320,7 @@ class GenericConstraint: ...@@ -320,7 +320,7 @@ class GenericConstraint:
@dataclass(eq=False) @dataclass(eq=False)
class GenericType(BaseType): class GenericType(BaseType):
parameters: list[GenericParameter] = field(default_factory=list, init=False) parameters: list[GenericParameter] = field(default_factory=list, init=False)
instance_cache: dict[object, GenericInstanceType] = field(default_factory=dict, init=False) instance_cache: list[(object, GenericInstanceType)] = field(default_factory=list, init=False)
def constraints(self, args: list[ConcreteType]) -> list[GenericConstraint]: def constraints(self, args: list[ConcreteType]) -> list[GenericConstraint]:
return [] return []
...@@ -348,11 +348,17 @@ class GenericType(BaseType): ...@@ -348,11 +348,17 @@ class GenericType(BaseType):
def deref(self): def deref(self):
return self.instantiate_default().deref() return self.instantiate_default().deref()
def cache_instance(self, instance): def find_cached_instance(self, args):
if not hasattr(self, "instance_cache"): for inst_args, inst in self.instance_cache:
self.instance_cache = {} if all(inst_arg.try_assign(arg) for inst_arg, arg in zip(inst_args, args)):
self.instance_cache[tuple(instance.generic_args)] = instance return inst
return None
def cache_instance(self, args, instance):
if not hasattr(self, "instance_cache"):
self.instance_cache = []
if not self.find_cached_instance(args):
self.instance_cache.append((tuple(args), instance))
@dataclass(eq=False, init=False) @dataclass(eq=False, init=False)
class BuiltinGenericType(UniqueTypeMixin, GenericType): class BuiltinGenericType(UniqueTypeMixin, GenericType):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment