Commit 4b427fb6 authored by Tom Niget's avatar Tom Niget

Start work on exprs

parent 2f8bdf7b
......@@ -23,4 +23,4 @@
if __name__ == "__main__":
print(5)
\ No newline at end of file
print("abc")
\ No newline at end of file
# coding: utf-8
import ast
from dataclasses import dataclass, field
from typing import Iterable
from transpiler.phases.emit_cpp.visitors import NodeVisitor, CoroutineMode, join
from transpiler.phases.typing.scope import Scope
# noinspection PyPep8Naming
@dataclass
class ExpressionVisitor(NodeVisitor):
scope: Scope
generator: CoroutineMode
def visit(self, node):
if False and type(node) in SYMBOLS:
yield SYMBOLS[type(node)]
else:
yield from NodeVisitor.visit(self, node)
def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]:
yield "std::make_tuple("
yield from join(", ", map(self.visit, node.elts))
yield ")"
def visit_Constant(self, node: ast.Constant) -> Iterable[str]:
if isinstance(node.value, str):
# TODO: escape sequences
yield f"\"{repr(node.value)[1:-1]}\"_ps"
elif isinstance(node.value, bool):
yield str(node.value).lower()
elif isinstance(node.value, int):
# TODO: bigints
yield str(node.value)
elif isinstance(node.value, float):
yield repr(node.value)
elif isinstance(node.value, complex):
yield f"PyComplex({node.value.real}, {node.value.imag})"
elif node.value is None:
yield "PyNone"
else:
raise NotImplementedError(node, type(node))
def visit_Slice(self, node: ast.Slice) -> Iterable[str]:
yield "PySlice("
yield from join(", ", (self.visit(x or ast.Constant(value=None)) for x in (node.lower, node.upper, node.step)))
yield ")"
def visit_Name(self, node: ast.Name) -> Iterable[str]:
res = self.fix_name(node.id)
if self.scope.function and (decl := self.scope.get(res)) and decl.type is self.scope.function.obj_type:
if not self.scope.function.parent.function:
res = "(*this)"
#if decl.kind == VarKind.SELF:
# res = "(*this)"
#elif decl.future and CoroutineMode.ASYNC in self.generator:
# res = f"{res}.get()"
# if decl.future == "future":
# res = "co_await " + res
yield res
# def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
# def make_lnd(op1, op2):
# return {
# "lineno": op1.lineno,
# "col_offset": op1.col_offset,
# "end_lineno": op2.end_lineno,
# "end_col_offset": op2.end_col_offset
# }
#
# operands = [node.left, *node.comparators]
# with self.prec_ctx("&&"):
# yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1], make_lnd(operands[0], operands[1]))
# for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# # TODO: cleaner code
# yield " && "
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
if len(node.values) == 1:
yield from self.visit(node.values[0])
return
cpp_op = {
ast.And: "&&",
ast.Or: "||"
}[type(node.op)]
with self.prec_ctx(cpp_op):
yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
for left, right in zip(node.values[1:], node.values[2:]):
yield f" {cpp_op} "
yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]:
# TODO
# if getattr(node, "keywords", None):
# raise NotImplementedError(node, "keywords")
if getattr(node, "starargs", None):
raise NotImplementedError(node, "varargs")
if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs")
func = node.func
if isinstance(func, ast.Attribute):
if sym := DUNDER_SYMBOLS.get(func.attr, None):
if len(node.args) == 1:
yield from self.visit_binary_operation(sym, func.value, node.args[0], linenodata(node))
else:
yield from self.visit_unary_operation(sym, func.value)
return
for name in ("fork", "future"):
if compare_ast(func, ast.parse(name, mode="eval").body):
assert len(node.args) == 1
arg = node.args[0]
assert isinstance(arg, ast.Lambda)
node.is_future = name
vis = self.reset()
vis.generator = CoroutineMode.SYNC
# todo: bad code
if CoroutineMode.ASYNC in self.generator:
yield f"co_await typon::{name}("
yield from vis.visit(arg.body)
yield ")"
return
elif CoroutineMode.FAKE in self.generator:
yield from self.visit(arg.body)
return
if compare_ast(func, ast.parse('sync', mode="eval").body):
if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::Sync()"
elif CoroutineMode.FAKE in self.generator:
yield from ()
return
# TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator and node.is_await:
yield "(" # TODO: temporary
yield "co_await "
node.in_await = True
elif CoroutineMode.FAKE in self.generator:
func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
yield from self.prec("()").visit(func)
yield "("
yield from join(", ", map(self.reset().visit, node.args))
yield ")"
if CoroutineMode.ASYNC in self.generator and node.is_await:
yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
yield "[]"
templ, args, _ = self.process_args(node.args)
yield templ
yield args
yield "{"
yield "return"
yield from self.reset().visit(node.body)
yield ";"
yield "}"
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node))
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node))
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]:
# if type(op) == ast.In:
# call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
# call.is_await = False
# yield from self.visit_Call(call)
# print(call.func.type)
# return
if type(op) != str:
op = SYMBOLS[type(op)]
# TODO: handle precedence locally since only binops really need it
# we could just store the history of traversed nodes and check if the last one was a binop
prio = self.precedence and PRECEDENCE_LEVELS[self.precedence[-1]] < PRECEDENCE_LEVELS[op]
if prio:
yield "("
with self.prec_ctx(op):
yield from self.visit(left)
yield op
yield from self.visit(right)
if prio:
yield ")"
def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]:
yield "dot"
yield "(("
yield from self.visit(node.value)
yield "), "
yield self.fix_name(node.attr)
yield ")"
def visit_List(self, node: ast.List) -> Iterable[str]:
if node.elts:
yield "typon::PyList{"
yield from join(", ", map(self.reset().visit, node.elts))
yield "}"
else:
yield from self.visit(node.type)
yield "{}"
def visit_Set(self, node: ast.Set) -> Iterable[str]:
if node.elts:
yield "typon::PySet{"
yield from join(", ", map(self.reset().visit, node.elts))
yield "}"
else:
yield from self.visit(node.type)
yield "{}"
def visit_Dict(self, node: ast.Dict) -> Iterable[str]:
def visit_item(key, value):
yield "std::pair {"
yield from self.reset().visit(key)
yield ", "
yield from self.reset().visit(value)
yield "}"
if node.keys:
yield from self.visit(node.type)
yield "{"
yield from join(", ", map(visit_item, node.keys, node.values))
yield "}"
else:
yield from self.visit(node.type)
yield "{}"
def visit_Subscript(self, node: ast.Subscript) -> Iterable[str]:
if isinstance(node.type, TypeType) and isinstance(node.type.type_object, MonomorphizedUserType):
yield node.type.type_object.name
return
yield from self.prec("[]").visit(node.value)
yield "["
yield from self.reset().visit(node.slice)
yield "]"
def visit_UnaryOp(self, node: ast.UnaryOp) -> Iterable[str]:
yield from self.visit_unary_operation(node.op, node.operand)
def visit_unary_operation(self, op, operand) -> Iterable[str]:
if type(op) != str:
op = SYMBOLS[type(op)]
yield op
yield from self.prec("unary").visit(operand)
def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]:
with self.prec_ctx("?:"):
yield from self.visit(node.test)
yield " ? "
yield from self.visit(node.body)
yield " : "
yield from self.visit(node.orelse)
def visit_Yield(self, node: ast.Yield) -> Iterable[str]:
#if CoroutineMode.GENERATOR in self.generator:
# yield "co_yield"
# yield from self.prec("co_yield").visit(node.value)
#elif CoroutineMode.FAKE in self.generator:
# yield "return"
# yield from self.visit(node.value)
#else:
# raise NotImplementedError(node)
yield "co_yield"
yield from self.prec("co_yield").visit(node.value)
def visit_ListComp(self, node: ast.ListComp) -> Iterable[str]:
if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0]
yield "mapFilter([]("
yield from self.visit(node.input_item_type)
yield from self.visit(gen.target)
yield ") { return "
yield from self.visit(node.elt)
yield "; }, "
yield from self.visit(gen.iter)
if gen.ifs:
yield ", "
yield "[]("
yield from self.visit(node.input_item_type)
yield from self.visit(gen.target)
yield ") { return "
yield from self.visit(gen.ifs_node)
yield "; }"
yield ")"
# iter_type = get_iter(self.visit(gen.iter))
# next_type = get_next(iter_type)
# virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER)
# from transpiler import ScoperBlockVisitor
# visitor = ScoperBlockVisitor(virt_scope)
# visitor.visit_assign_target(gen.target, next_type)
# res_item_type = visitor.expr().visit(node.elt)
# for if_ in gen.ifs:
# visitor.expr().visit(if_)
# return PyList(res_item_type)
\ No newline at end of file
import ast
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Iterable, Optional
from transpiler.phases.emit_cpp.expr import ExpressionVisitor
from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp.visitors import NodeVisitor, flatmap
from transpiler.phases.emit_cpp.visitors import NodeVisitor, flatmap, CoroutineMode
from transpiler.phases.typing.types import CallableInstanceType, BaseType
......@@ -23,167 +23,166 @@ def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]:
if False:
@dataclass
class BlockVisitor(NodeVisitor):
scope: Scope
#generator: CoroutineMode = field(default=CoroutineMode.SYNC, kw_only=True)
def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(self.scope, self.generator)
@dataclass
class BlockVisitor(NodeVisitor):
scope: Scope
generator: CoroutineMode = field(default=CoroutineMode.SYNC, kw_only=True)
def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(self.scope, self.generator)
def visit_Pass(self, node: ast.Pass) -> Iterable[str]:
yield ";"
# def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
# yield from self.visit_free_func(node)
# def visit_free_func(self, node: ast.FunctionDef, emission: FunctionEmissionKind) -> Iterable[str]:
# if getattr(node, "is_main", False):
# if emission == FunctionEmissionKind.DECLARATION:
# return
# # Special case handling for Python's interesting way of defining an entry point.
# # I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting
# # language that, by essence, uses "the start of the file" as the implicit entry point, since files are
# # read and executed line-by-line, contrary to usual structured languages that mark a distinction between
# # declarations (functions, classes, modules, ...) and code.
# # Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# # 0 is returned by default.
# yield "typon::Root root() const"
#
# def block():
# yield from node.body
# yield ast.Return()
#
# from transpiler.phases.emit_cpp.function import FunctionVisitor
# yield "{"
# yield from self.visit_func_decls(block(), node.scope, CoroutineMode.TASK)
# yield "}"
# return
#
# if emission == FunctionEmissionKind.DECLARATION:
# yield f"struct {node.name}_inner {{"
# yield from self.visit_func_new(node, emission)
# if emission == FunctionEmissionKind.DECLARATION:
# yield f"}} {node.name};"
def visit_func_decls(self, body: list[ast.stmt], inner_scope: Scope, mode = CoroutineMode.ASYNC) -> Iterable[str]:
for child in body:
from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, generator=mode)
for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
yield from self.visit(decl.type)
yield f" {name};"
yield from child_visitor.visit(child)
def visit_func_params(self, args: Iterable[tuple[str, BaseType, Optional[ast.expr]]], emission: FunctionEmissionKind) -> Iterable[str]:
for i, (arg, argty, default) in enumerate(args):
if i != 0:
yield ", "
if emission == FunctionEmissionKind.METHOD and i == 0:
yield "Self"
else:
yield from self.visit(argty)
yield arg
if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA, FunctionEmissionKind.METHOD} and default:
yield " = "
yield from self.expr().visit(default)
def visit_Pass(self, node: ast.Pass) -> Iterable[str]:
def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
if emission == FunctionEmissionKind.LAMBDA:
yield "[&]"
else:
if emission == FunctionEmissionKind.METHOD:
yield "template <typename Self>"
yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DEFINITION:
yield f"{node.name}_inner::"
yield "operator()"
yield "("
padded_defaults = [None] * (len(node.args.args) if node.type.optional_at is None else node.type.optional_at) + node.args.defaults
args_iter = zip(node.args.args, node.type.parameters, padded_defaults)
if skip_first_arg:
next(args_iter)
yield from self.visit_func_params(((arg.arg, argty, default) for arg, argty, default in args_iter), emission)
yield ")"
if emission == FunctionEmissionKind.METHOD:
yield "const"
inner_scope = node.inner_scope
if emission == FunctionEmissionKind.DECLARATION:
yield ";"
return
# def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
# yield from self.visit_free_func(node)
def visit_free_func(self, node: ast.FunctionDef, emission: FunctionEmissionKind) -> Iterable[str]:
if getattr(node, "is_main", False):
if emission == FunctionEmissionKind.DECLARATION:
return
# Special case handling for Python's interesting way of defining an entry point.
# I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting
# language that, by essence, uses "the start of the file" as the implicit entry point, since files are
# read and executed line-by-line, contrary to usual structured languages that mark a distinction between
# declarations (functions, classes, modules, ...) and code.
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default.
yield "typon::Root root() const"
def block():
yield from node.body
yield ast.Return()
from transpiler.phases.emit_cpp.function import FunctionVisitor
yield "{"
yield from self.visit_func_decls(block(), node.scope, CoroutineMode.TASK)
yield "}"
return
if emission == FunctionEmissionKind.LAMBDA:
yield "->"
yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DECLARATION:
yield f"struct {node.name}_inner {{"
yield from self.visit_func_new(node, emission)
if emission == FunctionEmissionKind.DECLARATION:
yield f"}} {node.name};"
def visit_func_decls(self, body: list[ast.stmt], inner_scope: Scope, mode = CoroutineMode.ASYNC) -> Iterable[str]:
for child in body:
from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, generator=mode)
for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
yield from self.visit(decl.type)
yield f" {name};"
yield from child_visitor.visit(child)
def visit_func_params(self, args: Iterable[tuple[str, BaseType, Optional[ast.expr]]], emission: FunctionEmissionKind) -> Iterable[str]:
for i, (arg, argty, default) in enumerate(args):
if i != 0:
yield ", "
if emission == FunctionEmissionKind.METHOD and i == 0:
yield "Self"
else:
yield from self.visit(argty)
yield arg
if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA, FunctionEmissionKind.METHOD} and default:
yield " = "
yield from self.expr().visit(default)
def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
if emission == FunctionEmissionKind.LAMBDA:
yield "[&]"
else:
if emission == FunctionEmissionKind.METHOD:
yield "template <typename Self>"
yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DEFINITION:
yield f"{node.name}_inner::"
yield "operator()"
yield "("
padded_defaults = [None] * (len(node.args.args) if node.type.optional_at is None else node.type.optional_at) + node.args.defaults
args_iter = zip(node.args.args, node.type.parameters, padded_defaults)
if skip_first_arg:
next(args_iter)
yield from self.visit_func_params(((arg.arg, argty, default) for arg, argty, default in args_iter), emission)
yield ")"
yield "{"
if emission == FunctionEmissionKind.METHOD:
yield "const"
class ReturnVisitor(SearchVisitor):
def visit_Return(self, node: ast.Return) -> bool:
yield True
inner_scope = node.inner_scope
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
if emission == FunctionEmissionKind.DECLARATION:
yield ";"
return
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
if emission == FunctionEmissionKind.LAMBDA:
yield "->"
yield from self.visit(node.type.return_type)
yield "{"
class ReturnVisitor(SearchVisitor):
def visit_Return(self, node: ast.Return) -> bool:
yield True
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_return = ReturnVisitor().match(node.body)
yield from self.visit_func_decls(node.body, inner_scope)
# if not has_return and isinstance(node.type.return_type, Promise):
# yield "co_return;"
yield "}"
def visit_lvalue(self, lvalue: ast.expr, declare: bool | list[bool] = False) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple):
for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args):
if decl:
yield from self.visit_lvalue(name, True)
yield ";"
yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})"
elif isinstance(lvalue, ast.Name):
if lvalue.id == "_":
if not declare:
yield "std::ignore"
return
name = self.fix_name(lvalue.id)
# if name not in self._scope.vars:
# if not self.scope.exists_local(name):
# yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None,
# getattr(val, "is_future", False))
if declare:
yield from self.visit(lvalue.type)
yield name
elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue)
elif isinstance(lvalue, ast.Attribute):
yield from self.expr().visit(lvalue)
else:
raise NotImplementedError(lvalue)
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1:
raise NotImplementedError(node)
yield from self.visit_lvalue(node.targets[0], node.is_declare)
has_return = ReturnVisitor().match(node.body)
yield from self.visit_func_decls(node.body, inner_scope)
# if not has_return and isinstance(node.type.return_type, Promise):
# yield "co_return;"
yield "}"
def visit_lvalue(self, lvalue: ast.expr, declare: bool | list[bool] = False) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple):
for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args):
if decl:
yield from self.visit_lvalue(name, True)
yield ";"
yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})"
elif isinstance(lvalue, ast.Name):
if lvalue.id == "_":
if not declare:
yield "std::ignore"
return
name = self.fix_name(lvalue.id)
# if name not in self._scope.vars:
# if not self.scope.exists_local(name):
# yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None,
# getattr(val, "is_future", False))
if declare:
yield from self.visit(lvalue.type)
yield name
elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue)
elif isinstance(lvalue, ast.Attribute):
yield from self.expr().visit(lvalue)
else:
raise NotImplementedError(lvalue)
def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1:
raise NotImplementedError(node)
yield from self.visit_lvalue(node.targets[0], node.is_declare)
yield " = "
yield from self.expr().visit(node.value)
yield ";"
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target, node.is_declare)
if node.value:
yield " = "
yield from self.expr().visit(node.value)
yield ";"
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target, node.is_declare)
if node.value:
yield " = "
yield from self.expr().visit(node.value)
yield ";"
yield ";"
import ast
from enum import Flag
from itertools import chain
from typing import Iterable
......@@ -89,3 +90,11 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
def flatmap(f, items):
return chain.from_iterable(map(f, items))
class CoroutineMode(Flag):
SYNC = 1
FAKE = 2 | SYNC
ASYNC = 4
GENERATOR = 8 | ASYNC
TASK = 16 | ASYNC
JOIN = 32 | ASYNC
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment