Commit e3cc4548 authored by Tom Niget's avatar Tom Niget

Add basic coroutine+await support

parent 45014685
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <ostream> #include <ostream>
#include <string> #include <string>
#include <typon/typon.hpp>
#ifdef __cpp_lib_unreachable #ifdef __cpp_lib_unreachable
#include <utility> #include <utility>
[[noreturn]] inline void TYPON_UNREACHABLE() { std::unreachable(); } [[noreturn]] inline void TYPON_UNREACHABLE() { std::unreachable(); }
...@@ -58,7 +60,7 @@ std::ostream &operator<<(std::ostream &os, std::optional<T> const &opt) { ...@@ -58,7 +60,7 @@ std::ostream &operator<<(std::ostream &os, std::optional<T> const &opt) {
return opt ? os << opt.value() : os << "None"; return opt ? os << opt.value() : os << "None";
} }
bool is_cpp() { return true; } typon::Task<bool> is_cpp() { co_return true; }
class NoneType { class NoneType {
public: public:
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <iostream> #include <iostream>
#include <ostream> #include <ostream>
#include <typon/typon.hpp>
template <typename T> template <typename T>
concept Streamable = requires(const T &x, std::ostream &s) { concept Streamable = requires(const T &x, std::ostream &s) {
{ s << x } -> std::same_as<std::ostream &>; { s << x } -> std::same_as<std::ostream &>;
...@@ -39,13 +41,28 @@ template <typename T> ...@@ -39,13 +41,28 @@ template <typename T>
concept Printable = requires(const T &x, std::ostream &s) { concept Printable = requires(const T &x, std::ostream &s) {
{ print_to(x, s) } -> std::same_as<void>; { print_to(x, s) } -> std::same_as<void>;
}; };
/*
template <Printable T, Printable... Args> template <Printable T, Printable... Args>
void print(T const &head, Args const &...args) { typon::Task<void> print(T const &head, Args const &...args) {
print_to(head, std::cout); print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...); (((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n'; std::cout << '\n'; co_return;
} }*/
struct {
void sync() { std::cout << '\n'; }
template <Printable T, Printable... Args>
void sync(T const &head, Args const &...args) {
print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n';
}
void print() { std::cout << '\n'; } template<Printable... Args>
typon::Task<void> operator()(Args const &...args) {
co_return sync(args...);
}
} print;
//typon::Task<void> print() { std::cout << '\n'; co_return; }
#endif // TYPON_PRINT_HPP #endif // TYPON_PRINT_HPP
from typon import fork, sync from typon import fork, sync
def fibo(n):
if n < 2:
return n
a = fibo(n - 1)
b = fibo(n - 2)
return a + b
#def fibo(n: int) -> int: #def fibo(n: int) -> int:
# if n < 2: # if n < 2:
...@@ -11,4 +18,4 @@ from typon import fork, sync ...@@ -11,4 +18,4 @@ from typon import fork, sync
if __name__ == "__main__": if __name__ == "__main__":
print("res=", 5, ".") print(fibo(30)) # should display 832040
\ No newline at end of file \ No newline at end of file
...@@ -41,7 +41,6 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]: ...@@ -41,7 +41,6 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
def transpile(source): def transpile(source):
tree = ast.parse(source) tree = ast.parse(source)
# print(ast.unparse(tree))
return "\n".join(filter(None, map(str, FileVisitor(Scope()).visit(tree)))) return "\n".join(filter(None, map(str, FileVisitor(Scope()).visit(tree))))
...@@ -74,7 +73,7 @@ SYMBOLS = { ...@@ -74,7 +73,7 @@ SYMBOLS = {
PRECEDENCE = [ PRECEDENCE = [
("()", "[]", ".",), ("()", "[]", ".",),
("unary",), ("unary", "co_await"),
("*", "/", "%",), ("*", "/", "%",),
("+", "-"), ("+", "-"),
("<<", ">>"), ("<<", ">>"),
...@@ -85,7 +84,7 @@ PRECEDENCE = [ ...@@ -85,7 +84,7 @@ PRECEDENCE = [
("|",), ("|",),
("&&",), ("&&",),
("||",), ("||",),
("?:",), ("?:", "co_yield"),
(",",) (",",)
] ]
"""Precedence of C++ operators.""" """Precedence of C++ operators."""
...@@ -106,6 +105,7 @@ class VarKind(Enum): ...@@ -106,6 +105,7 @@ class VarKind(Enum):
LOCAL = 1 LOCAL = 1
GLOBAL = 2 GLOBAL = 2
NONLOCAL = 3 NONLOCAL = 3
SELF = 4
@dataclass @dataclass
...@@ -119,15 +119,19 @@ class UnsupportedNodeError(Exception): ...@@ -119,15 +119,19 @@ class UnsupportedNodeError(Exception):
class NodeVisitor: class NodeVisitor:
def visit(self, node): def visit(self, node):
"""Visit a node.""" """Visit a node."""
if type(node) in SYMBOLS: if type(node) == list:
yield SYMBOLS[type(node)] for n in node:
yield from self.visit(n)
else: else:
for parent in node.__class__.__mro__: for parent in node.__class__.__mro__:
if visitor := getattr(self, 'visit_' + parent.__name__, None): if visitor := getattr(self, 'visit_' + parent.__name__, None):
yield from visitor(node) yield from visitor(node)
break break
else: else:
raise UnsupportedNodeError(node) return self.missing_impl(node)
def missing_impl(self, node):
raise UnsupportedNodeError(node)
def process_args(self, node: ast.arguments) -> (str, str, str): def process_args(self, node: ast.arguments) -> (str, str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"): for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
...@@ -149,6 +153,21 @@ class NodeVisitor: ...@@ -149,6 +153,21 @@ class NodeVisitor:
return MAPPINGS.get(name, name) return MAPPINGS.get(name, name)
class SearchVisitor(NodeVisitor):
def missing_impl(self, node):
if not hasattr(node, "__dict__"):
return
for val in node.__dict__.values():
if isinstance(val, list):
for item in val:
yield from self.visit(item)
elif isinstance(val, ast.AST):
yield from self.visit(val)
def default(self, node):
pass
class PrecedenceContext: class PrecedenceContext:
def __init__(self, visitor: "ExpressionVisitor", op: str): def __init__(self, visitor: "ExpressionVisitor", op: str):
self.visitor = visitor self.visitor = visitor
...@@ -160,18 +179,27 @@ class PrecedenceContext: ...@@ -160,18 +179,27 @@ class PrecedenceContext:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.visitor.precedence.pop() self.visitor.precedence.pop()
class CoroutineMode(Enum): class CoroutineMode(Enum):
NONE = 0 NONE = 0
GENERATOR = 1 GENERATOR = 1
FAKE = 2 FAKE = 2
TASK = 3
# noinspection PyPep8Naming # noinspection PyPep8Naming
@dataclass @dataclass
class ExpressionVisitor(NodeVisitor): class ExpressionVisitor(NodeVisitor):
scope: "Scope"
precedence: List = field(default_factory=list) precedence: List = field(default_factory=list)
generator: CoroutineMode = CoroutineMode.NONE generator: CoroutineMode = CoroutineMode.NONE
def visit(self, node):
if type(node) in SYMBOLS:
yield SYMBOLS[type(node)]
else:
yield from NodeVisitor.visit(self, node)
def prec_ctx(self, op: str) -> PrecedenceContext: def prec_ctx(self, op: str) -> PrecedenceContext:
""" """
Creates a context manager that sets the precedence of the next expression. Creates a context manager that sets the precedence of the next expression.
...@@ -182,13 +210,13 @@ class ExpressionVisitor(NodeVisitor): ...@@ -182,13 +210,13 @@ class ExpressionVisitor(NodeVisitor):
""" """
Sets the precedence of the next expression. Sets the precedence of the next expression.
""" """
return ExpressionVisitor([op], generator=self.generator) return ExpressionVisitor(self.scope, [op], generator=self.generator)
def reset(self) -> "ExpressionVisitor": def reset(self) -> "ExpressionVisitor":
""" """
Resets the precedence stack. Resets the precedence stack.
""" """
return ExpressionVisitor(generator=self.generator) return ExpressionVisitor(self.scope, generator=self.generator)
def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]: def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]:
yield "std::make_tuple(" yield "std::make_tuple("
...@@ -212,7 +240,10 @@ class ExpressionVisitor(NodeVisitor): ...@@ -212,7 +240,10 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, type(node)) raise NotImplementedError(node, type(node))
def visit_Name(self, node: ast.Name) -> Iterable[str]: def visit_Name(self, node: ast.Name) -> Iterable[str]:
yield self.fix_name(node.id) res = self.fix_name(node.id)
if (decl := self.scope.get(res)) and decl.kind == VarKind.SELF:
res = "(*this)"
yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
operands = [node.left, *node.comparators] operands = [node.left, *node.comparators]
...@@ -230,10 +261,16 @@ class ExpressionVisitor(NodeVisitor): ...@@ -230,10 +261,16 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, "varargs") raise NotImplementedError(node, "varargs")
if getattr(node, "kwargs", None): if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs") raise NotImplementedError(node, "kwargs")
yield from self.prec("()").visit(node.func) func = node.func
yield "(" if self.generator == CoroutineMode.TASK:
yield from join(", ", map(self.reset().visit, node.args)) yield "co_await "
yield ")" elif self.generator == CoroutineMode.FAKE:
func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
with self.prec_ctx("co_await"):
yield from self.prec("()").visit(func)
yield "("
yield from join(", ", map(self.reset().visit, node.args))
yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]: def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
yield "[]" yield "[]"
...@@ -313,9 +350,12 @@ class ExpressionVisitor(NodeVisitor): ...@@ -313,9 +350,12 @@ class ExpressionVisitor(NodeVisitor):
raise UnsupportedNodeError(node) raise UnsupportedNodeError(node)
elif self.generator == CoroutineMode.GENERATOR: elif self.generator == CoroutineMode.GENERATOR:
yield "co_yield" yield "co_yield"
yield from self.prec("co_yield").visit(node.value)
elif self.generator == CoroutineMode.FAKE: elif self.generator == CoroutineMode.FAKE:
yield "return" yield "return"
yield from self.visit(node.value) yield from self.visit(node.value)
else:
raise NotImplementedError(node)
@dataclass @dataclass
...@@ -342,6 +382,16 @@ class Scope: ...@@ -342,6 +382,16 @@ class Scope:
""" """
return name in self.vars or (self.parent is not None and self.parent.exists(name)) return name in self.vars or (self.parent is not None and self.parent.exists(name))
def get(self, name: str) -> Optional[VarDecl]:
"""
Gets the variable declaration of a variable in the current scope or any parent scope.
"""
if res := self.vars.get(name):
return res
if self.parent is not None:
return self.parent.get(name)
return None
def exists_local(self, name: str) -> bool: def exists_local(self, name: str) -> bool:
""" """
Determines whether a variable exists in the current function or global scope. Determines whether a variable exists in the current function or global scope.
...@@ -401,35 +451,58 @@ class Scope: ...@@ -401,35 +451,58 @@ class Scope:
@dataclass @dataclass
class BlockVisitor(NodeVisitor): class BlockVisitor(NodeVisitor):
scope: Scope scope: Scope
generator: CoroutineMode = CoroutineMode.NONE
def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(self.scope, generator=self.generator)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
try: yield from self.visit_coroutine(node)
[*result] = self.visit_func(node, CoroutineMode.NONE) # try:
return result # [*result] = self.visit_func(node, CoroutineMode.NONE)
except UnsupportedNodeError as e: # return result
if isinstance(e.node, ast.Yield): # except UnsupportedNodeError as e:
return self.visit_coroutine(node) # if isinstance(e.node, ast.Yield):
raise # return self.visit_coroutine(node)
# raise
def visit_coroutine(self, node: ast.FunctionDef) -> Iterable[str]: def visit_coroutine(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {"
yield from self.visit_func(node, CoroutineMode.FAKE) yield from self.visit_func(node, CoroutineMode.FAKE)
yield from self.visit_func(node, CoroutineMode.GENERATOR) yield from self.visit_func(node, CoroutineMode.TASK)
yield f"}} {node.name};"
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]: def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
templ, args, names = self.process_args(node.args) templ, args, names = self.process_args(node.args)
if templ: if templ:
yield "template" yield "template"
yield templ yield templ
faked = f"FAKED_{node.name}"
class ReturnVisitor(SearchVisitor):
def visit_Return(self, node: ast.Return) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_return = next(ReturnVisitor().visit(node.body), False)
if generator == CoroutineMode.FAKE: if generator == CoroutineMode.FAKE:
yield f"static auto {faked}" if has_return:
yield "auto"
else:
yield "void"
yield "sync"
else: else:
yield f"auto {node.name}" yield f"auto operator()"
yield args yield args
if generator == CoroutineMode.GENERATOR: if generator == CoroutineMode.TASK:
yield f"-> typon::Generator<decltype({faked}({', '.join(names)}))>" yield f"-> typon::Task<decltype(sync({', '.join(names)}))>"
yield "{" yield "{"
inner_scope = self.scope.function() inner_scope = self.scope.function(vars={node.name: VarDecl(VarKind.SELF, None)})
for child in node.body: for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes # Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
# on their own. Variables are still accessible in the remainder of the parent function or in the global # on their own. Variables are still accessible in the remainder of the parent function or in the global
...@@ -484,19 +557,22 @@ class BlockVisitor(NodeVisitor): ...@@ -484,19 +557,22 @@ class BlockVisitor(NodeVisitor):
yield from child_code # Yeet back the child node code. yield from child_code # Yeet back the child node code.
if generator == CoroutineMode.FAKE: if generator == CoroutineMode.FAKE:
yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements. yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements.
elif generator == CoroutineMode.TASK:
if not has_return:
yield "co_return;"
yield "}" yield "}"
def visit_lvalue(self, lvalue: ast.expr, val: Optional[ast.AST] = None) -> Iterable[str]: def visit_lvalue(self, lvalue: ast.expr, val: Optional[ast.AST] = None) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple): if isinstance(lvalue, ast.Tuple):
yield f"std::tie({', '.join(flatmap(ExpressionVisitor().visit, lvalue.elts))})" yield f"std::tie({', '.join(flatmap(self.expr().visit, lvalue.elts))})"
elif isinstance(lvalue, ast.Name): elif isinstance(lvalue, ast.Name):
name = self.fix_name(lvalue.id) name = self.fix_name(lvalue.id)
# if name not in self._scope.vars: # if name not in self._scope.vars:
if not self.scope.exists_local(name): if not self.scope.exists_local(name):
yield self.scope.declare(name, " ".join(ExpressionVisitor().visit(val)) if val else None) yield self.scope.declare(name, " ".join(self.expr().visit(val)) if val else None)
yield name yield name
elif isinstance(lvalue, ast.Subscript): elif isinstance(lvalue, ast.Subscript):
yield from ExpressionVisitor().visit(lvalue) yield from self.expr().visit(lvalue)
else: else:
raise NotImplementedError(lvalue) raise NotImplementedError(lvalue)
...@@ -505,7 +581,7 @@ class BlockVisitor(NodeVisitor): ...@@ -505,7 +581,7 @@ class BlockVisitor(NodeVisitor):
raise NotImplementedError(node) raise NotImplementedError(node)
yield from self.visit_lvalue(node.targets[0], node.value) yield from self.visit_lvalue(node.targets[0], node.value)
yield " = " yield " = "
yield from ExpressionVisitor().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]: def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
...@@ -513,7 +589,7 @@ class BlockVisitor(NodeVisitor): ...@@ -513,7 +589,7 @@ class BlockVisitor(NodeVisitor):
raise NotImplementedError(node, "empty value") raise NotImplementedError(node, "empty value")
yield from self.visit_lvalue(node.target, node.value) yield from self.visit_lvalue(node.target, node.value)
yield " = " yield " = "
yield from ExpressionVisitor().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
...@@ -557,8 +633,14 @@ class ModuleVisitor(BlockVisitor): ...@@ -557,8 +633,14 @@ class ModuleVisitor(BlockVisitor):
# declarations (functions, classes, modules, ...) and code. # declarations (functions, classes, modules, ...) and code.
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`. # Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default. # 0 is returned by default.
yield "int main()" yield "typon::Root root()"
yield from FunctionVisitor(self.scope.function()).emit_block(node.body)
def block():
yield from node.body
yield ast.Return()
yield from FunctionVisitor(self.scope.function(), CoroutineMode.TASK).emit_block(block())
yield "int main() { root().call(); }"
return return
raise NotImplementedError(node, "global scope if") raise NotImplementedError(node, "global scope if")
...@@ -567,10 +649,7 @@ class ModuleVisitor(BlockVisitor): ...@@ -567,10 +649,7 @@ class ModuleVisitor(BlockVisitor):
# noinspection PyPep8Naming # noinspection PyPep8Naming
@dataclass @dataclass
class FunctionVisitor(BlockVisitor): class FunctionVisitor(BlockVisitor):
generator: CoroutineMode = CoroutineMode.NONE
def visit_Expr(self, node: ast.Expr) -> Iterable[str]: def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
print(ast.dump(node))
yield from self.expr().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
...@@ -603,7 +682,10 @@ class FunctionVisitor(BlockVisitor): ...@@ -603,7 +682,10 @@ class FunctionVisitor(BlockVisitor):
yield from self.emit_block(node.orelse) yield from self.emit_block(node.orelse)
def visit_Return(self, node: ast.Return) -> Iterable[str]: def visit_Return(self, node: ast.Return) -> Iterable[str]:
yield "return " if self.generator == CoroutineMode.TASK:
yield "co_return "
else:
yield "return "
if node.value: if node.value:
yield from self.expr().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
...@@ -632,10 +714,7 @@ class FunctionVisitor(BlockVisitor): ...@@ -632,10 +714,7 @@ class FunctionVisitor(BlockVisitor):
# variables as the parent scope. # variables as the parent scope.
return FunctionVisitor(self.scope.child_share(), self.generator) return FunctionVisitor(self.scope.child_share(), self.generator)
def expr(self) -> ExpressionVisitor: def emit_block(self, items: Iterable[ast.stmt]) -> Iterable[str]:
return ExpressionVisitor(generator=self.generator)
def emit_block(self, items: List[ast.stmt]) -> Iterable[str]:
yield "{" yield "{"
block = self.block() block = self.block()
for child in items: for child in items:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment