Commit 1f9e68f2 authored by Tom Niget's avatar Tom Niget

Add basic support for generators (co_yield)

parent 776ac174
...@@ -33,6 +33,19 @@ concept PyLen = requires(const T &t) { ...@@ -33,6 +33,19 @@ concept PyLen = requires(const T &t) {
template <PyLen T> size_t len(const T &t) { return t.py_len(); } template <PyLen T> size_t len(const T &t) { return t.py_len(); }
template <typename T>
concept PyNext = requires(T t) {
t.py_next();
};
template <PyNext T> auto next(T &t) { return t.py_next(); }
template<typename T>
std::ostream& operator<<(std::ostream& os, std::optional<T> const& opt)
{
return opt ? os << opt.value() : os << "None";
}
bool is_cpp() { return true; } bool is_cpp() { return true; }
#include "builtins/bool.hpp" #include "builtins/bool.hpp"
...@@ -44,4 +57,6 @@ bool is_cpp() { return true; } ...@@ -44,4 +57,6 @@ bool is_cpp() { return true; }
#include "builtins/set.hpp" #include "builtins/set.hpp"
#include "builtins/str.hpp" #include "builtins/str.hpp"
#include "../typon/generator.hpp"
#endif // TYPON_BUILTINS_HPP #endif // TYPON_BUILTINS_HPP
//
// Created by Tom on 13/03/2023.
//
#ifndef TYPON_RANGE_HPP
#define TYPON_RANGE_HPP
#include <ranges>
// todo: proper range support
template <typename T> auto range(T stop) { return std::views::iota(0, stop); }
template <typename T> auto range(T start, T stop) {
return std::views::iota(start, stop);
}
#endif // TYPON_RANGE_HPP
...@@ -10,6 +10,9 @@ from transpiler.format import format_code ...@@ -10,6 +10,9 @@ from transpiler.format import format_code
def run_tests(): def run_tests():
for path in Path('tests').glob('*.py'): for path in Path('tests').glob('*.py'):
print(path.name) print(path.name)
if path.name.startswith('_'):
print("Skipping")
continue
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
res = format_code(transpile(f.read())) res = format_code(transpile(f.read()))
name_cpp = path.with_suffix('.cpp') name_cpp = path.with_suffix('.cpp')
......
# coding: utf-8
def fib():
a = 0
b = 1
while True:
yield a
a, b = b, a + b
if __name__ == "__main__":
f = fib()
for i in range(10):
print(next(f))
\ No newline at end of file
...@@ -108,6 +108,14 @@ class VarKind(Enum): ...@@ -108,6 +108,14 @@ class VarKind(Enum):
NONLOCAL = 3 NONLOCAL = 3
@dataclass
class UnsupportedNodeError(Exception):
node: ast.AST
def __str__(self) -> str:
return f"Unsupported node: {self.node.__class__.__mro__} {ast.dump(self.node)}"
class NodeVisitor: class NodeVisitor:
def visit(self, node): def visit(self, node):
"""Visit a node.""" """Visit a node."""
...@@ -119,7 +127,7 @@ class NodeVisitor: ...@@ -119,7 +127,7 @@ class NodeVisitor:
yield from visitor(node) yield from visitor(node)
break break
else: else:
raise NotImplementedError(node.__class__.__mro__, ast.dump(node)) raise UnsupportedNodeError(node)
def process_args(self, node: ast.arguments) -> (str, str): def process_args(self, node: ast.arguments) -> (str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"): for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
...@@ -151,11 +159,17 @@ class PrecedenceContext: ...@@ -151,11 +159,17 @@ class PrecedenceContext:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.visitor.precedence.pop() self.visitor.precedence.pop()
class CoroutineMode(Enum):
NONE = 0
GENERATOR = 1
FAKE = 2
# noinspection PyPep8Naming # noinspection PyPep8Naming
@dataclass
class ExpressionVisitor(NodeVisitor): class ExpressionVisitor(NodeVisitor):
def __init__(self, precedence=None): precedence: List = field(default_factory=list)
self.precedence = precedence or [] generator: CoroutineMode = CoroutineMode.NONE
def prec_ctx(self, op: str) -> PrecedenceContext: def prec_ctx(self, op: str) -> PrecedenceContext:
""" """
...@@ -167,13 +181,13 @@ class ExpressionVisitor(NodeVisitor): ...@@ -167,13 +181,13 @@ class ExpressionVisitor(NodeVisitor):
""" """
Sets the precedence of the next expression. Sets the precedence of the next expression.
""" """
return ExpressionVisitor([op]) return ExpressionVisitor([op], generator=self.generator)
def reset(self) -> "ExpressionVisitor": def reset(self) -> "ExpressionVisitor":
""" """
Resets the precedence stack. Resets the precedence stack.
""" """
return ExpressionVisitor() return ExpressionVisitor(generator=self.generator)
def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]: def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]:
yield "std::make_tuple(" yield "std::make_tuple("
...@@ -234,6 +248,8 @@ class ExpressionVisitor(NodeVisitor): ...@@ -234,6 +248,8 @@ class ExpressionVisitor(NodeVisitor):
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST) -> Iterable[str]: def visit_binary_operation(self, op, left: ast.AST, right: ast.AST) -> Iterable[str]:
op = SYMBOLS[type(op)] op = SYMBOLS[type(op)]
# TODO: handle precedence locally since only binops really need it
# we could just store the history of traversed nodes and check if the last one was a binop
prio = self.precedence and PRECEDENCE_LEVELS[self.precedence[-1]] < PRECEDENCE_LEVELS[op] prio = self.precedence and PRECEDENCE_LEVELS[self.precedence[-1]] < PRECEDENCE_LEVELS[op]
if prio: if prio:
yield "(" yield "("
...@@ -289,6 +305,15 @@ class ExpressionVisitor(NodeVisitor): ...@@ -289,6 +305,15 @@ class ExpressionVisitor(NodeVisitor):
yield " : " yield " : "
yield from self.visit(node.orelse) yield from self.visit(node.orelse)
def visit_Yield(self, node: ast.Yield) -> Iterable[str]:
if self.generator == CoroutineMode.NONE:
raise UnsupportedNodeError(node)
elif self.generator == CoroutineMode.GENERATOR:
yield "co_yield"
elif self.generator == CoroutineMode.FAKE:
yield "return"
yield from self.visit(node.value)
@dataclass @dataclass
class VarDecl: class VarDecl:
...@@ -370,19 +395,34 @@ class Scope: ...@@ -370,19 +395,34 @@ class Scope:
# noinspection PyPep8Naming # noinspection PyPep8Naming
@dataclass
class BlockVisitor(NodeVisitor): class BlockVisitor(NodeVisitor):
def __init__(self, scope: Scope): scope: Scope
self._scope = scope
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
try:
[*result] = self.visit_func(node, CoroutineMode.NONE)
return result
except UnsupportedNodeError as e:
if isinstance(e.node, ast.Yield):
return chain(self.visit_func(node, CoroutineMode.FAKE), self.visit_func(node, CoroutineMode.GENERATOR))
raise
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
templ, args = self.process_args(node.args) templ, args = self.process_args(node.args)
if templ: if templ:
yield "template" yield "template"
yield templ yield templ
yield f"auto {node.name}" faked = f"FAKED_{node.name}"
if generator == CoroutineMode.FAKE:
yield f"static auto {faked}"
elif generator == CoroutineMode.GENERATOR:
yield f"typon::Generator<decltype({faked}())> {node.name}"
else:
yield f"auto {node.name}"
yield args yield args
yield "{" yield "{"
inner = FunctionVisitor(self._scope.function()) inner_scope = self.scope.function()
for child in node.body: for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes # Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
# on their own. Variables are still accessible in the remainder of the parent function or in the global # on their own. Variables are still accessible in the remainder of the parent function or in the global
...@@ -422,18 +462,18 @@ class BlockVisitor(NodeVisitor): ...@@ -422,18 +462,18 @@ class BlockVisitor(NodeVisitor):
# auto y = 2; # auto y = 2;
# } # }
# ``` # ```
child_visitor = FunctionVisitor(inner._scope.child()) child_visitor = FunctionVisitor(inner_scope.child(), generator)
# We need to do this in two-passes. This unfortunately breaks our nice generator state-machine architecture. # We need to do this in two-passes. This unfortunately breaks our nice generator state-machine architecture.
# Fair enough. # Fair enough.
[*child_code] = child_visitor.visit(child) [*child_code] = child_visitor.visit(child)
# Hoist inner variables to the root scope. # Hoist inner variables to the root scope.
for var, decl in child_visitor._scope.vars.items(): for var, decl in child_visitor.scope.vars.items():
if decl.kind == VarKind.LOCAL: # Nested declarations become `decltype` declarations. if decl.kind == VarKind.LOCAL: # Nested declarations become `decltype` declarations.
yield f"decltype({decl.val}) {var};" yield f"decltype({decl.val}) {var};"
elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is. elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is.
inner._scope.vars[var] = decl inner_scope.vars[var] = decl
yield from child_code # Yeet back the child node code. yield from child_code # Yeet back the child node code.
yield "}" yield "}"
...@@ -443,8 +483,8 @@ class BlockVisitor(NodeVisitor): ...@@ -443,8 +483,8 @@ class BlockVisitor(NodeVisitor):
elif isinstance(lvalue, ast.Name): elif isinstance(lvalue, ast.Name):
name = self.fix_name(lvalue.id) name = self.fix_name(lvalue.id)
# if name not in self._scope.vars: # if name not in self._scope.vars:
if not self._scope.exists_local(name): if not self.scope.exists_local(name):
yield self._scope.declare(name, " ".join(ExpressionVisitor().visit(val)) if val else None) yield self.scope.declare(name, " ".join(ExpressionVisitor().visit(val)) if val else None)
yield name yield name
elif isinstance(lvalue, ast.Subscript): elif isinstance(lvalue, ast.Subscript):
yield from ExpressionVisitor().visit(lvalue) yield from ExpressionVisitor().visit(lvalue)
...@@ -473,7 +513,7 @@ class FileVisitor(BlockVisitor): ...@@ -473,7 +513,7 @@ class FileVisitor(BlockVisitor):
def visit_Module(self, node: ast.Module) -> Iterable[str]: def visit_Module(self, node: ast.Module) -> Iterable[str]:
stmt: ast.AST stmt: ast.AST
yield "#include <python/builtins.hpp>" yield "#include <python/builtins.hpp>"
visitor = ModuleVisitor(self._scope) visitor = ModuleVisitor(self.scope)
for stmt in node.body: for stmt in node.body:
yield from visitor.visit(stmt) yield from visitor.visit(stmt)
...@@ -509,29 +549,33 @@ class ModuleVisitor(BlockVisitor): ...@@ -509,29 +549,33 @@ class ModuleVisitor(BlockVisitor):
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`. # Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default. # 0 is returned by default.
yield "int main()" yield "int main()"
yield from FunctionVisitor(self._scope).emit_block(node.body) yield from FunctionVisitor(self.scope.function()).emit_block(node.body)
return return
raise NotImplementedError(node, "global scope if") raise NotImplementedError(node, "global scope if")
# noinspection PyPep8Naming # noinspection PyPep8Naming
@dataclass
class FunctionVisitor(BlockVisitor): class FunctionVisitor(BlockVisitor):
generator: CoroutineMode = CoroutineMode.NONE
def visit_Expr(self, node: ast.Expr) -> Iterable[str]: def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
yield from ExpressionVisitor().visit(node.value) print(ast.dump(node))
yield from self.expr().visit(node.value)
yield ";" yield ";"
def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]: def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target) yield from self.visit_lvalue(node.target)
yield SYMBOLS[type(node.op)] + "=" yield SYMBOLS[type(node.op)] + "="
yield from ExpressionVisitor().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
def visit_For(self, node: ast.For) -> Iterable[str]: def visit_For(self, node: ast.For) -> Iterable[str]:
if not isinstance(node.target, ast.Name): if not isinstance(node.target, ast.Name):
raise NotImplementedError(node) raise NotImplementedError(node)
yield f"for (auto {node.target.id} : " yield f"for (auto {node.target.id} : "
yield from ExpressionVisitor().visit(node.iter) yield from self.expr().visit(node.iter)
yield ")" yield ")"
yield from self.emit_block(node.body) yield from self.emit_block(node.body)
if node.orelse: if node.orelse:
...@@ -539,7 +583,7 @@ class FunctionVisitor(BlockVisitor): ...@@ -539,7 +583,7 @@ class FunctionVisitor(BlockVisitor):
def visit_If(self, node: ast.If) -> Iterable[str]: def visit_If(self, node: ast.If) -> Iterable[str]:
yield "if (" yield "if ("
yield from ExpressionVisitor().visit(node.test) yield from self.expr().visit(node.test)
yield ")" yield ")"
yield from self.emit_block(node.body) yield from self.emit_block(node.body)
if node.orelse: if node.orelse:
...@@ -552,12 +596,12 @@ class FunctionVisitor(BlockVisitor): ...@@ -552,12 +596,12 @@ class FunctionVisitor(BlockVisitor):
def visit_Return(self, node: ast.Return) -> Iterable[str]: def visit_Return(self, node: ast.Return) -> Iterable[str]:
yield "return " yield "return "
if node.value: if node.value:
yield from ExpressionVisitor().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
def visit_While(self, node: ast.While) -> Iterable[str]: def visit_While(self, node: ast.While) -> Iterable[str]:
yield "while (" yield "while ("
yield from ExpressionVisitor().visit(node.test) yield from self.expr().visit(node.test)
yield ")" yield ")"
yield from self.emit_block(node.body) yield from self.emit_block(node.body)
if node.orelse: if node.orelse:
...@@ -565,19 +609,22 @@ class FunctionVisitor(BlockVisitor): ...@@ -565,19 +609,22 @@ class FunctionVisitor(BlockVisitor):
def visit_Global(self, node: ast.Global) -> Iterable[str]: def visit_Global(self, node: ast.Global) -> Iterable[str]:
for name in map(self.fix_name, node.names): for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarDecl(VarKind.GLOBAL, None) self.scope.vars[name] = VarDecl(VarKind.GLOBAL, None)
yield "" yield ""
def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]: def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]:
for name in map(self.fix_name, node.names): for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarDecl(VarKind.NONLOCAL, None) self.scope.vars[name] = VarDecl(VarKind.NONLOCAL, None)
yield "" yield ""
def block(self) -> "FunctionVisitor": def block(self) -> "FunctionVisitor":
# See the comments in visit_FunctionDef. # See the comments in visit_FunctionDef.
# A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same # A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same
# variables as the parent scope. # variables as the parent scope.
return FunctionVisitor(self._scope.child_share()) return FunctionVisitor(self.scope.child_share(), self.generator)
def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(generator=self.generator)
def emit_block(self, items: List[ast.stmt]) -> Iterable[str]: def emit_block(self, items: List[ast.stmt]) -> Iterable[str]:
yield "{" yield "{"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment