Commit 8acf1377 authored by Tom Niget's avatar Tom Niget

Fix recursion and add proper discrimination of sync/async functions

parent 72920c17
......@@ -52,16 +52,10 @@ concept PyNext = requires(T t) {
struct {
template <PyNext T>
std::optional<typename T::value_type>
sync(T &t, std::optional<typename T::value_type> def = std::nullopt) {
operator()(T &t, std::optional<typename T::value_type> def = std::nullopt) {
auto opt = t.py_next();
return opt ? opt : def;
}
template <PyNext T>
auto operator()(T &t, std::optional<typename T::value_type> def = std::nullopt)
-> typon::Task<decltype(sync(t, def))> {
co_return sync(t, def);
}
} next;
template <typename T>
......@@ -69,7 +63,7 @@ std::ostream &operator<<(std::ostream &os, std::optional<T> const &opt) {
return opt ? os << opt.value() : os << "None";
}
typon::Task<bool> is_cpp() { co_return true; }
bool is_cpp() { return true; }
class NoneType {
public:
......
......@@ -50,19 +50,14 @@ typon::Task<void> print(T const &head, Args const &...args) {
}*/
struct {
void sync() { std::cout << '\n'; }
void operator()() { std::cout << '\n'; }
template <Printable T, Printable... Args>
void sync(T const &head, Args const &...args) {
void operator()(T const &head, Args const &...args) {
print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n';
}
template<Printable... Args>
typon::Task<void> operator()(Args const &...args) {
co_return sync(args...);
}
} print;
//typon::Task<void> print() { std::cout << '\n'; co_return; }
#endif // TYPON_PRINT_HPP
......@@ -28,6 +28,8 @@ def run_tests():
]
if alt := environ.get("ALT_RUNNER"):
commands.append(alt.format(name_bin=name_bin, name_cpp_posix=name_cpp.as_posix()))
else:
print("no ALT_RUNNER")
for cmd in commands:
if system(cmd) != 0:
print(f"Error running command: {cmd}")
......
......@@ -7,17 +7,17 @@ def fibo(n: int) -> int:
b = fibo(n - 2)
return a + b
def parallel_fibo(n: int) -> int:
if n < 2:
return n
if n < 25:
a = fibo(n - 1)
b = fibo(n - 2)
return a + b
x = fork(lambda: fibo(n - 1))
y = fork(lambda: fibo(n - 2))
sync()
return x.get() + y.get()
# def parallel_fibo(n: int) -> int:
# if n < 2:
# return n
# if n < 25:
# a = fibo(n - 1)
# b = fibo(n - 2)
# return a + b
# x = fork(lambda: fibo(n - 1))
# y = fork(lambda: fibo(n - 2))
# sync()
# return x.get() + y.get()
if __name__ == "__main__":
......
......@@ -120,7 +120,7 @@ class BlockVisitor(NodeVisitor):
yield "Join"
yield f"<decltype(sync({', '.join(names)}))>"
yield "{"
inner_scope = node.scope
inner_scope = node.inner_scope
for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
......
......@@ -75,13 +75,14 @@ class ExpressionVisitor(NodeVisitor):
def visit_Name(self, node: ast.Name) -> Iterable[str]:
res = self.fix_name(node.id)
if False and (decl := self.scope.get(res)):
if decl.kind == VarKind.SELF:
res = "(*this)"
elif decl.future and CoroutineMode.ASYNC in self.generator:
res = f"{res}.get()"
if decl.future == "future":
res = "co_await " + res
if self.scope.function and (decl := self.scope.get(res)) and decl.type is self.scope.function.obj_type:
res = "(*this)"
#if decl.kind == VarKind.SELF:
# res = "(*this)"
#elif decl.future and CoroutineMode.ASYNC in self.generator:
# res = f"{res}.get()"
# if decl.future == "future":
# res = "co_await " + res
yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
......@@ -125,7 +126,7 @@ class ExpressionVisitor(NodeVisitor):
yield from ()
return
# TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator:
if CoroutineMode.ASYNC in self.generator and node.is_await:
yield "co_await "
node.in_await = True
elif CoroutineMode.FAKE in self.generator:
......
......@@ -28,7 +28,7 @@ class FunctionVisitor(BlockVisitor):
yield f"for (auto {node.target.id} : "
yield from self.expr().visit(node.iter)
yield ")"
yield from self.emit_block(node.scope, node.body)
yield from self.emit_block(node.inner_scope, node.body)
if node.orelse:
raise NotImplementedError(node, "orelse")
......@@ -36,13 +36,13 @@ class FunctionVisitor(BlockVisitor):
yield "if ("
yield from self.expr().visit(node.test)
yield ")"
yield from self.emit_block(node.scope, node.body)
yield from self.emit_block(node.inner_scope, node.body)
if node.orelse:
yield "else "
if isinstance(node.orelse, ast.If):
yield from self.visit(node.orelse)
else:
yield from self.emit_block(node.orelse.scope, node.orelse)
yield from self.emit_block(node.orelse.inner_scope, node.orelse)
def visit_Return(self, node: ast.Return) -> Iterable[str]:
if CoroutineMode.ASYNC in self.generator:
......@@ -57,7 +57,7 @@ class FunctionVisitor(BlockVisitor):
yield "while ("
yield from self.expr().visit(node.test)
yield ")"
yield from self.emit_block(node.scope, node.body)
yield from self.emit_block(node.inner_scope, node.body)
if node.orelse:
raise NotImplementedError(node, "orelse")
......
......@@ -6,7 +6,8 @@ from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \
Promise
@dataclass
......@@ -74,13 +75,13 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_FunctionDef(self, node: ast.FunctionDef):
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
rtype = self.visit_annotation(node.returns)
rtype = Promise(self.visit_annotation(node.returns))
ftype = FunctionType(argtypes, rtype)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.scope = scope
node.inner_scope = scope
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body:
......@@ -92,7 +93,7 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
scope.function = self.scope.function
node.scope = scope
node.inner_scope = scope
visitor = ScoperBlockVisitor(scope, self.root_decls)
for b in node.body:
visitor.visit(b)
......@@ -107,7 +108,7 @@ class ScoperBlockVisitor(ScoperVisitor):
ftype = fct.obj_type
assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else None
vtype.unify(ftype.return_type)
vtype.unify(ftype.return_type.return_type if isinstance(ftype.return_type, Promise) else ftype.return_type)
def visit_Global(self, node: ast.Global):
for name in node.names:
......
......@@ -4,7 +4,7 @@ from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise
DUNDER = {
ast.Eq: "eq",
......@@ -73,7 +73,12 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func)
return self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
if isinstance(rtype, Promise):
node.is_await = True
return rtype.return_type
node.is_await = False
return rtype
def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]):
if not isinstance(ftype, FunctionType):
......@@ -93,7 +98,8 @@ class ScoperExprVisitor(ScoperVisitor):
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.scope = scope
node.inner_scope = scope
node.body.scope = scope
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
decls = {}
......
......@@ -42,6 +42,16 @@ class BaseType(ABC):
def to_list(self) -> List["BaseType"]:
return [self]
class MagicType(BaseType):
def unify_internal(self, other: "BaseType"):
if type(self) is not type(other):
raise IncompatibleTypesError()
def contains_internal(self, other: "BaseType") -> bool:
return False
cur_var = 0
......@@ -68,7 +78,7 @@ class TypeVariable(BaseType):
self.resolved = other
def contains_internal(self, other: BaseType) -> bool:
return self is other
return self.resolve() is other.resolve()
def gen_sub(self, this: "BaseType", typevars) -> "Self":
if match := typevars.get(self.name):
......@@ -134,6 +144,10 @@ class FunctionType(TypeOperator):
def return_type(self):
return self.args[0]
@return_type.setter
def return_type(self, value):
self.args[0] = value
def __str__(self):
ret, *args = map(str, self.args)
if self.variadic:
......@@ -214,3 +228,11 @@ class ForkResult(TypeOperator):
@property
def return_type(self):
return self.args[0]
class Promise(TypeOperator):
def __init__(self, args: BaseType):
super().__init__([args], "Promise")
@property
def return_type(self):
return self.args[0]
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment