Commit 3487ebcb authored by Tom Niget's avatar Tom Niget

Restructure code

parent 2ca5dcf1
This diff is collapsed.
# coding: utf-8
import ast
SYMBOLS = {
ast.Eq: "==",
ast.NotEq: '!=',
ast.Pass: '/* pass */',
ast.Mult: '*',
ast.Add: '+',
ast.Sub: '-',
ast.Div: '/',
ast.FloorDiv: '/', # TODO
ast.Mod: '%',
ast.Lt: '<',
ast.Gt: '>',
ast.GtE: '>=',
ast.LtE: '<=',
ast.LShift: '<<',
ast.RShift: '>>',
ast.BitXor: '^',
ast.BitOr: '|',
ast.BitAnd: '&',
ast.Not: '!',
ast.IsNot: '!=',
ast.USub: '-',
ast.And: '&&',
ast.Or: '||'
}
"""Mapping of Python AST nodes to C++ symbols."""
PRECEDENCE = [
("()", "[]", ".",),
("unary", "co_await"),
("*", "/", "%",),
("+", "-"),
("<<", ">>"),
("<", "<=", ">", ">="),
("==", "!="),
("&",),
("^",),
("|",),
("&&",),
("||",),
("?:", "co_yield"),
(",",)
]
"""Precedence of C++ operators."""
PRECEDENCE_LEVELS = {op: i for i, ops in enumerate(PRECEDENCE) for op in ops}
"""Mapping of C++ operators to their precedence level."""
MAPPINGS = {
"True": "true",
"False": "false",
"None": "nullptr"
}
"""Mapping of Python builtin constants to C++ equivalents."""
# coding: utf-8
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Dict
class VarKind(Enum):
"""Kind of variable."""
LOCAL = 1
GLOBAL = 2
NONLOCAL = 3
SELF = 4
@dataclass
class VarDecl:
kind: VarKind
val: Optional[str]
@dataclass
class Scope:
parent: Optional["Scope"] = None
is_function: bool = False
vars: Dict[str, VarDecl] = field(default_factory=dict)
def is_global(self) -> bool:
"""
Determines whether this scope is the global scope. The global scope is the only scope to have no parent.
"""
return self.parent is None
def exists(self, name: str) -> bool:
"""
Determines whether a variable exists in the current scope or any parent scope.
"""
return name in self.vars or (self.parent is not None and self.parent.exists(name))
def get(self, name: str) -> Optional[VarDecl]:
"""
Gets the variable declaration of a variable in the current scope or any parent scope.
"""
if res := self.vars.get(name):
return res
if self.parent is not None:
return self.parent.get(name)
return None
def exists_local(self, name: str) -> bool:
"""
Determines whether a variable exists in the current function or global scope.
The check does not cross function boundaries; i.e. global variables are not taken into account from inside
functions.
"""
return name in self.vars or (
not self.is_function and self.parent is not None and self.parent.exists_local(name))
def child(self) -> "Scope":
"""
Creates a child scope with a new variable dictionary.
This is used for first-level elements of a function.
"""
return Scope(self, False, {})
def child_share(self) -> "Scope":
"""
Creates a child scope sharing the variable dictionary with the parent scope.
This is used for Python blocks, which share the variable scope with their parent block.
"""
return Scope(self, False, self.vars)
def function(self, **kwargs) -> "Scope":
"""
Creates a function scope.
"""
return Scope(self, True, **kwargs)
def is_root(self) -> Optional[Dict[str, VarDecl]]:
"""
Determines whether this scope is a root scope.
A root scope is either the global scope, or the first inner scope of a function.
Variable declarations in the generated code only ever appear in root scopes.
:return: `None` if this scope is not a root scope, otherwise the variable dictionary of the root scope.
"""
if self.parent is None:
return self.vars
if self.parent.is_function:
return self.parent.vars
return None
def declare(self, name: str, val: Optional[str] = None) -> Optional[str]:
if self.exists_local(name):
# If the variable already exists in the current function or global scope, we don't need to declare it again.
# This is simply an assignment.
return None
vdict, prefix = self.vars, ""
if (root_vars := self.is_root()) is not None:
vdict, prefix = root_vars, "auto " # Root scope declarations can use `auto`.
vdict[name] = VarDecl(VarKind.LOCAL, val)
return prefix
# coding: utf-8
import ast
from dataclasses import dataclass
from enum import Flag
from itertools import zip_longest, chain
from typing import Iterable, Union
from transpiler import MAPPINGS
class NodeVisitor:
def visit(self, node):
"""Visit a node."""
if type(node) == list:
for n in node:
yield from self.visit(n)
else:
for parent in node.__class__.__mro__:
if visitor := getattr(self, 'visit_' + parent.__name__, None):
yield from visitor(node)
break
else:
yield from self.missing_impl(node)
def missing_impl(self, node):
raise UnsupportedNodeError(node)
def process_args(self, node: ast.arguments) -> (str, str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
if getattr(node, field, None):
raise NotImplementedError(node, field)
if not node.args:
return "", "()", []
f_args = [(self.fix_name(arg.arg), f"T{i + 1}") for i, arg in enumerate(node.args)]
return (
"<" + ", ".join(f"typename {t}" for _, t in f_args) + ">",
"(" + ", ".join(f"{t} {n}" for n, t in f_args) + ")",
[n for n, _ in f_args]
)
def fix_name(self, name: str) -> str:
if name.startswith("__") and name.endswith("__"):
return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name)
@dataclass
class UnsupportedNodeError(Exception):
node: ast.AST
def __str__(self) -> str:
return f"Unsupported node: {self.node.__class__.__mro__} {ast.dump(self.node)}"
class CoroutineMode(Flag):
SYNC = 1
FAKE = 2 | SYNC
ASYNC = 4
GENERATOR = 8 | ASYNC
TASK = 16 | ASYNC
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
items = iter(items)
try:
yield from next(items)
for item in items:
yield sep
yield from item
except StopIteration:
return
def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
if type(node1) is not type(node2):
return False
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}:
continue
if not compare_ast(v, getattr(node2, k)):
return False
return True
elif isinstance(node1, list) and isinstance(node2, list):
return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2))
else:
return node1 == node2
def flatmap(f, items):
return chain.from_iterable(map(f, items))
# coding: utf-8
import ast
from dataclasses import dataclass
from typing import Iterable, Optional
from transpiler.scope import VarDecl, VarKind, Scope
from transpiler.visitors import CoroutineMode, NodeVisitor, flatmap
from transpiler.visitors.expr import ExpressionVisitor
from transpiler.visitors.search import SearchVisitor
# noinspection PyPep8Naming
@dataclass
class BlockVisitor(NodeVisitor):
scope: Scope
generator: CoroutineMode = CoroutineMode.SYNC
def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(self.scope, self.generator)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {"
yield from self.visit_func(node, CoroutineMode.FAKE)
class YieldVisitor(SearchVisitor):
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_yield = YieldVisitor().match(node.body)
yield from self.visit_func(node, CoroutineMode.GENERATOR if has_yield else CoroutineMode.TASK)
if has_yield:
templ, args, names = self.process_args(node.args)
if templ:
yield "template"
yield templ
yield f"auto operator()"
yield args
yield f"-> typon::Task<decltype(gen({', '.join(names)}))>"
yield "{"
yield f"co_return std::move(gen({', '.join(names)}));"
yield "}"
yield f"}} {node.name};"
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
from transpiler.visitors.function import FunctionVisitor
templ, args, names = self.process_args(node.args)
if templ:
yield "template"
yield templ
class ReturnVisitor(SearchVisitor):
def visit_Return(self, node: ast.Return) -> bool:
yield True
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_return = ReturnVisitor().match(node.body)
if CoroutineMode.SYNC in generator:
if has_return:
yield "auto"
else:
yield "void"
yield "sync"
elif CoroutineMode.GENERATOR in generator:
yield "auto gen"
else:
yield "auto operator()"
yield args
if CoroutineMode.ASYNC in generator:
yield "-> typon::"
if CoroutineMode.TASK in generator:
yield "Task"
elif CoroutineMode.GENERATOR in generator:
yield "Generator"
yield f"<decltype(sync({', '.join(names)}))>"
yield "{"
inner_scope = self.scope.function(vars={node.name: VarDecl(VarKind.SELF, None)})
for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
# on their own. Variables are still accessible in the remainder of the parent function or in the global
# scope if outside a function.
# This is different from C++, where scope is tied to any code block. To emulate this behavior, we need to
# declare all variables in the first inner scope of a function.
# For example,
# ```py
# def f():
# if True:
# x = 1
# print(x)
# ```
# is translated to
# ```cpp
# auto f() {
# decltype(1) x;
# if (true) {
# x = 1;
# }
# print(x);
# }
# ```
# `decltype` allows for proper typing (`auto` can't be used for variables with values later assigned, since
# this would require real type inference, akin to what Rust does).
# This is only done, though, for *nested* blocks of a function. Root-level variables are declared with
# `auto`:
# ```py
# x = 1
# def f():
# y = 2
# ```
# is translated to
# ```cpp
# auto x = 1;
# auto f() {
# auto y = 2;
# }
# ```
child_visitor = FunctionVisitor(inner_scope.child(), generator)
# We need to do this in two-passes. This unfortunately breaks our nice generator state-machine architecture.
# Fair enough.
[*child_code] = child_visitor.visit(child)
# Hoist inner variables to the root scope.
for var, decl in child_visitor.scope.vars.items():
if decl.kind == VarKind.LOCAL: # Nested declarations become `decltype` declarations.
yield f"decltype({decl.val}) {var};"
elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is.
inner_scope.vars[var] = decl
yield from child_code # Yeet back the child node code.
if CoroutineMode.FAKE in generator:
yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements.
elif CoroutineMode.TASK in generator:
if not has_return:
yield "co_return;"
yield "}"
def visit_lvalue(self, lvalue: ast.expr, val: Optional[ast.AST] = None) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple):
yield f"std::tie({', '.join(flatmap(self.expr().visit, lvalue.elts))})"
elif isinstance(lvalue, ast.Name):
name = self.fix_name(lvalue.id)
# if name not in self._scope.vars:
if not self.scope.exists_local(name):
yield self.scope.declare(name, " ".join(self.expr().visit(val)) if val else None)
yield name
elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue)
else:
raise NotImplementedError(lvalue)
def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1:
raise NotImplementedError(node)
yield from self.visit_lvalue(node.targets[0], node.value)
yield " = "
yield from self.expr().visit(node.value)
yield ";"
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
if node.value is None:
raise NotImplementedError(node, "empty value")
yield from self.visit_lvalue(node.target, node.value)
yield " = "
yield from self.expr().visit(node.value)
yield ";"
# coding: utf-8
import ast
from dataclasses import dataclass, field
from typing import List, Iterable
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.scope import VarKind, Scope
from transpiler.visitors import CoroutineMode, NodeVisitor, join
class PrecedenceContext:
def __init__(self, visitor: "ExpressionVisitor", op: str):
self.visitor = visitor
self.op = op
def __enter__(self):
self.visitor.precedence.append(self.op)
def __exit__(self, exc_type, exc_val, exc_tb):
self.visitor.precedence.pop()
# noinspection PyPep8Naming
@dataclass
class ExpressionVisitor(NodeVisitor):
scope: Scope
generator: CoroutineMode
precedence: List = field(default_factory=list)
def visit(self, node):
if type(node) in SYMBOLS:
yield SYMBOLS[type(node)]
else:
yield from NodeVisitor.visit(self, node)
def prec_ctx(self, op: str) -> PrecedenceContext:
"""
Creates a context manager that sets the precedence of the next expression.
"""
return PrecedenceContext(self, op)
def prec(self, op: str) -> "ExpressionVisitor":
"""
Sets the precedence of the next expression.
"""
return ExpressionVisitor(self.scope, self.generator, [op])
def reset(self) -> "ExpressionVisitor":
"""
Resets the precedence stack.
"""
return ExpressionVisitor(self.scope, self.generator)
def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]:
yield "std::make_tuple("
yield from join(", ", map(self.visit, node.elts))
yield ")"
def visit_Constant(self, node: ast.Constant) -> Iterable[str]:
if isinstance(node.value, str):
# TODO: escape sequences
yield f"\"{repr(node.value)[1:-1]}\"s"
elif isinstance(node.value, bool):
yield str(node.value).lower()
elif isinstance(node.value, int):
# TODO: bigints
yield str(node.value)
elif isinstance(node.value, complex):
yield f"PyComplex({node.value.real}, {node.value.imag})"
elif node.value is None:
yield "PyNone"
else:
raise NotImplementedError(node, type(node))
def visit_Name(self, node: ast.Name) -> Iterable[str]:
res = self.fix_name(node.id)
if (decl := self.scope.get(res)) and decl.kind == VarKind.SELF:
res = "(*this)"
yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
operands = [node.left, *node.comparators]
with self.prec_ctx("&&"):
yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1])
for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# TODO: cleaner code
yield " && "
yield from self.visit_binary_operation(op, left, right)
def visit_Call(self, node: ast.Call) -> Iterable[str]:
if getattr(node, "keywords", None):
raise NotImplementedError(node, "keywords")
if getattr(node, "starargs", None):
raise NotImplementedError(node, "varargs")
if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs")
func = node.func
# TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator:
yield "co_await "
elif CoroutineMode.FAKE in self.generator:
func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
yield from self.prec("()").visit(func)
yield "("
yield from join(", ", map(self.reset().visit, node.args))
yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
yield "[]"
templ, args, _ = self.process_args(node.args)
yield templ
yield args
yield "{"
yield "return"
yield from self.reset().visit(node.body)
yield ";"
yield "}"
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
yield from self.visit_binary_operation(node.op, node.left, node.right)
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST) -> Iterable[str]:
op = SYMBOLS[type(op)]
# TODO: handle precedence locally since only binops really need it
# we could just store the history of traversed nodes and check if the last one was a binop
prio = self.precedence and PRECEDENCE_LEVELS[self.precedence[-1]] < PRECEDENCE_LEVELS[op]
if prio:
yield "("
with self.prec_ctx(op):
yield from self.visit(left)
yield op
yield from self.visit(right)
if prio:
yield ")"
def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]:
yield from self.prec(".").visit(node.value)
yield "."
yield node.attr
def visit_List(self, node: ast.List) -> Iterable[str]:
yield "PyList{"
yield from join(", ", map(self.reset().visit, node.elts))
yield "}"
def visit_Set(self, node: ast.Set) -> Iterable[str]:
yield "PySet{"
yield from join(", ", map(self.reset().visit, node.elts))
yield "}"
def visit_Dict(self, node: ast.Dict) -> Iterable[str]:
def visit_item(key, value):
yield "std::pair {"
yield from self.reset().visit(key)
yield ", "
yield from self.reset().visit(value)
yield "}"
yield "PyDict{"
yield from join(", ", map(visit_item, node.keys, node.values))
yield "}"
def visit_Subscript(self, node: ast.Subscript) -> Iterable[str]:
yield from self.prec("[]").visit(node.value)
yield "["
yield from self.reset().visit(node.slice)
yield "]"
def visit_UnaryOp(self, node: ast.UnaryOp) -> Iterable[str]:
yield from self.visit(node.op)
yield from self.prec("unary").visit(node.operand)
def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]:
with self.prec_ctx("?:"):
yield from self.visit(node.test)
yield " ? "
yield from self.visit(node.body)
yield " : "
yield from self.visit(node.orelse)
def visit_Yield(self, node: ast.Yield) -> Iterable[str]:
if CoroutineMode.GENERATOR in self.generator:
yield "co_yield"
yield from self.prec("co_yield").visit(node.value)
elif CoroutineMode.FAKE in self.generator:
yield "return"
yield from self.visit(node.value)
else:
raise NotImplementedError(node)
# coding: utf-8
import ast
from typing import Iterable
from transpiler.visitors.block import BlockVisitor
from transpiler.visitors.module import ModuleVisitor
# noinspection PyPep8Naming
class FileVisitor(BlockVisitor):
def visit_Module(self, node: ast.Module) -> Iterable[str]:
stmt: ast.AST
yield "#include <python/builtins.hpp>"
visitor = ModuleVisitor(self.scope)
for stmt in node.body:
yield from visitor.visit(stmt)
# coding: utf-8
import ast
from dataclasses import dataclass
from typing import Iterable
from transpiler.consts import SYMBOLS
from transpiler.scope import VarDecl, VarKind
from transpiler.visitors import CoroutineMode
from transpiler.visitors.block import BlockVisitor
# noinspection PyPep8Naming
@dataclass
class FunctionVisitor(BlockVisitor):
def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
yield from self.expr().visit(node.value)
yield ";"
def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target)
yield SYMBOLS[type(node.op)] + "="
yield from self.expr().visit(node.value)
yield ";"
def visit_For(self, node: ast.For) -> Iterable[str]:
if not isinstance(node.target, ast.Name):
raise NotImplementedError(node)
yield f"for (auto {node.target.id} : "
yield from self.expr().visit(node.iter)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
raise NotImplementedError(node, "orelse")
def visit_If(self, node: ast.If) -> Iterable[str]:
yield "if ("
yield from self.expr().visit(node.test)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
yield "else "
if isinstance(node.orelse, ast.If):
yield from self.visit(node.orelse)
else:
yield from self.emit_block(node.orelse)
def visit_Return(self, node: ast.Return) -> Iterable[str]:
if CoroutineMode.ASYNC in self.generator:
yield "co_return "
else:
yield "return "
if node.value:
yield from self.expr().visit(node.value)
yield ";"
def visit_While(self, node: ast.While) -> Iterable[str]:
yield "while ("
yield from self.expr().visit(node.test)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
raise NotImplementedError(node, "orelse")
def visit_Global(self, node: ast.Global) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self.scope.vars[name] = VarDecl(VarKind.GLOBAL, None)
yield ""
def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self.scope.vars[name] = VarDecl(VarKind.NONLOCAL, None)
yield ""
def block(self) -> "FunctionVisitor":
# See the comments in visit_FunctionDef.
# A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same
# variables as the parent scope.
return FunctionVisitor(self.scope.child_share(), self.generator)
def emit_block(self, items: Iterable[ast.stmt]) -> Iterable[str]:
yield "{"
block = self.block()
for child in items:
yield from block.visit(child)
yield "}"
# coding: utf-8
import ast
from typing import Iterable
from transpiler.visitors import CoroutineMode, compare_ast
from transpiler.visitors.block import BlockVisitor
from transpiler.visitors.function import FunctionVisitor
# noinspection PyPep8Naming
class ModuleVisitor(BlockVisitor):
def visit_Import(self, node: ast.Import) -> Iterable[str]:
for alias in node.names:
if alias.name == "typon":
yield ""
else:
yield from self.import_module(alias.name)
yield f'auto& {alias.asname or alias.name} = py_{alias.name}::all;'
def import_module(self, name: str) -> Iterable[str]:
yield f'#include "python/{name}.hpp"'
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module == "typon":
yield ""
else:
yield from self.import_module(node.module)
for alias in node.names:
yield f"auto& {alias.asname or alias.name} = py_{node.module}::all.{alias.name};"
def visit_If(self, node: ast.If) -> Iterable[str]:
if not node.orelse and compare_ast(node.test, ast.parse('__name__ == "__main__"', mode="eval").body):
# Special case handling for Python's interesting way of defining an entry point.
# I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting
# language that, by essence, uses "the start of the file" as the implicit entry point, since files are
# read and executed line-by-line, contrary to usual structured languages that mark a distinction between
# declarations (functions, classes, modules, ...) and code.
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default.
yield "typon::Root root()"
def block():
yield from node.body
yield ast.Return()
yield from FunctionVisitor(self.scope.function(), CoroutineMode.TASK).emit_block(block())
yield "int main() { root().call(); }"
return
raise NotImplementedError(node, "global scope if")
# coding: utf-8
import ast
from transpiler.visitors import NodeVisitor
class SearchVisitor(NodeVisitor):
def missing_impl(self, node):
if not hasattr(node, "__dict__"):
return
for val in node.__dict__.values():
if isinstance(val, list):
for item in val:
yield from self.visit(item)
elif isinstance(val, ast.AST):
yield from self.visit(val)
def match(self, node) -> bool:
return next(self.visit(node), False)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment