Commit 85f54768 authored by Tom Niget's avatar Tom Niget

Make first things work

parent 54759ed9
Subproject commit 285703a35af4ca1476ce30e2404b73af9d880ec5
Subproject commit 79677d125f915f7c61492d8d1d8cde9fc6a11875
......@@ -79,71 +79,79 @@ class ExpressionVisitor(NodeVisitor):
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
if len(node.values) == 1:
yield from self.visit(node.values[0])
return
cpp_op = {
ast.And: "&&",
ast.Or: "||"
}[type(node.op)]
with self.prec_ctx(cpp_op):
yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
for left, right in zip(node.values[1:], node.values[2:]):
yield f" {cpp_op} "
yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
raise NotImplementedError()
# if len(node.values) == 1:
# yield from self.visit(node.values[0])
# return
# cpp_op = {
# ast.And: "&&",
# ast.Or: "||"
# }[type(node.op)]
# with self.prec_ctx(cpp_op):
# yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
# for left, right in zip(node.values[1:], node.values[2:]):
# yield f" {cpp_op} "
# yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]:
yield "("
yield from self.visit(node.func)
yield ")("
yield from join(", ", map(self.visit, node.args))
yield ")"
#raise NotImplementedError()
# TODO
# if getattr(node, "keywords", None):
# raise NotImplementedError(node, "keywords")
if getattr(node, "starargs", None):
raise NotImplementedError(node, "varargs")
if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs")
func = node.func
if isinstance(func, ast.Attribute):
if sym := DUNDER_SYMBOLS.get(func.attr, None):
if len(node.args) == 1:
yield from self.visit_binary_operation(sym, func.value, node.args[0], linenodata(node))
else:
yield from self.visit_unary_operation(sym, func.value)
return
for name in ("fork", "future"):
if compare_ast(func, ast.parse(name, mode="eval").body):
assert len(node.args) == 1
arg = node.args[0]
assert isinstance(arg, ast.Lambda)
node.is_future = name
vis = self.reset()
vis.generator = CoroutineMode.SYNC
# todo: bad code
if CoroutineMode.ASYNC in self.generator:
yield f"co_await typon::{name}("
yield from vis.visit(arg.body)
yield ")"
return
elif CoroutineMode.FAKE in self.generator:
yield from self.visit(arg.body)
return
if compare_ast(func, ast.parse('sync', mode="eval").body):
if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::Sync()"
elif CoroutineMode.FAKE in self.generator:
yield from ()
return
# TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator and node.is_await:
yield "(" # TODO: temporary
yield "co_await "
node.in_await = True
elif CoroutineMode.FAKE in self.generator:
func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
yield from self.prec("()").visit(func)
yield "("
yield from join(", ", map(self.reset().visit, node.args))
yield ")"
if CoroutineMode.ASYNC in self.generator and node.is_await:
yield ")"
# if getattr(node, "starargs", None):
# raise NotImplementedError(node, "varargs")
# if getattr(node, "kwargs", None):
# raise NotImplementedError(node, "kwargs")
# func = node.func
# if isinstance(func, ast.Attribute):
# if sym := DUNDER_SYMBOLS.get(func.attr, None):
# if len(node.args) == 1:
# yield from self.visit_binary_operation(sym, func.value, node.args[0], linenodata(node))
# else:
# yield from self.visit_unary_operation(sym, func.value)
# return
# for name in ("fork", "future"):
# if compare_ast(func, ast.parse(name, mode="eval").body):
# assert len(node.args) == 1
# arg = node.args[0]
# assert isinstance(arg, ast.Lambda)
# node.is_future = name
# vis = self.reset()
# vis.generator = CoroutineMode.SYNC
# # todo: bad code
# if CoroutineMode.ASYNC in self.generator:
# yield f"co_await typon::{name}("
# yield from vis.visit(arg.body)
# yield ")"
# return
# elif CoroutineMode.FAKE in self.generator:
# yield from self.visit(arg.body)
# return
# if compare_ast(func, ast.parse('sync', mode="eval").body):
# if CoroutineMode.ASYNC in self.generator:
# yield "co_await typon::Sync()"
# elif CoroutineMode.FAKE in self.generator:
# yield from ()
# return
# # TODO: precedence needed?
# if CoroutineMode.ASYNC in self.generator and node.is_await:
# yield "(" # TODO: temporary
# yield "co_await "
# node.in_await = True
# elif CoroutineMode.FAKE in self.generator:
# func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
# yield from self.prec("()").visit(func)
# yield "("
# yield from join(", ", map(self.reset().visit, node.args))
# yield ")"
# if CoroutineMode.ASYNC in self.generator and node.is_await:
# yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
yield "[]"
......@@ -157,12 +165,15 @@ class ExpressionVisitor(NodeVisitor):
yield "}"
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node))
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node))
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]:
raise NotImplementedError()
# if type(op) == ast.In:
# call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
# call.is_await = False
......
......@@ -6,7 +6,7 @@ from typing import Iterable
import transpiler.phases.typing.types as types
from transpiler.phases.typing.exceptions import UnresolvedTypeVariableError
from transpiler.phases.typing.types import BaseType
from transpiler.utils import UnsupportedNodeError
from transpiler.utils import UnsupportedNodeError, highlight
class UniversalVisitor:
......@@ -48,7 +48,8 @@ class NodeVisitor(UniversalVisitor):
def fix_name(self, name: str) -> str:
if name.startswith("__") and name.endswith("__"):
return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name)
return name
#return MAPPINGS.get(name, name)
def visit_BaseType(self, node: BaseType) -> Iterable[str]:
node = node.resolve()
......
......@@ -54,7 +54,7 @@ def transpile(source, name: str, path: Path):
# yield from code
# yield "}"
yield "#else"
yield "typon::Root root() const {"
yield "typon::Root root() {"
yield f"co_await dot(PROGRAMNS::{module.name()}, main)();"
yield "}"
yield "int main(int argc, char* argv[]) {"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment