Commit ae6126f2 authored by Tom Niget's avatar Tom Niget

Fix desugaring and internal handling of bin/unops

parent 2cb61dd6
...@@ -11,6 +11,7 @@ import traceback ...@@ -11,6 +11,7 @@ import traceback
import colorama import colorama
from transpiler.phases.desugar_compare import DesugarCompare from transpiler.phases.desugar_compare import DesugarCompare
from transpiler.phases.desugar_op import DesugarOp
colorama.init() colorama.init()
...@@ -177,6 +178,7 @@ def transpile(source, name="<module>", path=None): ...@@ -177,6 +178,7 @@ def transpile(source, name="<module>", path=None):
IfMainVisitor().visit(res) IfMainVisitor().visit(res)
res = DesugarWith().visit(res) res = DesugarWith().visit(res)
res = DesugarCompare().visit(res) res = DesugarCompare().visit(res)
res = DesugarOp().visit(res)
ScoperBlockVisitor().visit(res) ScoperBlockVisitor().visit(res)
# print(res.scope) # print(res.scope)
......
...@@ -28,6 +28,28 @@ SYMBOLS = { ...@@ -28,6 +28,28 @@ SYMBOLS = {
} }
"""Mapping of Python AST nodes to C++ symbols.""" """Mapping of Python AST nodes to C++ symbols."""
DUNDER_SYMBOLS = {
"__eq__": "==",
"__ne__": "!=",
"__lt__": "<",
"__gt__": ">",
"__ge__": ">=",
"__le__": "<=",
"__add__": "+",
"__sub__": "-",
"__mul__": "*",
"__div__": "/",
"__mod__": "%",
"__lshift__": "<<",
"__rshift__": ">>",
"__xor__": "^",
"__or__": "|",
"__and__": "&",
"__invert__": "~",
"__neg__": "-",
"__pos__": "+",
}
PRECEDENCE = [ PRECEDENCE = [
("()", "[]", ".",), ("()", "[]", ".",),
("unary", "co_await"), ("unary", "co_await"),
......
# coding: utf-8 # coding: utf-8
import ast import ast
from transpiler.phases.typing.expr import DUNDER
from transpiler.phases.utils import make_lnd from transpiler.phases.utils import make_lnd
from transpiler.utils import linenodata from transpiler.utils import linenodata
DUNDER = {
ast.Eq: "eq",
ast.NotEq: "ne",
ast.Lt: "lt",
ast.Gt: "gt",
ast.GtE: "ge",
ast.LtE: "le",
ast.In: "contains",
ast.NotIn: "contains",
}
class DesugarCompare(ast.NodeTransformer): class DesugarCompare(ast.NodeTransformer):
def visit_Compare(self, node: ast.Compare): def visit_Compare(self, node: ast.Compare):
res = ast.BoolOp(ast.And(), [], **linenodata(node)) res = ast.BoolOp(ast.And(), [], **linenodata(node))
for left, op, right in zip([node.left] + node.comparators, node.ops, node.comparators): operands = list(map(self.visit, [node.left, *node.comparators]))
for left, op, right in zip(operands, node.ops, operands[1:]):
lnd = make_lnd(left, right) lnd = make_lnd(left, right)
if type(op) in (ast.In, ast.NotIn): if type(op) in (ast.In, ast.NotIn):
left, right = right, left left, right = right, left
...@@ -25,3 +35,21 @@ class DesugarCompare(ast.NodeTransformer): ...@@ -25,3 +35,21 @@ class DesugarCompare(ast.NodeTransformer):
if len(res.values) == 1: if len(res.values) == 1:
return res.values[0] return res.values[0]
return res return res
# def visit_Compare(self, node: ast.Compare):
# res = ast.BoolOp(ast.And(), [], **linenodata(node))
# operands = list(map(self.visit, [node.left, *node.comparators]))
# for left, op, right in zip(operands, node.ops, operands[1:]):
# lnd = make_lnd(left, right)
# call = ast.Compare(
# left,
# [op],
# [right],
# **lnd
# )
# if type(op) == ast.NotIn:
# call = ast.UnaryOp(ast.Not(), call, **lnd)
# res.values.append(call)
# if len(res.values) == 1:
# return res.values[0]
# return res
# coding: utf-8
import ast
from transpiler.utils import linenodata
DUNDER = {
ast.Mult: "mul",
ast.Add: "add",
ast.Sub: "sub",
ast.Div: "truediv",
ast.FloorDiv: "floordiv",
ast.Mod: "mod",
ast.LShift: "lshift",
ast.RShift: "rshift",
ast.BitXor: "xor",
ast.BitOr: "or",
ast.BitAnd: "and",
ast.USub: "neg",
ast.UAdd: "pos",
ast.Invert: "invert",
}
class DesugarOp(ast.NodeTransformer):
def visit_BinOp(self, node: ast.BinOp):
lnd = linenodata(node)
return ast.Call(
func=ast.Attribute(
value=self.visit(node.left),
attr=f"__{DUNDER[type(node.op)]}__",
ctx=ast.Load(),
**lnd
),
args=[self.visit(node.right)],
keywords={},
**lnd
)
def visit_UnaryOp(self, node: ast.UnaryOp):
lnd = linenodata(node)
if type(node.op) == ast.Not:
return ast.UnaryOp(
operand=self.visit(node.operand),
op=node.op,
**lnd
)
return ast.Call(
func=ast.Attribute(
value=self.visit(node.operand),
attr=f"__{DUNDER[type(node.op)]}__",
ctx=ast.Load(),
**lnd
),
args=[],
keywords={},
**lnd
)
# def visit_AugAssign(self, node: ast.AugAssign):
# return
...@@ -31,4 +31,7 @@ def process(items: list[ast.withitem], body: list[ast.stmt]) -> PlainBlock: ...@@ -31,4 +31,7 @@ def process(items: list[ast.withitem], body: list[ast.stmt]) -> PlainBlock:
class DesugarWith(ast.NodeTransformer): class DesugarWith(ast.NodeTransformer):
def visit_With(self, node: ast.With): def visit_With(self, node: ast.With):
return process(node.items, node.body) return process(
list(map(self.visit, node.items)),
list(map(self.visit, node.body))
)
...@@ -6,7 +6,7 @@ from typing import List, Iterable ...@@ -6,7 +6,7 @@ from typing import List, Iterable
from transpiler.phases.typing.types import UserType, FunctionType from transpiler.phases.typing.types import UserType, FunctionType
from transpiler.phases.utils import make_lnd from transpiler.phases.utils import make_lnd
from transpiler.utils import compare_ast, linenodata from transpiler.utils import compare_ast, linenodata
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS, DUNDER_SYMBOLS
from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor
from transpiler.phases.typing.scope import Scope, VarKind from transpiler.phases.typing.scope import Scope, VarKind
...@@ -115,10 +115,10 @@ class ExpressionVisitor(NodeVisitor): ...@@ -115,10 +115,10 @@ class ExpressionVisitor(NodeVisitor):
ast.Or: "||" ast.Or: "||"
}[type(node.op)] }[type(node.op)]
with self.prec_ctx(cpp_op): with self.prec_ctx(cpp_op):
yield from self.visit_binary_operation(node.op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1])) yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
for left, right in zip(node.values[1:], node.values[2:]): for left, right in zip(node.values[1:], node.values[2:]):
yield f" {cpp_op} " yield f" {cpp_op} "
yield from self.visit_binary_operation(node.op, left, right, make_lnd(left, right)) yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]: def visit_Call(self, node: ast.Call) -> Iterable[str]:
# TODO # TODO
...@@ -129,6 +129,13 @@ class ExpressionVisitor(NodeVisitor): ...@@ -129,6 +129,13 @@ class ExpressionVisitor(NodeVisitor):
if getattr(node, "kwargs", None): if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs") raise NotImplementedError(node, "kwargs")
func = node.func func = node.func
if isinstance(func, ast.Attribute):
if sym := DUNDER_SYMBOLS.get(func.attr, None):
if len(node.args) == 1:
yield from self.visit_binary_operation(sym, func.value, node.args[0], linenodata(node))
else:
yield from self.visit_unary_operation(sym, func.value)
return
for name in ("fork", "future"): for name in ("fork", "future"):
if compare_ast(func, ast.parse(name, mode="eval").body): if compare_ast(func, ast.parse(name, mode="eval").body):
assert len(node.args) == 1 assert len(node.args) == 1
...@@ -180,14 +187,18 @@ class ExpressionVisitor(NodeVisitor): ...@@ -180,14 +187,18 @@ class ExpressionVisitor(NodeVisitor):
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]: def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node)) yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node))
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node))
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]: def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]:
if type(op) == ast.In: # if type(op) == ast.In:
call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd) # call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
call.is_await = False # call.is_await = False
yield from self.visit_Call(call) # yield from self.visit_Call(call)
print(call.func.type) # print(call.func.type)
return # return
op = SYMBOLS[type(op)] if type(op) != str:
op = SYMBOLS[type(op)]
# TODO: handle precedence locally since only binops really need it # TODO: handle precedence locally since only binops really need it
# we could just store the history of traversed nodes and check if the last one was a binop # we could just store the history of traversed nodes and check if the last one was a binop
prio = self.precedence and PRECEDENCE_LEVELS[self.precedence[-1]] < PRECEDENCE_LEVELS[op] prio = self.precedence and PRECEDENCE_LEVELS[self.precedence[-1]] < PRECEDENCE_LEVELS[op]
...@@ -206,9 +217,9 @@ class ExpressionVisitor(NodeVisitor): ...@@ -206,9 +217,9 @@ class ExpressionVisitor(NodeVisitor):
yield "dotp" yield "dotp"
else: else:
yield "dot" yield "dot"
yield "(" yield "(("
yield from self.visit(node.value) yield from self.visit(node.value)
yield ", " yield "), "
yield self.fix_name(node.attr) yield self.fix_name(node.attr)
yield ")" yield ")"
else: else:
...@@ -261,8 +272,11 @@ class ExpressionVisitor(NodeVisitor): ...@@ -261,8 +272,11 @@ class ExpressionVisitor(NodeVisitor):
yield "]" yield "]"
def visit_UnaryOp(self, node: ast.UnaryOp) -> Iterable[str]: def visit_UnaryOp(self, node: ast.UnaryOp) -> Iterable[str]:
yield from self.visit(node.op) yield from self.visit_unary_operation(node.op, node.operand)
yield from self.prec("unary").visit(node.operand)
def visit_unary_operation(self, op, operand) -> Iterable[str]:
yield op
yield from self.prec("unary").visit(operand)
def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]: def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]:
with self.prec_ctx("?:"): with self.prec_ctx("?:"):
......
...@@ -3,10 +3,11 @@ import dataclasses ...@@ -3,10 +3,11 @@ import dataclasses
import importlib import importlib
from dataclasses import dataclass from dataclasses import dataclass
from transpiler.exceptions import CompileError
from transpiler.utils import highlight, linenodata from transpiler.utils import highlight, linenodata
from transpiler.phases.typing import make_mod_decl from transpiler.phases.typing import make_mod_decl
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
...@@ -249,11 +250,16 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -249,11 +250,16 @@ class ScoperBlockVisitor(ScoperVisitor):
self.scope.global_scope.vars[name] = VarDecl(VarKind.LOCAL, None) self.scope.global_scope.vars[name] = VarDecl(VarKind.LOCAL, None)
def visit_AugAssign(self, node: ast.AugAssign): def visit_AugAssign(self, node: ast.AugAssign):
equivalent = ast.Assign( target, value = map(self.get_type, (node.target, node.value))
targets=[node.target], try:
value=ast.BinOp(left=node.target, op=node.op, right=node.value, **linenodata(node)), self.expr().make_dunder([target, value], "i" + DUNDER[type(node.op)])
**linenodata(node)) except CompileError as e:
self.visit(equivalent) self.visit_assign_target(node.target, self.expr().make_dunder([target, value], DUNDER[type(node.op)]))
# equivalent = ast.Assign(
# targets=[node.target],
# value=ast.BinOp(left=node.target, op=node.op, right=node.value, **linenodata(node)),
# **linenodata(node))
# self.visit(equivalent)
def visit(self, node: ast.AST): def visit(self, node: ast.AST):
if isinstance(node, ast.AST): if isinstance(node, ast.AST):
......
...@@ -149,9 +149,16 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -149,9 +149,16 @@ class ScoperExprVisitor(ScoperVisitor):
node.body.decls = decls node.body.decls = decls
return ftype return ftype
def visit_BinOp(self, node: ast.BinOp) -> BaseType: # def visit_BinOp(self, node: ast.BinOp) -> BaseType:
left, right = map(self.visit, (node.left, node.right)) # left, right = map(self.visit, (node.left, node.right))
return self.make_dunder([left, right], DUNDER[type(node.op)]) # return self.make_dunder([left, right], DUNDER[type(node.op)])
# def visit_Compare(self, node: ast.Compare) -> BaseType:
# left, right = map(self.visit, (node.left, node.comparators[0]))
# op = node.ops[0]
# if type(op) == ast.In:
# left, right = right, left
# return self.make_dunder([left, right], DUNDER[type(op)])
def visit_Attribute(self, node: ast.Attribute) -> BaseType: def visit_Attribute(self, node: ast.Attribute) -> BaseType:
ltype = self.visit(node.value) ltype = self.visit(node.value)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment