Commit 1f9e68f2 authored by Tom Niget's avatar Tom Niget

Add basic support for generators (co_yield)

parent 776ac174
......@@ -33,6 +33,19 @@ concept PyLen = requires(const T &t) {
template <PyLen T> size_t len(const T &t) { return t.py_len(); }
template <typename T>
concept PyNext = requires(T t) {
t.py_next();
};
template <PyNext T> auto next(T &t) { return t.py_next(); }
template<typename T>
std::ostream& operator<<(std::ostream& os, std::optional<T> const& opt)
{
return opt ? os << opt.value() : os << "None";
}
bool is_cpp() { return true; }
#include "builtins/bool.hpp"
......@@ -44,4 +57,6 @@ bool is_cpp() { return true; }
#include "builtins/set.hpp"
#include "builtins/str.hpp"
#include "../typon/generator.hpp"
#endif // TYPON_BUILTINS_HPP
//
// Created by Tom on 13/03/2023.
//
#ifndef TYPON_RANGE_HPP
#define TYPON_RANGE_HPP
#include <ranges>
// todo: proper range support
template <typename T> auto range(T stop) { return std::views::iota(0, stop); }
template <typename T> auto range(T start, T stop) {
return std::views::iota(start, stop);
}
#endif // TYPON_RANGE_HPP
......@@ -10,6 +10,9 @@ from transpiler.format import format_code
def run_tests():
for path in Path('tests').glob('*.py'):
print(path.name)
if path.name.startswith('_'):
print("Skipping")
continue
with open(path, "r", encoding="utf-8") as f:
res = format_code(transpile(f.read()))
name_cpp = path.with_suffix('.cpp')
......
# coding: utf-8
def fib():
a = 0
b = 1
while True:
yield a
a, b = b, a + b
if __name__ == "__main__":
f = fib()
for i in range(10):
print(next(f))
\ No newline at end of file
......@@ -108,6 +108,14 @@ class VarKind(Enum):
NONLOCAL = 3
@dataclass
class UnsupportedNodeError(Exception):
node: ast.AST
def __str__(self) -> str:
return f"Unsupported node: {self.node.__class__.__mro__} {ast.dump(self.node)}"
class NodeVisitor:
def visit(self, node):
"""Visit a node."""
......@@ -119,7 +127,7 @@ class NodeVisitor:
yield from visitor(node)
break
else:
raise NotImplementedError(node.__class__.__mro__, ast.dump(node))
raise UnsupportedNodeError(node)
def process_args(self, node: ast.arguments) -> (str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
......@@ -151,11 +159,17 @@ class PrecedenceContext:
def __exit__(self, exc_type, exc_val, exc_tb):
self.visitor.precedence.pop()
class CoroutineMode(Enum):
NONE = 0
GENERATOR = 1
FAKE = 2
# noinspection PyPep8Naming
@dataclass
class ExpressionVisitor(NodeVisitor):
def __init__(self, precedence=None):
self.precedence = precedence or []
precedence: List = field(default_factory=list)
generator: CoroutineMode = CoroutineMode.NONE
def prec_ctx(self, op: str) -> PrecedenceContext:
"""
......@@ -167,13 +181,13 @@ class ExpressionVisitor(NodeVisitor):
"""
Sets the precedence of the next expression.
"""
return ExpressionVisitor([op])
return ExpressionVisitor([op], generator=self.generator)
def reset(self) -> "ExpressionVisitor":
"""
Resets the precedence stack.
"""
return ExpressionVisitor()
return ExpressionVisitor(generator=self.generator)
def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]:
yield "std::make_tuple("
......@@ -234,6 +248,8 @@ class ExpressionVisitor(NodeVisitor):
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST) -> Iterable[str]:
op = SYMBOLS[type(op)]
# TODO: handle precedence locally since only binops really need it
# we could just store the history of traversed nodes and check if the last one was a binop
prio = self.precedence and PRECEDENCE_LEVELS[self.precedence[-1]] < PRECEDENCE_LEVELS[op]
if prio:
yield "("
......@@ -289,6 +305,15 @@ class ExpressionVisitor(NodeVisitor):
yield " : "
yield from self.visit(node.orelse)
def visit_Yield(self, node: ast.Yield) -> Iterable[str]:
if self.generator == CoroutineMode.NONE:
raise UnsupportedNodeError(node)
elif self.generator == CoroutineMode.GENERATOR:
yield "co_yield"
elif self.generator == CoroutineMode.FAKE:
yield "return"
yield from self.visit(node.value)
@dataclass
class VarDecl:
......@@ -370,19 +395,34 @@ class Scope:
# noinspection PyPep8Naming
@dataclass
class BlockVisitor(NodeVisitor):
def __init__(self, scope: Scope):
self._scope = scope
scope: Scope
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
try:
[*result] = self.visit_func(node, CoroutineMode.NONE)
return result
except UnsupportedNodeError as e:
if isinstance(e.node, ast.Yield):
return chain(self.visit_func(node, CoroutineMode.FAKE), self.visit_func(node, CoroutineMode.GENERATOR))
raise
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
templ, args = self.process_args(node.args)
if templ:
yield "template"
yield templ
yield f"auto {node.name}"
faked = f"FAKED_{node.name}"
if generator == CoroutineMode.FAKE:
yield f"static auto {faked}"
elif generator == CoroutineMode.GENERATOR:
yield f"typon::Generator<decltype({faked}())> {node.name}"
else:
yield f"auto {node.name}"
yield args
yield "{"
inner = FunctionVisitor(self._scope.function())
inner_scope = self.scope.function()
for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
# on their own. Variables are still accessible in the remainder of the parent function or in the global
......@@ -422,18 +462,18 @@ class BlockVisitor(NodeVisitor):
# auto y = 2;
# }
# ```
child_visitor = FunctionVisitor(inner._scope.child())
child_visitor = FunctionVisitor(inner_scope.child(), generator)
# We need to do this in two-passes. This unfortunately breaks our nice generator state-machine architecture.
# Fair enough.
[*child_code] = child_visitor.visit(child)
# Hoist inner variables to the root scope.
for var, decl in child_visitor._scope.vars.items():
for var, decl in child_visitor.scope.vars.items():
if decl.kind == VarKind.LOCAL: # Nested declarations become `decltype` declarations.
yield f"decltype({decl.val}) {var};"
elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is.
inner._scope.vars[var] = decl
inner_scope.vars[var] = decl
yield from child_code # Yeet back the child node code.
yield "}"
......@@ -443,8 +483,8 @@ class BlockVisitor(NodeVisitor):
elif isinstance(lvalue, ast.Name):
name = self.fix_name(lvalue.id)
# if name not in self._scope.vars:
if not self._scope.exists_local(name):
yield self._scope.declare(name, " ".join(ExpressionVisitor().visit(val)) if val else None)
if not self.scope.exists_local(name):
yield self.scope.declare(name, " ".join(ExpressionVisitor().visit(val)) if val else None)
yield name
elif isinstance(lvalue, ast.Subscript):
yield from ExpressionVisitor().visit(lvalue)
......@@ -473,7 +513,7 @@ class FileVisitor(BlockVisitor):
def visit_Module(self, node: ast.Module) -> Iterable[str]:
stmt: ast.AST
yield "#include <python/builtins.hpp>"
visitor = ModuleVisitor(self._scope)
visitor = ModuleVisitor(self.scope)
for stmt in node.body:
yield from visitor.visit(stmt)
......@@ -509,29 +549,33 @@ class ModuleVisitor(BlockVisitor):
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default.
yield "int main()"
yield from FunctionVisitor(self._scope).emit_block(node.body)
yield from FunctionVisitor(self.scope.function()).emit_block(node.body)
return
raise NotImplementedError(node, "global scope if")
# noinspection PyPep8Naming
@dataclass
class FunctionVisitor(BlockVisitor):
generator: CoroutineMode = CoroutineMode.NONE
def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
yield from ExpressionVisitor().visit(node.value)
print(ast.dump(node))
yield from self.expr().visit(node.value)
yield ";"
def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target)
yield SYMBOLS[type(node.op)] + "="
yield from ExpressionVisitor().visit(node.value)
yield from self.expr().visit(node.value)
yield ";"
def visit_For(self, node: ast.For) -> Iterable[str]:
if not isinstance(node.target, ast.Name):
raise NotImplementedError(node)
yield f"for (auto {node.target.id} : "
yield from ExpressionVisitor().visit(node.iter)
yield from self.expr().visit(node.iter)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
......@@ -539,7 +583,7 @@ class FunctionVisitor(BlockVisitor):
def visit_If(self, node: ast.If) -> Iterable[str]:
yield "if ("
yield from ExpressionVisitor().visit(node.test)
yield from self.expr().visit(node.test)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
......@@ -552,12 +596,12 @@ class FunctionVisitor(BlockVisitor):
def visit_Return(self, node: ast.Return) -> Iterable[str]:
yield "return "
if node.value:
yield from ExpressionVisitor().visit(node.value)
yield from self.expr().visit(node.value)
yield ";"
def visit_While(self, node: ast.While) -> Iterable[str]:
yield "while ("
yield from ExpressionVisitor().visit(node.test)
yield from self.expr().visit(node.test)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
......@@ -565,19 +609,22 @@ class FunctionVisitor(BlockVisitor):
def visit_Global(self, node: ast.Global) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarDecl(VarKind.GLOBAL, None)
self.scope.vars[name] = VarDecl(VarKind.GLOBAL, None)
yield ""
def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarDecl(VarKind.NONLOCAL, None)
self.scope.vars[name] = VarDecl(VarKind.NONLOCAL, None)
yield ""
def block(self) -> "FunctionVisitor":
# See the comments in visit_FunctionDef.
# A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same
# variables as the parent scope.
return FunctionVisitor(self._scope.child_share())
return FunctionVisitor(self.scope.child_share(), self.generator)
def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(generator=self.generator)
def emit_block(self, items: List[ast.stmt]) -> Iterable[str]:
yield "{"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment