Commit 8d6354b3 authored by Tom Niget's avatar Tom Niget

Add preliminary support for generators with new coroutine system

parent 081e3685
...@@ -48,12 +48,20 @@ concept PyNext = requires(T t) { ...@@ -48,12 +48,20 @@ concept PyNext = requires(T t) {
{ t.py_next() } -> std::same_as<std::optional<typename T::value_type>>; { t.py_next() } -> std::same_as<std::optional<typename T::value_type>>;
}; };
template <PyNext T> struct {
std::optional<typename T::value_type> template <PyNext T>
next(T &t, std::optional<typename T::value_type> def = std::nullopt) { std::optional<typename T::value_type>
auto opt = t.py_next(); sync(T &t, std::optional<typename T::value_type> def = std::nullopt) {
return opt ? opt : def; auto opt = t.py_next();
} return opt ? opt : def;
}
template <PyNext T>
auto operator()(T &t, std::optional<typename T::value_type> def = std::nullopt)
-> typon::Task<decltype(sync(t, def))> {
co_return sync(t, def);
}
} next;
template <typename T> template <typename T>
std::ostream &operator<<(std::ostream &os, std::optional<T> const &opt) { std::ostream &operator<<(std::ostream &os, std::optional<T> const &opt) {
......
...@@ -7,11 +7,25 @@ ...@@ -7,11 +7,25 @@
#include <ranges> #include <ranges>
#include <typon/typon.hpp>
// todo: proper range support // todo: proper range support
template <typename T> auto range(T stop) { return std::views::iota(0, stop); } struct {
template <typename T> auto sync(T stop) { return std::views::iota(0, stop); }
template <typename T> auto sync(T start, T stop) {
return std::views::iota(start, stop);
}
template <typename T>
auto operator()(T stop) -> typon::Task<decltype(sync(stop))> {
co_return sync(stop);
}
template <typename T> auto range(T start, T stop) { template <typename T>
return std::views::iota(start, stop); auto operator()(T start, T stop) -> typon::Task<decltype(sync(start, stop))> {
} co_return sync(start, stop);
}
} range;
#endif // TYPON_RANGE_HPP #endif // TYPON_RANGE_HPP
...@@ -12,17 +12,17 @@ namespace typon { ...@@ -12,17 +12,17 @@ namespace typon {
/** /**
* https://github.com/feabhas/coroutines-blog * https://github.com/feabhas/coroutines-blog
*/ */
template <typename T> template <typename T> class Generator {
class Generator class Promise {
{
class Promise
{
public: public:
using value_type = std::optional<T>; using value_type = std::optional<T>;
Promise() = default; Promise() = default;
std::suspend_always initial_suspend() { return {}; } std::suspend_always initial_suspend() { return {}; }
std::suspend_always final_suspend() noexcept { final = true; return {}; } std::suspend_always final_suspend() noexcept {
final = true;
return {};
}
void unhandled_exception() { void unhandled_exception() {
std::rethrow_exception(std::move(std::current_exception())); std::rethrow_exception(std::move(std::current_exception()));
} }
...@@ -36,18 +36,14 @@ class Generator ...@@ -36,18 +36,14 @@ class Generator
// this->value = std::move(value); // this->value = std::move(value);
// } // }
void return_void() { void return_void() { this->value = std::nullopt; }
this->value = std::nullopt;
}
inline Generator get_return_object(); inline Generator get_return_object();
value_type get_value() { value_type get_value() { return std::move(value); }
return std::move(value);
}
bool finished() { bool finished() {
//return !value.has_value(); // return !value.has_value();
return final; return final;
} }
...@@ -60,12 +56,12 @@ public: ...@@ -60,12 +56,12 @@ public:
using value_type = T; using value_type = T;
using promise_type = Promise; using promise_type = Promise;
explicit Generator(std::coroutine_handle<Promise> handle) explicit Generator(std::coroutine_handle<Promise> handle) : handle(handle) {}
: handle (handle)
{}
~Generator() { ~Generator() {
if (handle) { handle.destroy(); } if (handle) {
handle.destroy();
}
} }
Promise::value_type next() { Promise::value_type next() {
...@@ -74,24 +70,21 @@ public: ...@@ -74,24 +70,21 @@ public:
handle.resume(); handle.resume();
} }
return handle.promise().get_value(); return handle.promise().get_value();
} } else {
else {
return {}; return {};
} }
} }
struct end_iterator {}; struct end_iterator {};
class iterator class iterator {
{
public: public:
using value_type = Promise::value_type; using value_type = Promise::value_type;
using difference_type = std::ptrdiff_t; using difference_type = std::ptrdiff_t;
using iterator_category = std::input_iterator_tag; using iterator_category = std::input_iterator_tag;
iterator() = default; iterator() = default;
iterator(Generator& generator) : generator{&generator} iterator(Generator &generator) : generator{&generator} {}
{}
value_type operator*() const { value_type operator*() const {
if (generator) { if (generator) {
...@@ -107,26 +100,26 @@ public: ...@@ -107,26 +100,26 @@ public:
return {}; return {};
} }
iterator& operator++() { iterator &operator++() {
if (generator && generator->handle) { if (generator && generator->handle) {
generator->handle.resume(); generator->handle.resume();
} }
return *this; return *this;
} }
iterator& operator++(int) { iterator &operator++(int) {
if (generator && generator->handle) { if (generator && generator->handle) {
generator->handle.resume(); generator->handle.resume();
} }
return *this; return *this;
} }
bool operator== (const end_iterator&) const { bool operator==(const end_iterator &) const {
return generator ? generator->handle.promise().finished() : true; return generator ? generator->handle.promise().finished() : true;
} }
private: private:
Generator* generator{}; Generator *generator{};
}; };
iterator begin() { iterator begin() {
...@@ -134,24 +127,18 @@ public: ...@@ -134,24 +127,18 @@ public:
return ++it; return ++it;
} }
end_iterator end() { end_iterator end() { return end_sentinel; }
return end_sentinel;
}
std::optional<value_type> py_next() { std::optional<value_type> py_next() { return next(); }
return next();
}
private: private:
end_iterator end_sentinel{}; end_iterator end_sentinel{};
std::coroutine_handle<Promise> handle; std::coroutine_handle<Promise> handle;
}; };
template <typename T> template <typename T>
inline Generator<T> Generator<T>::Promise::get_return_object() inline Generator<T> Generator<T>::Promise::get_return_object() {
{ return Generator{std::coroutine_handle<Promise>::from_promise(*this)};
return Generator{ std::coroutine_handle<Promise>::from_promise(*this) };
} }
} // namespace typon } // namespace typon
......
# coding: utf-8 # coding: utf-8
import ast import ast
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum, auto, Flag
from itertools import chain, zip_longest from itertools import chain, zip_longest
from typing import * from typing import *
...@@ -128,7 +128,7 @@ class NodeVisitor: ...@@ -128,7 +128,7 @@ class NodeVisitor:
yield from visitor(node) yield from visitor(node)
break break
else: else:
return self.missing_impl(node) yield from self.missing_impl(node)
def missing_impl(self, node): def missing_impl(self, node):
raise UnsupportedNodeError(node) raise UnsupportedNodeError(node)
...@@ -164,8 +164,8 @@ class SearchVisitor(NodeVisitor): ...@@ -164,8 +164,8 @@ class SearchVisitor(NodeVisitor):
elif isinstance(val, ast.AST): elif isinstance(val, ast.AST):
yield from self.visit(val) yield from self.visit(val)
def default(self, node): def match(self, node) -> bool:
pass return next(self.visit(node), False)
class PrecedenceContext: class PrecedenceContext:
...@@ -180,19 +180,20 @@ class PrecedenceContext: ...@@ -180,19 +180,20 @@ class PrecedenceContext:
self.visitor.precedence.pop() self.visitor.precedence.pop()
class CoroutineMode(Enum): class CoroutineMode(Flag):
NONE = 0 SYNC = 1
GENERATOR = 1 FAKE = 2 | SYNC
FAKE = 2 ASYNC = 4
TASK = 3 GENERATOR = 8 | ASYNC
TASK = 16 | ASYNC
# noinspection PyPep8Naming # noinspection PyPep8Naming
@dataclass @dataclass
class ExpressionVisitor(NodeVisitor): class ExpressionVisitor(NodeVisitor):
scope: "Scope" scope: "Scope"
generator: CoroutineMode
precedence: List = field(default_factory=list) precedence: List = field(default_factory=list)
generator: CoroutineMode = CoroutineMode.NONE
def visit(self, node): def visit(self, node):
if type(node) in SYMBOLS: if type(node) in SYMBOLS:
...@@ -210,13 +211,13 @@ class ExpressionVisitor(NodeVisitor): ...@@ -210,13 +211,13 @@ class ExpressionVisitor(NodeVisitor):
""" """
Sets the precedence of the next expression. Sets the precedence of the next expression.
""" """
return ExpressionVisitor(self.scope, [op], generator=self.generator) return ExpressionVisitor(self.scope, self.generator, [op])
def reset(self) -> "ExpressionVisitor": def reset(self) -> "ExpressionVisitor":
""" """
Resets the precedence stack. Resets the precedence stack.
""" """
return ExpressionVisitor(self.scope, generator=self.generator) return ExpressionVisitor(self.scope, self.generator)
def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]: def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]:
yield "std::make_tuple(" yield "std::make_tuple("
...@@ -262,15 +263,15 @@ class ExpressionVisitor(NodeVisitor): ...@@ -262,15 +263,15 @@ class ExpressionVisitor(NodeVisitor):
if getattr(node, "kwargs", None): if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs") raise NotImplementedError(node, "kwargs")
func = node.func func = node.func
if self.generator == CoroutineMode.TASK: # TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator:
yield "co_await " yield "co_await "
elif self.generator == CoroutineMode.FAKE: elif CoroutineMode.FAKE in self.generator:
func = ast.Attribute(value=func, attr="sync", ctx=ast.Load()) func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
with self.prec_ctx("co_await"): yield from self.prec("()").visit(func)
yield from self.prec("()").visit(func) yield "("
yield "(" yield from join(", ", map(self.reset().visit, node.args))
yield from join(", ", map(self.reset().visit, node.args)) yield ")"
yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]: def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
yield "[]" yield "[]"
...@@ -346,12 +347,10 @@ class ExpressionVisitor(NodeVisitor): ...@@ -346,12 +347,10 @@ class ExpressionVisitor(NodeVisitor):
yield from self.visit(node.orelse) yield from self.visit(node.orelse)
def visit_Yield(self, node: ast.Yield) -> Iterable[str]: def visit_Yield(self, node: ast.Yield) -> Iterable[str]:
if self.generator == CoroutineMode.NONE: if CoroutineMode.GENERATOR in self.generator:
raise UnsupportedNodeError(node)
elif self.generator == CoroutineMode.GENERATOR:
yield "co_yield" yield "co_yield"
yield from self.prec("co_yield").visit(node.value) yield from self.prec("co_yield").visit(node.value)
elif self.generator == CoroutineMode.FAKE: elif CoroutineMode.FAKE in self.generator:
yield "return" yield "return"
yield from self.visit(node.value) yield from self.visit(node.value)
else: else:
...@@ -451,25 +450,38 @@ class Scope: ...@@ -451,25 +450,38 @@ class Scope:
@dataclass @dataclass
class BlockVisitor(NodeVisitor): class BlockVisitor(NodeVisitor):
scope: Scope scope: Scope
generator: CoroutineMode = CoroutineMode.NONE generator: CoroutineMode = CoroutineMode.SYNC
def expr(self) -> ExpressionVisitor: def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(self.scope, generator=self.generator) return ExpressionVisitor(self.scope, self.generator)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from self.visit_coroutine(node)
# try:
# [*result] = self.visit_func(node, CoroutineMode.NONE)
# return result
# except UnsupportedNodeError as e:
# if isinstance(e.node, ast.Yield):
# return self.visit_coroutine(node)
# raise
def visit_coroutine(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {" yield "struct {"
yield from self.visit_func(node, CoroutineMode.FAKE) yield from self.visit_func(node, CoroutineMode.FAKE)
yield from self.visit_func(node, CoroutineMode.TASK)
class YieldVisitor(SearchVisitor):
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_yield = YieldVisitor().match(node.body)
yield from self.visit_func(node, CoroutineMode.GENERATOR if has_yield else CoroutineMode.TASK)
if has_yield:
templ, args, names = self.process_args(node.args)
if templ:
yield "template"
yield templ
yield f"auto operator()"
yield args
yield f"-> typon::Task<decltype(gen({', '.join(names)}))>"
yield "{"
yield f"co_return gen({', '.join(names)});"
yield "}"
yield f"}} {node.name};" yield f"}} {node.name};"
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]: def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
...@@ -482,25 +494,35 @@ class BlockVisitor(NodeVisitor): ...@@ -482,25 +494,35 @@ class BlockVisitor(NodeVisitor):
def visit_Return(self, node: ast.Return) -> bool: def visit_Return(self, node: ast.Return) -> bool:
yield True yield True
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef): def visit_FunctionDef(self, node: ast.FunctionDef):
yield from () yield from ()
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
yield from () yield from ()
has_return = next(ReturnVisitor().visit(node.body), False) has_return = ReturnVisitor().match(node.body)
if generator == CoroutineMode.FAKE: if CoroutineMode.SYNC in generator:
if has_return: if has_return:
yield "auto" yield "auto"
else: else:
yield "void" yield "void"
yield "sync" yield "sync"
elif CoroutineMode.GENERATOR in generator:
yield "auto gen"
else: else:
yield f"auto operator()" yield "auto operator()"
yield args yield args
if generator == CoroutineMode.TASK: if CoroutineMode.ASYNC in generator:
yield f"-> typon::Task<decltype(sync({', '.join(names)}))>" yield "-> typon::"
if CoroutineMode.TASK in generator:
yield "Task"
elif CoroutineMode.GENERATOR in generator:
yield "Generator"
yield f"<decltype(sync({', '.join(names)}))>"
yield "{" yield "{"
inner_scope = self.scope.function(vars={node.name: VarDecl(VarKind.SELF, None)}) inner_scope = self.scope.function(vars={node.name: VarDecl(VarKind.SELF, None)})
for child in node.body: for child in node.body:
...@@ -555,9 +577,9 @@ class BlockVisitor(NodeVisitor): ...@@ -555,9 +577,9 @@ class BlockVisitor(NodeVisitor):
elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is. elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is.
inner_scope.vars[var] = decl inner_scope.vars[var] = decl
yield from child_code # Yeet back the child node code. yield from child_code # Yeet back the child node code.
if generator == CoroutineMode.FAKE: if CoroutineMode.FAKE in generator:
yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements. yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements.
elif generator == CoroutineMode.TASK: elif CoroutineMode.TASK in generator:
if not has_return: if not has_return:
yield "co_return;" yield "co_return;"
yield "}" yield "}"
...@@ -682,7 +704,7 @@ class FunctionVisitor(BlockVisitor): ...@@ -682,7 +704,7 @@ class FunctionVisitor(BlockVisitor):
yield from self.emit_block(node.orelse) yield from self.emit_block(node.orelse)
def visit_Return(self, node: ast.Return) -> Iterable[str]: def visit_Return(self, node: ast.Return) -> Iterable[str]:
if self.generator == CoroutineMode.TASK: if CoroutineMode.ASYNC in self.generator:
yield "co_return " yield "co_return "
else: else:
yield "return " yield "return "
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment