Commit 13764aba authored by Tom Niget's avatar Tom Niget

Add proper implementation of Compare nodes

parent 6374abcf
# coding: utf-8
import ast
from transpiler.phases.typing.expr import DUNDER
from transpiler.phases.utils import make_lnd
from transpiler.utils import linenodata
class DesugarCompare(ast.NodeTransformer):
def visit_Compare(self, node: ast.Compare):
res = ast.BoolOp(ast.And(), [], **linenodata(node))
for left, op, right in zip([node.left] + node.comparators, node.ops, node.comparators):
lnd = make_lnd(left, right)
if type(op) in (ast.In, ast.NotIn):
left, right = right, left
call = ast.Call(
ast.Attribute(left, f"__{DUNDER[type(op)]}__", **lnd),
[right],
[],
**lnd
)
if type(op) == ast.NotIn:
call = ast.UnaryOp(ast.Not(), call, **lnd)
res.values.append(call)
if len(res.values) == 1:
return res.values[0]
return res
......@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
from typing import List, Iterable
from transpiler.phases.typing.types import UserType, FunctionType
from transpiler.phases.utils import make_lnd
from transpiler.utils import compare_ast, linenodata
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor
......@@ -91,22 +92,33 @@ class ExpressionVisitor(NodeVisitor):
# res = "co_await " + res
yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
def make_lnd(op1, op2):
return {
"lineno": op1.lineno,
"col_offset": op1.col_offset,
"end_lineno": op2.end_lineno,
"end_col_offset": op2.end_col_offset
}
# def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
# def make_lnd(op1, op2):
# return {
# "lineno": op1.lineno,
# "col_offset": op1.col_offset,
# "end_lineno": op2.end_lineno,
# "end_col_offset": op2.end_col_offset
# }
#
# operands = [node.left, *node.comparators]
# with self.prec_ctx("&&"):
# yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1], make_lnd(operands[0], operands[1]))
# for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# # TODO: cleaner code
# yield " && "
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
operands = [node.left, *node.comparators]
with self.prec_ctx("&&"):
yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1], make_lnd(operands[0], operands[1]))
for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# TODO: cleaner code
yield " && "
yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
cpp_op = {
ast.And: "&&",
ast.Or: "||"
}[type(node.op)]
with self.prec_ctx(cpp_op):
yield from self.visit_binary_operation(node.op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
for left, right in zip(node.values[1:], node.values[2:]):
yield f" {cpp_op} "
yield from self.visit_binary_operation(node.op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]:
# TODO
......@@ -173,6 +185,7 @@ class ExpressionVisitor(NodeVisitor):
call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
call.is_await = False
yield from self.visit_Call(call)
print(call.func.type)
return
op = SYMBOLS[type(op)]
# TODO: handle precedence locally since only binops really need it
......
......@@ -114,10 +114,10 @@ class ArgumentCountMismatchError(CompileError):
class ProtocolMismatchError(CompileError):
value: BaseType
protocol: BaseType
reason: Exception
reason: Exception | str
def __str__(self) -> str:
return f"Protocol mismatch: {highlight(self.value)} does not implement {highlight(self.protocol)}"
return f"Protocol mismatch: {str(self.value)} does not implement {str(self.protocol)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
......
......@@ -8,6 +8,7 @@ from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \
TY_SLICE
from transpiler.utils import linenodata
DUNDER = {
ast.Eq: "eq",
......@@ -30,6 +31,7 @@ DUNDER = {
ast.USub: "neg",
ast.UAdd: "pos",
ast.Invert: "invert",
ast.In: "contains",
}
class ScoperExprVisitor(ScoperVisitor):
......@@ -94,13 +96,10 @@ class ScoperExprVisitor(ScoperVisitor):
obj.python_func_used = True
return obj.type
def visit_Compare(self, node: ast.Compare) -> BaseType:
# todo:
self.visit(node.left)
for op, right in zip(node.ops, node.comparators):
self.visit(right)
def visit_BoolOp(self, node: ast.BoolOp) -> BaseType:
for value in node.values:
self.visit(value)
return TY_BOOL
#raise NotImplementedError(node)
def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func)
......
......@@ -147,6 +147,7 @@ class TypeOperator(BaseType, ABC):
cls.gen_parents = []
def __post_init__(self):
assert all(x is not None for x in self.args)
if self.name is None:
self.name = self.__class__.__name__
for name, factory in self.gen_methods.items():
......@@ -157,19 +158,30 @@ class TypeOperator(BaseType, ABC):
self.parents.append(gp)
self.methods = {**gp.methods, **self.methods}
self.is_protocol = self.is_protocol or self.is_protocol_gen
self._add_default_eq()
def _add_default_eq(self):
if "__eq__" not in self.methods:
if "DEFAULT_EQ" in globals():
self.methods["__eq__"] = DEFAULT_EQ
def matches_protocol(self, protocol: "TypeOperator"):
if hash(protocol) in self.match_cache:
return
from transpiler.phases.typing.exceptions import ProtocolMismatchError, TypeMismatchError
try:
dupl = protocol.gen_sub(self, {v.name: (TypeVariable(v.name) if isinstance(v.resolve(), TypeVariable) else v) for v in protocol.args})
self.match_cache.add(hash(protocol))
for name, ty in dupl.methods.items():
if name == "__eq__":
continue
if name not in self.methods:
raise ProtocolMismatchError(self, protocol, f"missing method {name}")
corresp = self.methods[name]
corresp.remove_self().unify(ty.remove_self())
except Exception as e:
self.match_cache.remove(hash(protocol))
from transpiler.phases.typing.exceptions import ProtocolMismatchError
except TypeMismatchError as e:
if hash(protocol) in self.match_cache:
self.match_cache.remove(hash(protocol))
raise ProtocolMismatchError(self, protocol, e)
def unify_internal(self, other: BaseType):
......@@ -331,18 +343,24 @@ class TypeType(TypeOperator):
self.args[0] = value
TY_SELF = TypeOperator.make_type("Self")
def self_gen_sub(this, typevars, _):
assert this is not None
return this
TY_SELF.gen_sub = self_gen_sub
TY_BOOL = TypeOperator.make_type("bool")
DEFAULT_EQ = FunctionType([TY_SELF, TY_SELF], TY_BOOL)
TY_BOOL._add_default_eq()
TY_TYPE = TypeOperator.make_type("type")
TY_INT = TypeOperator.make_type("int")
TY_FLOAT = TypeOperator.make_type("float")
TY_STR = TypeOperator.make_type("str")
TY_BYTES = TypeOperator.make_type("bytes")
TY_BOOL = TypeOperator.make_type("bool")
TY_COMPLEX = TypeOperator.make_type("complex")
TY_NONE = TypeOperator.make_type("NoneType")
#TY_MODULE = TypeOperator([], "module")
TY_VARARG = TypeOperator.make_type("vararg")
TY_SELF = TypeOperator.make_type("Self")
TY_SELF.gen_sub = lambda this, typevars, _: this
TY_SLICE = TypeOperator.make_type("slice")
......@@ -460,4 +478,4 @@ class UserType(TypeOperator):
class UnionType(TypeOperator):
def __init__(self, *args: List[BaseType]):
super().__init__(args, "Union")
self.parents.extend(args)
\ No newline at end of file
self.parents.extend(args)
......@@ -43,4 +43,12 @@ class AnnotationName:
def id(self):
return str(self.inner)
AnnotationName.__name__ = "Name"
\ No newline at end of file
AnnotationName.__name__ = "Name"
def make_lnd(op1, op2):
return {
"lineno": op1.lineno,
"col_offset": op1.col_offset,
"end_lineno": op2.end_lineno,
"end_col_offset": op2.end_col_offset
}
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment