Commit 85f54768 authored by Tom Niget's avatar Tom Niget

Make first things work

parent 54759ed9
Subproject commit 285703a35af4ca1476ce30e2404b73af9d880ec5 Subproject commit 79677d125f915f7c61492d8d1d8cde9fc6a11875
...@@ -79,71 +79,79 @@ class ExpressionVisitor(NodeVisitor): ...@@ -79,71 +79,79 @@ class ExpressionVisitor(NodeVisitor):
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right)) # yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]: def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
if len(node.values) == 1: raise NotImplementedError()
yield from self.visit(node.values[0])
return # if len(node.values) == 1:
cpp_op = { # yield from self.visit(node.values[0])
ast.And: "&&", # return
ast.Or: "||" # cpp_op = {
}[type(node.op)] # ast.And: "&&",
with self.prec_ctx(cpp_op): # ast.Or: "||"
yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1])) # }[type(node.op)]
for left, right in zip(node.values[1:], node.values[2:]): # with self.prec_ctx(cpp_op):
yield f" {cpp_op} " # yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right)) # for left, right in zip(node.values[1:], node.values[2:]):
# yield f" {cpp_op} "
# yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]: def visit_Call(self, node: ast.Call) -> Iterable[str]:
yield "("
yield from self.visit(node.func)
yield ")("
yield from join(", ", map(self.visit, node.args))
yield ")"
#raise NotImplementedError()
# TODO # TODO
# if getattr(node, "keywords", None): # if getattr(node, "keywords", None):
# raise NotImplementedError(node, "keywords") # raise NotImplementedError(node, "keywords")
if getattr(node, "starargs", None): # if getattr(node, "starargs", None):
raise NotImplementedError(node, "varargs") # raise NotImplementedError(node, "varargs")
if getattr(node, "kwargs", None): # if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs") # raise NotImplementedError(node, "kwargs")
func = node.func # func = node.func
if isinstance(func, ast.Attribute): # if isinstance(func, ast.Attribute):
if sym := DUNDER_SYMBOLS.get(func.attr, None): # if sym := DUNDER_SYMBOLS.get(func.attr, None):
if len(node.args) == 1: # if len(node.args) == 1:
yield from self.visit_binary_operation(sym, func.value, node.args[0], linenodata(node)) # yield from self.visit_binary_operation(sym, func.value, node.args[0], linenodata(node))
else: # else:
yield from self.visit_unary_operation(sym, func.value) # yield from self.visit_unary_operation(sym, func.value)
return # return
for name in ("fork", "future"): # for name in ("fork", "future"):
if compare_ast(func, ast.parse(name, mode="eval").body): # if compare_ast(func, ast.parse(name, mode="eval").body):
assert len(node.args) == 1 # assert len(node.args) == 1
arg = node.args[0] # arg = node.args[0]
assert isinstance(arg, ast.Lambda) # assert isinstance(arg, ast.Lambda)
node.is_future = name # node.is_future = name
vis = self.reset() # vis = self.reset()
vis.generator = CoroutineMode.SYNC # vis.generator = CoroutineMode.SYNC
# todo: bad code # # todo: bad code
if CoroutineMode.ASYNC in self.generator: # if CoroutineMode.ASYNC in self.generator:
yield f"co_await typon::{name}(" # yield f"co_await typon::{name}("
yield from vis.visit(arg.body) # yield from vis.visit(arg.body)
yield ")" # yield ")"
return # return
elif CoroutineMode.FAKE in self.generator: # elif CoroutineMode.FAKE in self.generator:
yield from self.visit(arg.body) # yield from self.visit(arg.body)
return # return
if compare_ast(func, ast.parse('sync', mode="eval").body): # if compare_ast(func, ast.parse('sync', mode="eval").body):
if CoroutineMode.ASYNC in self.generator: # if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::Sync()" # yield "co_await typon::Sync()"
elif CoroutineMode.FAKE in self.generator: # elif CoroutineMode.FAKE in self.generator:
yield from () # yield from ()
return # return
# TODO: precedence needed? # # TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator and node.is_await: # if CoroutineMode.ASYNC in self.generator and node.is_await:
yield "(" # TODO: temporary # yield "(" # TODO: temporary
yield "co_await " # yield "co_await "
node.in_await = True # node.in_await = True
elif CoroutineMode.FAKE in self.generator: # elif CoroutineMode.FAKE in self.generator:
func = ast.Attribute(value=func, attr="sync", ctx=ast.Load()) # func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
yield from self.prec("()").visit(func) # yield from self.prec("()").visit(func)
yield "(" # yield "("
yield from join(", ", map(self.reset().visit, node.args)) # yield from join(", ", map(self.reset().visit, node.args))
yield ")" # yield ")"
if CoroutineMode.ASYNC in self.generator and node.is_await: # if CoroutineMode.ASYNC in self.generator and node.is_await:
yield ")" # yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]: def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
yield "[]" yield "[]"
...@@ -157,12 +165,15 @@ class ExpressionVisitor(NodeVisitor): ...@@ -157,12 +165,15 @@ class ExpressionVisitor(NodeVisitor):
yield "}" yield "}"
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]: def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node)) yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node))
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node)) yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node))
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]: def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]:
raise NotImplementedError()
# if type(op) == ast.In: # if type(op) == ast.In:
# call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd) # call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
# call.is_await = False # call.is_await = False
......
...@@ -9,7 +9,7 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType ...@@ -9,7 +9,7 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType
def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]: def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]:
yield f"struct : function {{" yield f"struct : referencemodel::function {{"
yield "typon::Task<void> operator()(" yield "typon::Task<void> operator()("
for arg, ty in zip(func.block_data.node.args.args, func.parameters): for arg, ty in zip(func.block_data.node.args.args, func.parameters):
...@@ -17,6 +17,8 @@ def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]: ...@@ -17,6 +17,8 @@ def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]:
yield arg yield arg
yield ") const {" yield ") const {"
yield from BlockVisitor(func.block_data.scope, generator=CoroutineMode.TASK).visit(func.block_data.node.body)
yield "co_return {};"
yield "}" yield "}"
yield f"}} static constexpr {name} {{}};" yield f"}} static constexpr {name} {{}};"
yield f"static_assert(sizeof {name} == 1);" yield f"static_assert(sizeof {name} == 1);"
...@@ -34,6 +36,10 @@ class BlockVisitor(NodeVisitor): ...@@ -34,6 +36,10 @@ class BlockVisitor(NodeVisitor):
def visit_Pass(self, node: ast.Pass) -> Iterable[str]: def visit_Pass(self, node: ast.Pass) -> Iterable[str]:
yield ";" yield ";"
def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
yield from self.expr().visit(node.value)
yield ";"
# def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: # def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
# yield from self.visit_free_func(node) # yield from self.visit_free_func(node)
...@@ -66,111 +72,111 @@ class BlockVisitor(NodeVisitor): ...@@ -66,111 +72,111 @@ class BlockVisitor(NodeVisitor):
# if emission == FunctionEmissionKind.DECLARATION: # if emission == FunctionEmissionKind.DECLARATION:
# yield f"}} {node.name};" # yield f"}} {node.name};"
def visit_func_decls(self, body: list[ast.stmt], inner_scope: Scope, mode = CoroutineMode.ASYNC) -> Iterable[str]: # def visit_func_decls(self, body: list[ast.stmt], inner_scope: Scope, mode = CoroutineMode.ASYNC) -> Iterable[str]:
for child in body: # for child in body:
from transpiler.phases.emit_cpp.function import FunctionVisitor # from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, generator=mode) # child_visitor = FunctionVisitor(inner_scope, generator=mode)
#
for name, decl in getattr(child, "decls", {}).items(): # for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};" # #yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
yield from self.visit(decl.type) # yield from self.visit(decl.type)
yield f" {name};" # yield f" {name};"
yield from child_visitor.visit(child) # yield from child_visitor.visit(child)
#
def visit_func_params(self, args: Iterable[tuple[str, BaseType, Optional[ast.expr]]], emission: FunctionEmissionKind) -> Iterable[str]: # def visit_func_params(self, args: Iterable[tuple[str, BaseType, Optional[ast.expr]]], emission: FunctionEmissionKind) -> Iterable[str]:
for i, (arg, argty, default) in enumerate(args): # for i, (arg, argty, default) in enumerate(args):
if i != 0: # if i != 0:
yield ", " # yield ", "
if emission == FunctionEmissionKind.METHOD and i == 0: # if emission == FunctionEmissionKind.METHOD and i == 0:
yield "Self" # yield "Self"
else: # else:
yield from self.visit(argty) # yield from self.visit(argty)
yield arg # yield arg
if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA, FunctionEmissionKind.METHOD} and default: # if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA, FunctionEmissionKind.METHOD} and default:
yield " = " # yield " = "
yield from self.expr().visit(default) # yield from self.expr().visit(default)
#
def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]: # def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
if emission == FunctionEmissionKind.LAMBDA: # if emission == FunctionEmissionKind.LAMBDA:
yield "[&]" # yield "[&]"
else: # else:
if emission == FunctionEmissionKind.METHOD: # if emission == FunctionEmissionKind.METHOD:
yield "template <typename Self>" # yield "template <typename Self>"
yield from self.visit(node.type.return_type) # yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DEFINITION: # if emission == FunctionEmissionKind.DEFINITION:
yield f"{node.name}_inner::" # yield f"{node.name}_inner::"
yield "operator()" # yield "operator()"
yield "(" # yield "("
padded_defaults = [None] * (len(node.args.args) if node.type.optional_at is None else node.type.optional_at) + node.args.defaults # padded_defaults = [None] * (len(node.args.args) if node.type.optional_at is None else node.type.optional_at) + node.args.defaults
args_iter = zip(node.args.args, node.type.parameters, padded_defaults) # args_iter = zip(node.args.args, node.type.parameters, padded_defaults)
if skip_first_arg: # if skip_first_arg:
next(args_iter) # next(args_iter)
yield from self.visit_func_params(((arg.arg, argty, default) for arg, argty, default in args_iter), emission) # yield from self.visit_func_params(((arg.arg, argty, default) for arg, argty, default in args_iter), emission)
yield ")" # yield ")"
#
if emission == FunctionEmissionKind.METHOD: # if emission == FunctionEmissionKind.METHOD:
yield "const" # yield "const"
#
inner_scope = node.inner_scope # inner_scope = node.inner_scope
#
if emission == FunctionEmissionKind.DECLARATION: # if emission == FunctionEmissionKind.DECLARATION:
yield ";" # yield ";"
return # return
#
if emission == FunctionEmissionKind.LAMBDA: # if emission == FunctionEmissionKind.LAMBDA:
yield "->" # yield "->"
yield from self.visit(node.type.return_type) # yield from self.visit(node.type.return_type)
#
yield "{" # yield "{"
#
class ReturnVisitor(SearchVisitor): # class ReturnVisitor(SearchVisitor):
def visit_Return(self, node: ast.Return) -> bool: # def visit_Return(self, node: ast.Return) -> bool:
yield True # yield True
#
def visit_Yield(self, node: ast.Yield) -> bool: # def visit_Yield(self, node: ast.Yield) -> bool:
yield True # yield True
#
def visit_FunctionDef(self, node: ast.FunctionDef): # def visit_FunctionDef(self, node: ast.FunctionDef):
yield from () # yield from ()
#
def visit_ClassDef(self, node: ast.ClassDef): # def visit_ClassDef(self, node: ast.ClassDef):
yield from () # yield from ()
#
has_return = ReturnVisitor().match(node.body) # has_return = ReturnVisitor().match(node.body)
#
yield from self.visit_func_decls(node.body, inner_scope) # yield from self.visit_func_decls(node.body, inner_scope)
#
# if not has_return and isinstance(node.type.return_type, Promise): # # if not has_return and isinstance(node.type.return_type, Promise):
# yield "co_return;" # # yield "co_return;"
#
yield "}" # yield "}"
#
def visit_lvalue(self, lvalue: ast.expr, declare: bool | list[bool] = False) -> Iterable[str]: # def visit_lvalue(self, lvalue: ast.expr, declare: bool | list[bool] = False) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple): # if isinstance(lvalue, ast.Tuple):
for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args): # for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args):
if decl: # if decl:
yield from self.visit_lvalue(name, True) # yield from self.visit_lvalue(name, True)
yield ";" # yield ";"
yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})" # yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})"
elif isinstance(lvalue, ast.Name): # elif isinstance(lvalue, ast.Name):
if lvalue.id == "_": # if lvalue.id == "_":
if not declare: # if not declare:
yield "std::ignore" # yield "std::ignore"
return # return
name = self.fix_name(lvalue.id) # name = self.fix_name(lvalue.id)
# if name not in self._scope.vars: # # if name not in self._scope.vars:
# if not self.scope.exists_local(name): # # if not self.scope.exists_local(name):
# yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None, # # yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None,
# getattr(val, "is_future", False)) # # getattr(val, "is_future", False))
if declare: # if declare:
yield from self.visit(lvalue.type) # yield from self.visit(lvalue.type)
yield name # yield name
elif isinstance(lvalue, ast.Subscript): # elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue) # yield from self.expr().visit(lvalue)
elif isinstance(lvalue, ast.Attribute): # elif isinstance(lvalue, ast.Attribute):
yield from self.expr().visit(lvalue) # yield from self.expr().visit(lvalue)
else: # else:
raise NotImplementedError(lvalue) # raise NotImplementedError(lvalue)
def visit_Assign(self, node: ast.Assign) -> Iterable[str]: def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1: if len(node.targets) != 1:
......
...@@ -6,7 +6,7 @@ from typing import Iterable ...@@ -6,7 +6,7 @@ from typing import Iterable
import transpiler.phases.typing.types as types import transpiler.phases.typing.types as types
from transpiler.phases.typing.exceptions import UnresolvedTypeVariableError from transpiler.phases.typing.exceptions import UnresolvedTypeVariableError
from transpiler.phases.typing.types import BaseType from transpiler.phases.typing.types import BaseType
from transpiler.utils import UnsupportedNodeError from transpiler.utils import UnsupportedNodeError, highlight
class UniversalVisitor: class UniversalVisitor:
...@@ -48,7 +48,8 @@ class NodeVisitor(UniversalVisitor): ...@@ -48,7 +48,8 @@ class NodeVisitor(UniversalVisitor):
def fix_name(self, name: str) -> str: def fix_name(self, name: str) -> str:
if name.startswith("__") and name.endswith("__"): if name.startswith("__") and name.endswith("__"):
return f"py_{name[2:-2]}" return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name) return name
#return MAPPINGS.get(name, name)
def visit_BaseType(self, node: BaseType) -> Iterable[str]: def visit_BaseType(self, node: BaseType) -> Iterable[str]:
node = node.resolve() node = node.resolve()
......
...@@ -54,7 +54,7 @@ def transpile(source, name: str, path: Path): ...@@ -54,7 +54,7 @@ def transpile(source, name: str, path: Path):
# yield from code # yield from code
# yield "}" # yield "}"
yield "#else" yield "#else"
yield "typon::Root root() const {" yield "typon::Root root() {"
yield f"co_await dot(PROGRAMNS::{module.name()}, main)();" yield f"co_await dot(PROGRAMNS::{module.name()}, main)();"
yield "}" yield "}"
yield "int main(int argc, char* argv[]) {" yield "int main(int argc, char* argv[]) {"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment