Commit 37e3c944 authored by Tom Niget's avatar Tom Niget

Add initial support for generators with parameters

parent 1f9e68f2
...@@ -35,19 +35,32 @@ template <PyLen T> size_t len(const T &t) { return t.py_len(); } ...@@ -35,19 +35,32 @@ template <PyLen T> size_t len(const T &t) { return t.py_len(); }
template <typename T> template <typename T>
concept PyNext = requires(T t) { concept PyNext = requires(T t) {
t.py_next(); { t.py_next() } -> std::same_as<std::optional<typename T::value_type>>;
}; };
template <PyNext T> auto next(T &t) { return t.py_next(); } template <PyNext T>
std::optional<typename T::value_type>
next(T &t, std::optional<typename T::value_type> def = std::nullopt) {
auto opt = t.py_next();
return opt ? opt : def;
}
template<typename T> template <typename T>
std::ostream& operator<<(std::ostream& os, std::optional<T> const& opt) std::ostream &operator<<(std::ostream &os, std::optional<T> const &opt) {
{
return opt ? os << opt.value() : os << "None"; return opt ? os << opt.value() : os << "None";
} }
bool is_cpp() { return true; } bool is_cpp() { return true; }
class NoneType {
public:
template <typename T> operator T *() const { return nullptr; }
template <typename T> operator std::optional<T>() const {
return std::nullopt;
}
} PyNone{};
#include "builtins/bool.hpp" #include "builtins/bool.hpp"
#include "builtins/complex.hpp" #include "builtins/complex.hpp"
#include "builtins/dict.hpp" #include "builtins/dict.hpp"
......
# coding: utf-8 # coding: utf-8
def fib(): def fib(upto):
a = 0 a = 0
b = 1 b = 1
while True: while b < upto:
yield a yield a
a, b = b, a + b a, b = b, a + b
if __name__ == "__main__": if __name__ == "__main__":
f = fib() f = fib(50)
for i in range(10): for i in range(15):
print(next(f)) print(next(f, None))
\ No newline at end of file \ No newline at end of file
...@@ -129,17 +129,18 @@ class NodeVisitor: ...@@ -129,17 +129,18 @@ class NodeVisitor:
else: else:
raise UnsupportedNodeError(node) raise UnsupportedNodeError(node)
def process_args(self, node: ast.arguments) -> (str, str): def process_args(self, node: ast.arguments) -> (str, str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"): for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
if getattr(node, field, None): if getattr(node, field, None):
raise NotImplementedError(node, field) raise NotImplementedError(node, field)
if not node.args: if not node.args:
return "", "()" return "", "()", []
f_args = [(arg.arg, f"T{i + 1}") for i, arg in enumerate(node.args)] f_args = [(self.fix_name(arg.arg), f"T{i + 1}") for i, arg in enumerate(node.args)]
return ( return (
"<" + ", ".join(f"typename {t}" for _, t in f_args) + ">", "<" + ", ".join(f"typename {t}" for _, t in f_args) + ">",
"(" + ", ".join(f"{t} {self.fix_name(n)}" for n, t in f_args) + ")" "(" + ", ".join(f"{t} {n}" for n, t in f_args) + ")",
[n for n, _ in f_args]
) )
def fix_name(self, name: str) -> str: def fix_name(self, name: str) -> str:
...@@ -205,6 +206,8 @@ class ExpressionVisitor(NodeVisitor): ...@@ -205,6 +206,8 @@ class ExpressionVisitor(NodeVisitor):
yield str(node.value) yield str(node.value)
elif isinstance(node.value, complex): elif isinstance(node.value, complex):
yield f"PyComplex({node.value.real}, {node.value.imag})" yield f"PyComplex({node.value.real}, {node.value.imag})"
elif node.value is None:
yield "PyNone"
else: else:
raise NotImplementedError(node, type(node)) raise NotImplementedError(node, type(node))
...@@ -234,7 +237,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -234,7 +237,7 @@ class ExpressionVisitor(NodeVisitor):
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]: def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
yield "[]" yield "[]"
templ, args = self.process_args(node.args) templ, args, _ = self.process_args(node.args)
yield templ yield templ
yield args yield args
yield "{" yield "{"
...@@ -409,18 +412,18 @@ class BlockVisitor(NodeVisitor): ...@@ -409,18 +412,18 @@ class BlockVisitor(NodeVisitor):
raise raise
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]: def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
templ, args = self.process_args(node.args) templ, args, names = self.process_args(node.args)
if templ: if templ:
yield "template" yield "template"
yield templ yield templ
faked = f"FAKED_{node.name}" faked = f"FAKED_{node.name}"
if generator == CoroutineMode.FAKE: if generator == CoroutineMode.FAKE:
yield f"static auto {faked}" yield f"static auto {faked}"
elif generator == CoroutineMode.GENERATOR:
yield f"typon::Generator<decltype({faked}())> {node.name}"
else: else:
yield f"auto {node.name}" yield f"auto {node.name}"
yield args yield args
if generator == CoroutineMode.GENERATOR:
yield f"-> typon::Generator<decltype({faked}({', '.join(names)}))>"
yield "{" yield "{"
inner_scope = self.scope.function() inner_scope = self.scope.function()
for child in node.body: for child in node.body:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment