Commit 566f4bb0 authored by Tom Niget's avatar Tom Niget

Fix optional parameter unification

parent ae470644
......@@ -93,3 +93,12 @@ class file:
def close(self) -> Task[None]: ...
def open(filename: str, mode: str) -> Task[file]: ...
def __test_opt(x: int, y: int = 5) -> int:
...
assert __test_opt
assert __test_opt(5)
assert __test_opt(5, 6)
assert not __test_opt(5, 6, 7)
assert not __test_opt()
\ No newline at end of file
......@@ -9,13 +9,8 @@ class IfMainVisitor(ast.NodeVisitor):
for i, stmt in enumerate(node.body):
if isinstance(stmt, ast.If):
if not stmt.orelse and compare_ast(stmt.test, NAME_MAIN):
new_node = ast.FunctionDef(
name="main",
args=ast.arguments(args=[]),
body=stmt.body,
decorator_list=[],
returns=None
)
new_node = ast.parse("def main(): pass").body[0]
new_node.body = stmt.body
new_node.is_main = True
node.body[i] = new_node
return
\ No newline at end of file
......@@ -55,7 +55,10 @@ class ScoperBlockVisitor(ScoperVisitor):
raise NotImplementedError(node)
target = node.targets[0]
ty = self.get_type(node.value)
try:
node.is_declare = self.visit_assign_target(target, ty)
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"`{ast.unparse(node)}: {e}")
def visit_assign_target(self, target, decl_val: BaseType) -> bool:
if isinstance(target, ast.Name):
......@@ -92,6 +95,7 @@ class ScoperBlockVisitor(ScoperVisitor):
scope.function = scope
node.inner_scope = scope
node.type = ftype
ftype.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body:
......
......@@ -77,6 +77,7 @@ class StdlibVisitor(NodeVisitorSeq):
ty.typevars = arg_visitor.typevars
if node.args.vararg:
ty.variadic = True
ty.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
if self.cur_class:
assert isinstance(self.cur_class, TypeType)
if isinstance(self.cur_class.type_object, ABCMeta):
......@@ -86,6 +87,15 @@ class StdlibVisitor(NodeVisitorSeq):
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert):
if isinstance(node.test, ast.UnaryOp) and isinstance(node.test.op, ast.Not):
oper = node.test.operand
try:
res = self.expr().visit(oper)
except:
print("Type of", ast.unparse(oper), ":=", "INVALID")
else:
raise AssertionError(f"Assertion should fail, got {res} for {ast.unparse(oper)}")
else:
print("Type of", ast.unparse(node.test), ":=", self.expr().visit(node.test))
def visit_Call(self, node: ast.Call) -> BaseType:
......
......@@ -144,6 +144,8 @@ class TypeOperator(BaseType, ABC):
def unify_internal(self, other: BaseType):
if not isinstance(other, TypeOperator):
raise IncompatibleTypesError()
if len(self.args) < len(other.args):
return other.unify_internal(self)
if type(self) != type(other):
for parent in other.get_parents():
try:
......@@ -161,28 +163,34 @@ class TypeOperator(BaseType, ABC):
return
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different type and no common parents")
if len(self.args) != len(other.args):
a, b = self, other
a_opt = a.optional_at is not None
b_opt = b.optional_at is not None
if a_opt and b_opt:
raise IncompatibleTypesError(f"This really should never happen")
if b_opt:
a, b = b, a
if a_opt:
# a = f(A, B; C=?, D=?)
# b = g(A, B, ... ?)
# either
# |a| < |b| => b has more args => invalid
# |a| ≥ |b| => b has less args => valid, up to |b|, so normal course of events
x = True
# c'est pété => utiliser le truc de la boucle en bas
# TODO: pas implémenté
if not (self.variadic or other.variadic):
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
pass
# # a, b = self, other
# # if a.optio
# # a_opt = a.optional_at is not None
# # b_opt = b.optional_at is not None
# # if a_opt and b_opt:
# # raise IncompatibleTypesError(f"This really should never happen")
# # if b_opt:
# # other.unify_internal(self)
# # return
# if a_opt:
# if len(a.args) < len(b.args):
# raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
# # a = f(A, B; C=?, D=?)
# # b = g(A, B, ... ?)
# # either
# # |a| < |b| => b has more args => invalid
# # |a| ≥ |b| => b has less args => valid, up to |b|, so normal course of events
#
# x = True
#
# # c'est pété => utiliser le truc de la boucle en bas
#
# # TODO: pas implémenté
#
# # if not (self.variadic or other.variadic):
# # raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
if len(self.args) == 0:
if self.name != other.name:
raise IncompatibleTypesError(f"Cannot unify {self} and {other}")
......@@ -190,6 +198,12 @@ class TypeOperator(BaseType, ABC):
if a is None and self.variadic or b is None and other.variadic:
continue
if a is not None and b is None:
if i >= self.optional_at:
continue
else:
raise IncompatibleTypesError(f"Cannot unify {self} and {other}, not enough arguments")
if isinstance(a, BaseType) and isinstance(b, BaseType):
a.unify(b)
else:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment