Commit e3cc4548 authored by Tom Niget's avatar Tom Niget

Add basic coroutine+await support

parent 45014685
......@@ -10,6 +10,8 @@
#include <ostream>
#include <string>
#include <typon/typon.hpp>
#ifdef __cpp_lib_unreachable
#include <utility>
[[noreturn]] inline void TYPON_UNREACHABLE() { std::unreachable(); }
......@@ -58,7 +60,7 @@ std::ostream &operator<<(std::ostream &os, std::optional<T> const &opt) {
return opt ? os << opt.value() : os << "None";
}
bool is_cpp() { return true; }
typon::Task<bool> is_cpp() { co_return true; }
class NoneType {
public:
......
......@@ -8,6 +8,8 @@
#include <iostream>
#include <ostream>
#include <typon/typon.hpp>
template <typename T>
concept Streamable = requires(const T &x, std::ostream &s) {
{ s << x } -> std::same_as<std::ostream &>;
......@@ -39,13 +41,28 @@ template <typename T>
concept Printable = requires(const T &x, std::ostream &s) {
{ print_to(x, s) } -> std::same_as<void>;
};
/*
template <Printable T, Printable... Args>
void print(T const &head, Args const &...args) {
typon::Task<void> print(T const &head, Args const &...args) {
print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n';
}
std::cout << '\n'; co_return;
}*/
struct {
void sync() { std::cout << '\n'; }
template <Printable T, Printable... Args>
void sync(T const &head, Args const &...args) {
print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n';
}
void print() { std::cout << '\n'; }
template<Printable... Args>
typon::Task<void> operator()(Args const &...args) {
co_return sync(args...);
}
} print;
//typon::Task<void> print() { std::cout << '\n'; co_return; }
#endif // TYPON_PRINT_HPP
from typon import fork, sync
def fibo(n):
if n < 2:
return n
a = fibo(n - 1)
b = fibo(n - 2)
return a + b
#def fibo(n: int) -> int:
# if n < 2:
......@@ -11,4 +18,4 @@ from typon import fork, sync
if __name__ == "__main__":
print("res=", 5, ".")
\ No newline at end of file
print(fibo(30)) # should display 832040
\ No newline at end of file
......@@ -41,7 +41,6 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
def transpile(source):
tree = ast.parse(source)
# print(ast.unparse(tree))
return "\n".join(filter(None, map(str, FileVisitor(Scope()).visit(tree))))
......@@ -74,7 +73,7 @@ SYMBOLS = {
PRECEDENCE = [
("()", "[]", ".",),
("unary",),
("unary", "co_await"),
("*", "/", "%",),
("+", "-"),
("<<", ">>"),
......@@ -85,7 +84,7 @@ PRECEDENCE = [
("|",),
("&&",),
("||",),
("?:",),
("?:", "co_yield"),
(",",)
]
"""Precedence of C++ operators."""
......@@ -106,6 +105,7 @@ class VarKind(Enum):
LOCAL = 1
GLOBAL = 2
NONLOCAL = 3
SELF = 4
@dataclass
......@@ -119,15 +119,19 @@ class UnsupportedNodeError(Exception):
class NodeVisitor:
def visit(self, node):
"""Visit a node."""
if type(node) in SYMBOLS:
yield SYMBOLS[type(node)]
if type(node) == list:
for n in node:
yield from self.visit(n)
else:
for parent in node.__class__.__mro__:
if visitor := getattr(self, 'visit_' + parent.__name__, None):
yield from visitor(node)
break
else:
raise UnsupportedNodeError(node)
return self.missing_impl(node)
def missing_impl(self, node):
raise UnsupportedNodeError(node)
def process_args(self, node: ast.arguments) -> (str, str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
......@@ -149,6 +153,21 @@ class NodeVisitor:
return MAPPINGS.get(name, name)
class SearchVisitor(NodeVisitor):
def missing_impl(self, node):
if not hasattr(node, "__dict__"):
return
for val in node.__dict__.values():
if isinstance(val, list):
for item in val:
yield from self.visit(item)
elif isinstance(val, ast.AST):
yield from self.visit(val)
def default(self, node):
pass
class PrecedenceContext:
def __init__(self, visitor: "ExpressionVisitor", op: str):
self.visitor = visitor
......@@ -160,18 +179,27 @@ class PrecedenceContext:
def __exit__(self, exc_type, exc_val, exc_tb):
self.visitor.precedence.pop()
class CoroutineMode(Enum):
NONE = 0
GENERATOR = 1
FAKE = 2
TASK = 3
# noinspection PyPep8Naming
@dataclass
class ExpressionVisitor(NodeVisitor):
scope: "Scope"
precedence: List = field(default_factory=list)
generator: CoroutineMode = CoroutineMode.NONE
def visit(self, node):
if type(node) in SYMBOLS:
yield SYMBOLS[type(node)]
else:
yield from NodeVisitor.visit(self, node)
def prec_ctx(self, op: str) -> PrecedenceContext:
"""
Creates a context manager that sets the precedence of the next expression.
......@@ -182,13 +210,13 @@ class ExpressionVisitor(NodeVisitor):
"""
Sets the precedence of the next expression.
"""
return ExpressionVisitor([op], generator=self.generator)
return ExpressionVisitor(self.scope, [op], generator=self.generator)
def reset(self) -> "ExpressionVisitor":
"""
Resets the precedence stack.
"""
return ExpressionVisitor(generator=self.generator)
return ExpressionVisitor(self.scope, generator=self.generator)
def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]:
yield "std::make_tuple("
......@@ -212,7 +240,10 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, type(node))
def visit_Name(self, node: ast.Name) -> Iterable[str]:
yield self.fix_name(node.id)
res = self.fix_name(node.id)
if (decl := self.scope.get(res)) and decl.kind == VarKind.SELF:
res = "(*this)"
yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
operands = [node.left, *node.comparators]
......@@ -230,10 +261,16 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, "varargs")
if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs")
yield from self.prec("()").visit(node.func)
yield "("
yield from join(", ", map(self.reset().visit, node.args))
yield ")"
func = node.func
if self.generator == CoroutineMode.TASK:
yield "co_await "
elif self.generator == CoroutineMode.FAKE:
func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
with self.prec_ctx("co_await"):
yield from self.prec("()").visit(func)
yield "("
yield from join(", ", map(self.reset().visit, node.args))
yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
yield "[]"
......@@ -313,9 +350,12 @@ class ExpressionVisitor(NodeVisitor):
raise UnsupportedNodeError(node)
elif self.generator == CoroutineMode.GENERATOR:
yield "co_yield"
yield from self.prec("co_yield").visit(node.value)
elif self.generator == CoroutineMode.FAKE:
yield "return"
yield from self.visit(node.value)
yield from self.visit(node.value)
else:
raise NotImplementedError(node)
@dataclass
......@@ -342,6 +382,16 @@ class Scope:
"""
return name in self.vars or (self.parent is not None and self.parent.exists(name))
def get(self, name: str) -> Optional[VarDecl]:
"""
Gets the variable declaration of a variable in the current scope or any parent scope.
"""
if res := self.vars.get(name):
return res
if self.parent is not None:
return self.parent.get(name)
return None
def exists_local(self, name: str) -> bool:
"""
Determines whether a variable exists in the current function or global scope.
......@@ -401,35 +451,58 @@ class Scope:
@dataclass
class BlockVisitor(NodeVisitor):
scope: Scope
generator: CoroutineMode = CoroutineMode.NONE
def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(self.scope, generator=self.generator)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
try:
[*result] = self.visit_func(node, CoroutineMode.NONE)
return result
except UnsupportedNodeError as e:
if isinstance(e.node, ast.Yield):
return self.visit_coroutine(node)
raise
yield from self.visit_coroutine(node)
# try:
# [*result] = self.visit_func(node, CoroutineMode.NONE)
# return result
# except UnsupportedNodeError as e:
# if isinstance(e.node, ast.Yield):
# return self.visit_coroutine(node)
# raise
def visit_coroutine(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {"
yield from self.visit_func(node, CoroutineMode.FAKE)
yield from self.visit_func(node, CoroutineMode.GENERATOR)
yield from self.visit_func(node, CoroutineMode.TASK)
yield f"}} {node.name};"
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
templ, args, names = self.process_args(node.args)
if templ:
yield "template"
yield templ
faked = f"FAKED_{node.name}"
class ReturnVisitor(SearchVisitor):
def visit_Return(self, node: ast.Return) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_return = next(ReturnVisitor().visit(node.body), False)
if generator == CoroutineMode.FAKE:
yield f"static auto {faked}"
if has_return:
yield "auto"
else:
yield "void"
yield "sync"
else:
yield f"auto {node.name}"
yield f"auto operator()"
yield args
if generator == CoroutineMode.GENERATOR:
yield f"-> typon::Generator<decltype({faked}({', '.join(names)}))>"
if generator == CoroutineMode.TASK:
yield f"-> typon::Task<decltype(sync({', '.join(names)}))>"
yield "{"
inner_scope = self.scope.function()
inner_scope = self.scope.function(vars={node.name: VarDecl(VarKind.SELF, None)})
for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
# on their own. Variables are still accessible in the remainder of the parent function or in the global
......@@ -484,19 +557,22 @@ class BlockVisitor(NodeVisitor):
yield from child_code # Yeet back the child node code.
if generator == CoroutineMode.FAKE:
yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements.
elif generator == CoroutineMode.TASK:
if not has_return:
yield "co_return;"
yield "}"
def visit_lvalue(self, lvalue: ast.expr, val: Optional[ast.AST] = None) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple):
yield f"std::tie({', '.join(flatmap(ExpressionVisitor().visit, lvalue.elts))})"
yield f"std::tie({', '.join(flatmap(self.expr().visit, lvalue.elts))})"
elif isinstance(lvalue, ast.Name):
name = self.fix_name(lvalue.id)
# if name not in self._scope.vars:
if not self.scope.exists_local(name):
yield self.scope.declare(name, " ".join(ExpressionVisitor().visit(val)) if val else None)
yield self.scope.declare(name, " ".join(self.expr().visit(val)) if val else None)
yield name
elif isinstance(lvalue, ast.Subscript):
yield from ExpressionVisitor().visit(lvalue)
yield from self.expr().visit(lvalue)
else:
raise NotImplementedError(lvalue)
......@@ -505,7 +581,7 @@ class BlockVisitor(NodeVisitor):
raise NotImplementedError(node)
yield from self.visit_lvalue(node.targets[0], node.value)
yield " = "
yield from ExpressionVisitor().visit(node.value)
yield from self.expr().visit(node.value)
yield ";"
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
......@@ -513,7 +589,7 @@ class BlockVisitor(NodeVisitor):
raise NotImplementedError(node, "empty value")
yield from self.visit_lvalue(node.target, node.value)
yield " = "
yield from ExpressionVisitor().visit(node.value)
yield from self.expr().visit(node.value)
yield ";"
......@@ -557,8 +633,14 @@ class ModuleVisitor(BlockVisitor):
# declarations (functions, classes, modules, ...) and code.
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default.
yield "int main()"
yield from FunctionVisitor(self.scope.function()).emit_block(node.body)
yield "typon::Root root()"
def block():
yield from node.body
yield ast.Return()
yield from FunctionVisitor(self.scope.function(), CoroutineMode.TASK).emit_block(block())
yield "int main() { root().call(); }"
return
raise NotImplementedError(node, "global scope if")
......@@ -567,10 +649,7 @@ class ModuleVisitor(BlockVisitor):
# noinspection PyPep8Naming
@dataclass
class FunctionVisitor(BlockVisitor):
generator: CoroutineMode = CoroutineMode.NONE
def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
print(ast.dump(node))
yield from self.expr().visit(node.value)
yield ";"
......@@ -603,7 +682,10 @@ class FunctionVisitor(BlockVisitor):
yield from self.emit_block(node.orelse)
def visit_Return(self, node: ast.Return) -> Iterable[str]:
yield "return "
if self.generator == CoroutineMode.TASK:
yield "co_return "
else:
yield "return "
if node.value:
yield from self.expr().visit(node.value)
yield ";"
......@@ -632,10 +714,7 @@ class FunctionVisitor(BlockVisitor):
# variables as the parent scope.
return FunctionVisitor(self.scope.child_share(), self.generator)
def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(generator=self.generator)
def emit_block(self, items: List[ast.stmt]) -> Iterable[str]:
def emit_block(self, items: Iterable[ast.stmt]) -> Iterable[str]:
yield "{"
block = self.block()
for child in items:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment