Commit 8acf1377 authored by Tom Niget's avatar Tom Niget

Fix recursion and add proper discrimination of sync/async functions

parent 72920c17
...@@ -52,16 +52,10 @@ concept PyNext = requires(T t) { ...@@ -52,16 +52,10 @@ concept PyNext = requires(T t) {
struct { struct {
template <PyNext T> template <PyNext T>
std::optional<typename T::value_type> std::optional<typename T::value_type>
sync(T &t, std::optional<typename T::value_type> def = std::nullopt) { operator()(T &t, std::optional<typename T::value_type> def = std::nullopt) {
auto opt = t.py_next(); auto opt = t.py_next();
return opt ? opt : def; return opt ? opt : def;
} }
template <PyNext T>
auto operator()(T &t, std::optional<typename T::value_type> def = std::nullopt)
-> typon::Task<decltype(sync(t, def))> {
co_return sync(t, def);
}
} next; } next;
template <typename T> template <typename T>
...@@ -69,7 +63,7 @@ std::ostream &operator<<(std::ostream &os, std::optional<T> const &opt) { ...@@ -69,7 +63,7 @@ std::ostream &operator<<(std::ostream &os, std::optional<T> const &opt) {
return opt ? os << opt.value() : os << "None"; return opt ? os << opt.value() : os << "None";
} }
typon::Task<bool> is_cpp() { co_return true; } bool is_cpp() { return true; }
class NoneType { class NoneType {
public: public:
......
...@@ -50,19 +50,14 @@ typon::Task<void> print(T const &head, Args const &...args) { ...@@ -50,19 +50,14 @@ typon::Task<void> print(T const &head, Args const &...args) {
}*/ }*/
struct { struct {
void sync() { std::cout << '\n'; } void operator()() { std::cout << '\n'; }
template <Printable T, Printable... Args> template <Printable T, Printable... Args>
void sync(T const &head, Args const &...args) { void operator()(T const &head, Args const &...args) {
print_to(head, std::cout); print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...); (((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n'; std::cout << '\n';
} }
template<Printable... Args>
typon::Task<void> operator()(Args const &...args) {
co_return sync(args...);
}
} print; } print;
//typon::Task<void> print() { std::cout << '\n'; co_return; } //typon::Task<void> print() { std::cout << '\n'; co_return; }
#endif // TYPON_PRINT_HPP #endif // TYPON_PRINT_HPP
...@@ -28,6 +28,8 @@ def run_tests(): ...@@ -28,6 +28,8 @@ def run_tests():
] ]
if alt := environ.get("ALT_RUNNER"): if alt := environ.get("ALT_RUNNER"):
commands.append(alt.format(name_bin=name_bin, name_cpp_posix=name_cpp.as_posix())) commands.append(alt.format(name_bin=name_bin, name_cpp_posix=name_cpp.as_posix()))
else:
print("no ALT_RUNNER")
for cmd in commands: for cmd in commands:
if system(cmd) != 0: if system(cmd) != 0:
print(f"Error running command: {cmd}") print(f"Error running command: {cmd}")
......
...@@ -7,17 +7,17 @@ def fibo(n: int) -> int: ...@@ -7,17 +7,17 @@ def fibo(n: int) -> int:
b = fibo(n - 2) b = fibo(n - 2)
return a + b return a + b
def parallel_fibo(n: int) -> int: # def parallel_fibo(n: int) -> int:
if n < 2: # if n < 2:
return n # return n
if n < 25: # if n < 25:
a = fibo(n - 1) # a = fibo(n - 1)
b = fibo(n - 2) # b = fibo(n - 2)
return a + b # return a + b
x = fork(lambda: fibo(n - 1)) # x = fork(lambda: fibo(n - 1))
y = fork(lambda: fibo(n - 2)) # y = fork(lambda: fibo(n - 2))
sync() # sync()
return x.get() + y.get() # return x.get() + y.get()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -120,7 +120,7 @@ class BlockVisitor(NodeVisitor): ...@@ -120,7 +120,7 @@ class BlockVisitor(NodeVisitor):
yield "Join" yield "Join"
yield f"<decltype(sync({', '.join(names)}))>" yield f"<decltype(sync({', '.join(names)}))>"
yield "{" yield "{"
inner_scope = node.scope inner_scope = node.inner_scope
for child in node.body: for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes # Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
......
...@@ -75,13 +75,14 @@ class ExpressionVisitor(NodeVisitor): ...@@ -75,13 +75,14 @@ class ExpressionVisitor(NodeVisitor):
def visit_Name(self, node: ast.Name) -> Iterable[str]: def visit_Name(self, node: ast.Name) -> Iterable[str]:
res = self.fix_name(node.id) res = self.fix_name(node.id)
if False and (decl := self.scope.get(res)): if self.scope.function and (decl := self.scope.get(res)) and decl.type is self.scope.function.obj_type:
if decl.kind == VarKind.SELF:
res = "(*this)" res = "(*this)"
elif decl.future and CoroutineMode.ASYNC in self.generator: #if decl.kind == VarKind.SELF:
res = f"{res}.get()" # res = "(*this)"
if decl.future == "future": #elif decl.future and CoroutineMode.ASYNC in self.generator:
res = "co_await " + res # res = f"{res}.get()"
# if decl.future == "future":
# res = "co_await " + res
yield res yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
...@@ -125,7 +126,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -125,7 +126,7 @@ class ExpressionVisitor(NodeVisitor):
yield from () yield from ()
return return
# TODO: precedence needed? # TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator: if CoroutineMode.ASYNC in self.generator and node.is_await:
yield "co_await " yield "co_await "
node.in_await = True node.in_await = True
elif CoroutineMode.FAKE in self.generator: elif CoroutineMode.FAKE in self.generator:
......
...@@ -28,7 +28,7 @@ class FunctionVisitor(BlockVisitor): ...@@ -28,7 +28,7 @@ class FunctionVisitor(BlockVisitor):
yield f"for (auto {node.target.id} : " yield f"for (auto {node.target.id} : "
yield from self.expr().visit(node.iter) yield from self.expr().visit(node.iter)
yield ")" yield ")"
yield from self.emit_block(node.scope, node.body) yield from self.emit_block(node.inner_scope, node.body)
if node.orelse: if node.orelse:
raise NotImplementedError(node, "orelse") raise NotImplementedError(node, "orelse")
...@@ -36,13 +36,13 @@ class FunctionVisitor(BlockVisitor): ...@@ -36,13 +36,13 @@ class FunctionVisitor(BlockVisitor):
yield "if (" yield "if ("
yield from self.expr().visit(node.test) yield from self.expr().visit(node.test)
yield ")" yield ")"
yield from self.emit_block(node.scope, node.body) yield from self.emit_block(node.inner_scope, node.body)
if node.orelse: if node.orelse:
yield "else " yield "else "
if isinstance(node.orelse, ast.If): if isinstance(node.orelse, ast.If):
yield from self.visit(node.orelse) yield from self.visit(node.orelse)
else: else:
yield from self.emit_block(node.orelse.scope, node.orelse) yield from self.emit_block(node.orelse.inner_scope, node.orelse)
def visit_Return(self, node: ast.Return) -> Iterable[str]: def visit_Return(self, node: ast.Return) -> Iterable[str]:
if CoroutineMode.ASYNC in self.generator: if CoroutineMode.ASYNC in self.generator:
...@@ -57,7 +57,7 @@ class FunctionVisitor(BlockVisitor): ...@@ -57,7 +57,7 @@ class FunctionVisitor(BlockVisitor):
yield "while (" yield "while ("
yield from self.expr().visit(node.test) yield from self.expr().visit(node.test)
yield ")" yield ")"
yield from self.emit_block(node.scope, node.body) yield from self.emit_block(node.inner_scope, node.body)
if node.orelse: if node.orelse:
raise NotImplementedError(node, "orelse") raise NotImplementedError(node, "orelse")
......
...@@ -6,7 +6,8 @@ from transpiler.phases.typing.annotations import TypeAnnotationVisitor ...@@ -6,7 +6,8 @@ from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \
Promise
@dataclass @dataclass
...@@ -74,13 +75,13 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -74,13 +75,13 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_FunctionDef(self, node: ast.FunctionDef): def visit_FunctionDef(self, node: ast.FunctionDef):
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args] argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
rtype = self.visit_annotation(node.returns) rtype = Promise(self.visit_annotation(node.returns))
ftype = FunctionType(argtypes, rtype) ftype = FunctionType(argtypes, rtype)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
scope = self.scope.child(ScopeKind.FUNCTION) scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype scope.obj_type = ftype
scope.function = scope scope.function = scope
node.scope = scope node.inner_scope = scope
for arg, ty in zip(node.args.args, argtypes): for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty) scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body: for b in node.body:
...@@ -92,7 +93,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -92,7 +93,7 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_If(self, node: ast.If): def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
scope.function = self.scope.function scope.function = self.scope.function
node.scope = scope node.inner_scope = scope
visitor = ScoperBlockVisitor(scope, self.root_decls) visitor = ScoperBlockVisitor(scope, self.root_decls)
for b in node.body: for b in node.body:
visitor.visit(b) visitor.visit(b)
...@@ -107,7 +108,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -107,7 +108,7 @@ class ScoperBlockVisitor(ScoperVisitor):
ftype = fct.obj_type ftype = fct.obj_type
assert isinstance(ftype, FunctionType) assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else None vtype = self.expr().visit(node.value) if node.value else None
vtype.unify(ftype.return_type) vtype.unify(ftype.return_type.return_type if isinstance(ftype.return_type, Promise) else ftype.return_type)
def visit_Global(self, node: ast.Global): def visit_Global(self, node: ast.Global):
for name in node.names: for name in node.names:
......
...@@ -4,7 +4,7 @@ from typing import List ...@@ -4,7 +4,7 @@ from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \ from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise
DUNDER = { DUNDER = {
ast.Eq: "eq", ast.Eq: "eq",
...@@ -73,7 +73,12 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -73,7 +73,12 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Call(self, node: ast.Call) -> BaseType: def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func) ftype = self.visit(node.func)
return self.visit_function_call(ftype, [self.visit(arg) for arg in node.args]) rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
if isinstance(rtype, Promise):
node.is_await = True
return rtype.return_type
node.is_await = False
return rtype
def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]): def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]):
if not isinstance(ftype, FunctionType): if not isinstance(ftype, FunctionType):
...@@ -93,7 +98,8 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -93,7 +98,8 @@ class ScoperExprVisitor(ScoperVisitor):
scope = self.scope.child(ScopeKind.FUNCTION) scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype scope.obj_type = ftype
scope.function = scope scope.function = scope
node.scope = scope node.inner_scope = scope
node.body.scope = scope
for arg, ty in zip(node.args.args, argtypes): for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty) scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
decls = {} decls = {}
......
...@@ -42,6 +42,16 @@ class BaseType(ABC): ...@@ -42,6 +42,16 @@ class BaseType(ABC):
def to_list(self) -> List["BaseType"]: def to_list(self) -> List["BaseType"]:
return [self] return [self]
class MagicType(BaseType):
def unify_internal(self, other: "BaseType"):
if type(self) is not type(other):
raise IncompatibleTypesError()
def contains_internal(self, other: "BaseType") -> bool:
return False
cur_var = 0 cur_var = 0
...@@ -68,7 +78,7 @@ class TypeVariable(BaseType): ...@@ -68,7 +78,7 @@ class TypeVariable(BaseType):
self.resolved = other self.resolved = other
def contains_internal(self, other: BaseType) -> bool: def contains_internal(self, other: BaseType) -> bool:
return self is other return self.resolve() is other.resolve()
def gen_sub(self, this: "BaseType", typevars) -> "Self": def gen_sub(self, this: "BaseType", typevars) -> "Self":
if match := typevars.get(self.name): if match := typevars.get(self.name):
...@@ -134,6 +144,10 @@ class FunctionType(TypeOperator): ...@@ -134,6 +144,10 @@ class FunctionType(TypeOperator):
def return_type(self): def return_type(self):
return self.args[0] return self.args[0]
@return_type.setter
def return_type(self, value):
self.args[0] = value
def __str__(self): def __str__(self):
ret, *args = map(str, self.args) ret, *args = map(str, self.args)
if self.variadic: if self.variadic:
...@@ -214,3 +228,11 @@ class ForkResult(TypeOperator): ...@@ -214,3 +228,11 @@ class ForkResult(TypeOperator):
@property @property
def return_type(self): def return_type(self):
return self.args[0] return self.args[0]
class Promise(TypeOperator):
def __init__(self, args: BaseType):
super().__init__([args], "Promise")
@property
def return_type(self):
return self.args[0]
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment