Commit 13764aba authored by Tom Niget's avatar Tom Niget

Add proper implementation of Compare nodes

parent 6374abcf
# coding: utf-8
import ast
from transpiler.phases.typing.expr import DUNDER
from transpiler.phases.utils import make_lnd
from transpiler.utils import linenodata
class DesugarCompare(ast.NodeTransformer):
def visit_Compare(self, node: ast.Compare):
res = ast.BoolOp(ast.And(), [], **linenodata(node))
for left, op, right in zip([node.left] + node.comparators, node.ops, node.comparators):
lnd = make_lnd(left, right)
if type(op) in (ast.In, ast.NotIn):
left, right = right, left
call = ast.Call(
ast.Attribute(left, f"__{DUNDER[type(op)]}__", **lnd),
[right],
[],
**lnd
)
if type(op) == ast.NotIn:
call = ast.UnaryOp(ast.Not(), call, **lnd)
res.values.append(call)
if len(res.values) == 1:
return res.values[0]
return res
...@@ -4,6 +4,7 @@ from dataclasses import dataclass, field ...@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
from typing import List, Iterable from typing import List, Iterable
from transpiler.phases.typing.types import UserType, FunctionType from transpiler.phases.typing.types import UserType, FunctionType
from transpiler.phases.utils import make_lnd
from transpiler.utils import compare_ast, linenodata from transpiler.utils import compare_ast, linenodata
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor
...@@ -91,22 +92,33 @@ class ExpressionVisitor(NodeVisitor): ...@@ -91,22 +92,33 @@ class ExpressionVisitor(NodeVisitor):
# res = "co_await " + res # res = "co_await " + res
yield res yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: # def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
def make_lnd(op1, op2): # def make_lnd(op1, op2):
return { # return {
"lineno": op1.lineno, # "lineno": op1.lineno,
"col_offset": op1.col_offset, # "col_offset": op1.col_offset,
"end_lineno": op2.end_lineno, # "end_lineno": op2.end_lineno,
"end_col_offset": op2.end_col_offset # "end_col_offset": op2.end_col_offset
} # }
#
# operands = [node.left, *node.comparators]
# with self.prec_ctx("&&"):
# yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1], make_lnd(operands[0], operands[1]))
# for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# # TODO: cleaner code
# yield " && "
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
operands = [node.left, *node.comparators] def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
with self.prec_ctx("&&"): cpp_op = {
yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1], make_lnd(operands[0], operands[1])) ast.And: "&&",
for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]): ast.Or: "||"
# TODO: cleaner code }[type(node.op)]
yield " && " with self.prec_ctx(cpp_op):
yield from self.visit_binary_operation(op, left, right, make_lnd(left, right)) yield from self.visit_binary_operation(node.op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
for left, right in zip(node.values[1:], node.values[2:]):
yield f" {cpp_op} "
yield from self.visit_binary_operation(node.op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]: def visit_Call(self, node: ast.Call) -> Iterable[str]:
# TODO # TODO
...@@ -173,6 +185,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -173,6 +185,7 @@ class ExpressionVisitor(NodeVisitor):
call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd) call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
call.is_await = False call.is_await = False
yield from self.visit_Call(call) yield from self.visit_Call(call)
print(call.func.type)
return return
op = SYMBOLS[type(op)] op = SYMBOLS[type(op)]
# TODO: handle precedence locally since only binops really need it # TODO: handle precedence locally since only binops really need it
......
...@@ -114,10 +114,10 @@ class ArgumentCountMismatchError(CompileError): ...@@ -114,10 +114,10 @@ class ArgumentCountMismatchError(CompileError):
class ProtocolMismatchError(CompileError): class ProtocolMismatchError(CompileError):
value: BaseType value: BaseType
protocol: BaseType protocol: BaseType
reason: Exception reason: Exception | str
def __str__(self) -> str: def __str__(self) -> str:
return f"Protocol mismatch: {highlight(self.value)} does not implement {highlight(self.protocol)}" return f"Protocol mismatch: {str(self.value)} does not implement {str(self.protocol)}"
def detail(self, last_node: ast.AST = None) -> str: def detail(self, last_node: ast.AST = None) -> str:
return f""" return f"""
......
...@@ -8,6 +8,7 @@ from transpiler.phases.typing.common import ScoperVisitor ...@@ -8,6 +8,7 @@ from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \ from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \ TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \
TY_SLICE TY_SLICE
from transpiler.utils import linenodata
DUNDER = { DUNDER = {
ast.Eq: "eq", ast.Eq: "eq",
...@@ -30,6 +31,7 @@ DUNDER = { ...@@ -30,6 +31,7 @@ DUNDER = {
ast.USub: "neg", ast.USub: "neg",
ast.UAdd: "pos", ast.UAdd: "pos",
ast.Invert: "invert", ast.Invert: "invert",
ast.In: "contains",
} }
class ScoperExprVisitor(ScoperVisitor): class ScoperExprVisitor(ScoperVisitor):
...@@ -94,13 +96,10 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -94,13 +96,10 @@ class ScoperExprVisitor(ScoperVisitor):
obj.python_func_used = True obj.python_func_used = True
return obj.type return obj.type
def visit_Compare(self, node: ast.Compare) -> BaseType: def visit_BoolOp(self, node: ast.BoolOp) -> BaseType:
# todo: for value in node.values:
self.visit(node.left) self.visit(value)
for op, right in zip(node.ops, node.comparators):
self.visit(right)
return TY_BOOL return TY_BOOL
#raise NotImplementedError(node)
def visit_Call(self, node: ast.Call) -> BaseType: def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func) ftype = self.visit(node.func)
......
...@@ -147,6 +147,7 @@ class TypeOperator(BaseType, ABC): ...@@ -147,6 +147,7 @@ class TypeOperator(BaseType, ABC):
cls.gen_parents = [] cls.gen_parents = []
def __post_init__(self): def __post_init__(self):
assert all(x is not None for x in self.args)
if self.name is None: if self.name is None:
self.name = self.__class__.__name__ self.name = self.__class__.__name__
for name, factory in self.gen_methods.items(): for name, factory in self.gen_methods.items():
...@@ -157,19 +158,30 @@ class TypeOperator(BaseType, ABC): ...@@ -157,19 +158,30 @@ class TypeOperator(BaseType, ABC):
self.parents.append(gp) self.parents.append(gp)
self.methods = {**gp.methods, **self.methods} self.methods = {**gp.methods, **self.methods}
self.is_protocol = self.is_protocol or self.is_protocol_gen self.is_protocol = self.is_protocol or self.is_protocol_gen
self._add_default_eq()
def _add_default_eq(self):
if "__eq__" not in self.methods:
if "DEFAULT_EQ" in globals():
self.methods["__eq__"] = DEFAULT_EQ
def matches_protocol(self, protocol: "TypeOperator"): def matches_protocol(self, protocol: "TypeOperator"):
if hash(protocol) in self.match_cache: if hash(protocol) in self.match_cache:
return return
from transpiler.phases.typing.exceptions import ProtocolMismatchError, TypeMismatchError
try: try:
dupl = protocol.gen_sub(self, {v.name: (TypeVariable(v.name) if isinstance(v.resolve(), TypeVariable) else v) for v in protocol.args}) dupl = protocol.gen_sub(self, {v.name: (TypeVariable(v.name) if isinstance(v.resolve(), TypeVariable) else v) for v in protocol.args})
self.match_cache.add(hash(protocol)) self.match_cache.add(hash(protocol))
for name, ty in dupl.methods.items(): for name, ty in dupl.methods.items():
if name == "__eq__":
continue
if name not in self.methods:
raise ProtocolMismatchError(self, protocol, f"missing method {name}")
corresp = self.methods[name] corresp = self.methods[name]
corresp.remove_self().unify(ty.remove_self()) corresp.remove_self().unify(ty.remove_self())
except Exception as e: except TypeMismatchError as e:
self.match_cache.remove(hash(protocol)) if hash(protocol) in self.match_cache:
from transpiler.phases.typing.exceptions import ProtocolMismatchError self.match_cache.remove(hash(protocol))
raise ProtocolMismatchError(self, protocol, e) raise ProtocolMismatchError(self, protocol, e)
def unify_internal(self, other: BaseType): def unify_internal(self, other: BaseType):
...@@ -331,18 +343,24 @@ class TypeType(TypeOperator): ...@@ -331,18 +343,24 @@ class TypeType(TypeOperator):
self.args[0] = value self.args[0] = value
TY_SELF = TypeOperator.make_type("Self")
def self_gen_sub(this, typevars, _):
assert this is not None
return this
TY_SELF.gen_sub = self_gen_sub
TY_BOOL = TypeOperator.make_type("bool")
DEFAULT_EQ = FunctionType([TY_SELF, TY_SELF], TY_BOOL)
TY_BOOL._add_default_eq()
TY_TYPE = TypeOperator.make_type("type") TY_TYPE = TypeOperator.make_type("type")
TY_INT = TypeOperator.make_type("int") TY_INT = TypeOperator.make_type("int")
TY_FLOAT = TypeOperator.make_type("float") TY_FLOAT = TypeOperator.make_type("float")
TY_STR = TypeOperator.make_type("str") TY_STR = TypeOperator.make_type("str")
TY_BYTES = TypeOperator.make_type("bytes") TY_BYTES = TypeOperator.make_type("bytes")
TY_BOOL = TypeOperator.make_type("bool")
TY_COMPLEX = TypeOperator.make_type("complex") TY_COMPLEX = TypeOperator.make_type("complex")
TY_NONE = TypeOperator.make_type("NoneType") TY_NONE = TypeOperator.make_type("NoneType")
#TY_MODULE = TypeOperator([], "module") #TY_MODULE = TypeOperator([], "module")
TY_VARARG = TypeOperator.make_type("vararg") TY_VARARG = TypeOperator.make_type("vararg")
TY_SELF = TypeOperator.make_type("Self")
TY_SELF.gen_sub = lambda this, typevars, _: this
TY_SLICE = TypeOperator.make_type("slice") TY_SLICE = TypeOperator.make_type("slice")
...@@ -460,4 +478,4 @@ class UserType(TypeOperator): ...@@ -460,4 +478,4 @@ class UserType(TypeOperator):
class UnionType(TypeOperator): class UnionType(TypeOperator):
def __init__(self, *args: List[BaseType]): def __init__(self, *args: List[BaseType]):
super().__init__(args, "Union") super().__init__(args, "Union")
self.parents.extend(args) self.parents.extend(args)
\ No newline at end of file
...@@ -43,4 +43,12 @@ class AnnotationName: ...@@ -43,4 +43,12 @@ class AnnotationName:
def id(self): def id(self):
return str(self.inner) return str(self.inner)
AnnotationName.__name__ = "Name" AnnotationName.__name__ = "Name"
\ No newline at end of file
def make_lnd(op1, op2):
return {
"lineno": op1.lineno,
"col_offset": op1.col_offset,
"end_lineno": op2.end_lineno,
"end_col_offset": op2.end_col_offset
}
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment