Commit c9cb938b authored by Tom Niget's avatar Tom Niget

Fork Sync works

parent 19d31953
...@@ -399,7 +399,7 @@ concept HasSync = requires(T t) { typename T::has_sync; }; ...@@ -399,7 +399,7 @@ concept HasSync = requires(T t) { typename T::has_sync; };
auto call_sync(auto f) { auto call_sync(auto f) {
if constexpr (HasSync<decltype(f)>) { if constexpr (HasSync<decltype(f)>) {
return [f](auto... args) { return [f](auto... args) {
return f.sync(std::forward<decltype(args)>(args)...); return f.typon$$sync(std::forward<decltype(args)>(args)...);
}; };
} else { } else {
return f; return f;
......
from typon import fork, sync
def fibo(n): def fibo(n):
if n < 2: if n < 2:
...@@ -8,46 +7,47 @@ def fibo(n): ...@@ -8,46 +7,47 @@ def fibo(n):
sync() sync()
return a.get() + b.get() return a.get() + b.get()
"""
def fibo(n: int) -> int:
if n < 2:
return n
with sync(): # {
a = fork(lambda: fibo(n - 1))
b = fork(lambda: fibo(n - 2))
# }
return a + b
"""
"""
Task<int> fibo(int n) {
if (n < 2) {
return n;
}
Forked<int> a;
Forked<int> b;
{
a = fork(fibo(n - 1));
// cvcvc
b = fork(fibo(n - 2));
co_await sync();
}
co_return a.get() + b.get();
"""
""" # """
Task<int> fibo(int n) { # def fibo(n: int) -> int:
int a, b; # if n < 2:
co_return []() -> Join<int> { # return n
if (n < 2) { # with sync(): # {
return n; # a = fork(lambda: fibo(n - 1))
} # b = fork(lambda: fibo(n - 2))
co_await fork(fibo(n - 1), a); # # }
co_await fork(fibo(n - 2), b); # return a + b
co_await Sync(); # """
co_return a + b; # """
}(); # Task<int> fibo(int n) {
} # if (n < 2) {
""" # return n;
# }
# Forked<int> a;
# Forked<int> b;
# {
# a = fork(fibo(n - 1));
# // cvcvc
# b = fork(fibo(n - 2));
# co_await sync();
# }
# co_return a.get() + b.get();
# """
#
# """
# Task<int> fibo(int n) {
# int a, b;
# co_return []() -> Join<int> {
# if (n < 2) {
# return n;
# }
# co_await fork(fibo(n - 1), a);
# co_await fork(fibo(n - 2), b);
# co_await Sync();
# co_return a + b;
# }();
# }
# """
......
...@@ -5,7 +5,7 @@ from typing import Iterable ...@@ -5,7 +5,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp.visitors import NodeVisitor, CoroutineMode, join from transpiler.phases.emit_cpp.visitors import NodeVisitor, CoroutineMode, join
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import ClassTypeType, TupleInstanceType, TY_FUTURE, ResolvedConcreteType from transpiler.phases.typing.types import ClassTypeType, TupleInstanceType, TY_FUTURE, ResolvedConcreteType, TY_FORKED
from transpiler.phases.utils import make_lnd from transpiler.phases.utils import make_lnd
from transpiler.utils import linenodata from transpiler.utils import linenodata
...@@ -138,31 +138,46 @@ class ExpressionVisitor(NodeVisitor): ...@@ -138,31 +138,46 @@ class ExpressionVisitor(NodeVisitor):
yield from self.visit(arg.body) yield from self.visit(arg.body)
return return
if isinstance(node.func, ast.Name) and node.func.id == "sync":
if self.generator != CoroutineMode.SYNC:
yield "co_await typon::Sync()"
else:
yield "(void)0"
return
is_get = isinstance(node.func, ast.Attribute) and node.func.attr == "get"
# async : co_await f(args) # async : co_await f(args)
# sync : call_sync(f, args) # sync : call_sync(f, args)
if self.generator != CoroutineMode.SYNC: if self.generator != CoroutineMode.SYNC:
nty = node.type.resolve() nty = node.type.resolve()
if not (isinstance(nty, ResolvedConcreteType) and nty.inherits(TY_FUTURE)): if isinstance(nty, ResolvedConcreteType) and (
yield "co_await" nty.inherits(TY_FUTURE) or (
is_get and nty.inherits(TY_FORKED)
)
):
pass
else: else:
yield "call_sync" yield "co_await"
if isinstance(node.func, ast.Attribute) and node.func.attr == "get" and node.func.value.type.inherits(TY_FUTURE):
yield "("
if self.generator == CoroutineMode.SYNC:
yield from self.visit(node.func.value)
else: else:
yield "(" if is_get and node.func.value.type.inherits(TY_FUTURE, TY_FORKED):
yield from self.visit(node.func.value) yield from self.visit(node.func.value)
yield ").get()"
yield ")"
return return
yield "call_sync"
yield "(" yield "("
if is_get and node.func.value.type.inherits(TY_FUTURE, TY_FORKED):
yield "("
yield from self.visit(node.func.value)
yield ").get"
else:
yield from self.visit(node.func) yield from self.visit(node.func)
yield ")(" yield ")("
yield from join(", ", map(self.visit, node.args)) yield from join(", ", map(self.visit, node.args))
yield ")" yield ")"
......
...@@ -59,13 +59,13 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p= ...@@ -59,13 +59,13 @@ def emit_function(name: str, func: CallableInstanceType, base="function", gen_p=
try: try:
rty_code = " ".join(NodeVisitor().visit_BaseType(func.return_type)) rty_code = " ".join(NodeVisitor().visit_BaseType(func.return_type))
except: except:
yield from emit_body("sync", CoroutineMode.SYNC, None) yield from emit_body("typon$$sync", CoroutineMode.SYNC, None)
has_sync = True has_sync = True
yield "using has_sync = std::true_type;" yield "using has_sync = std::true_type;"
def task_type(): def task_type():
yield from NodeVisitor().visit_BaseType(func.return_type.generic_parent) yield from NodeVisitor().visit_BaseType(func.return_type.generic_parent)
yield "<" yield "<"
yield"decltype(sync(" yield"decltype(typon$$sync("
yield from join(",", (arg.arg for arg in func.block_data.node.args.args)) yield from join(",", (arg.arg for arg in func.block_data.node.args.args))
yield "))" yield "))"
yield ">" yield ">"
......
...@@ -197,8 +197,8 @@ class ResolvedConcreteType(ConcreteType): ...@@ -197,8 +197,8 @@ class ResolvedConcreteType(ConcreteType):
return [self] + merge(*[p.get_mro() for p in self.parents], self.parents) return [self] + merge(*[p.get_mro() for p in self.parents], self.parents)
def inherits(self, parent: BaseType): def inherits(self, *parent: BaseType):
return self == parent or any(p.inherits(parent) for p in self.parents) return self in parent or any(p.inherits(*parent) for p in self.parents)
def try_assign_internal(self, other: BaseType) -> bool: def try_assign_internal(self, other: BaseType) -> bool:
if self == other: if self == other:
...@@ -264,8 +264,8 @@ class GenericInstanceType(ResolvedConcreteType): ...@@ -264,8 +264,8 @@ class GenericInstanceType(ResolvedConcreteType):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def inherits(self, parent: BaseType): def inherits(self, *parent: BaseType):
return self.generic_parent == parent or super().inherits(parent) return self.generic_parent in parent or super().inherits(*parent)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, GenericInstanceType): if isinstance(other, GenericInstanceType):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment