Commit 1e590d39 authored by Tom Niget's avatar Tom Niget

Forbid use of raw type variable

parent b1ddc670
...@@ -50,9 +50,7 @@ def identity_2(x: U, y: V) -> Tuple[U, V]: ...@@ -50,9 +50,7 @@ def identity_2(x: U, y: V) -> Tuple[U, V]:
assert list.__add__ assert list.__add__
assert list.__add__([5], [[6][0]]) assert list.__add__([5], [[6][0]])
assert list[U].__add__ #assert list[U].__add__
assert list[U].__add__([1], [2])
assert list[U].__add__
assert list[int].__add__ assert list[int].__add__
assert identity_2(1, "a") assert identity_2(1, "a")
assert lambda x, y: identity_2(x, y) assert lambda x, y: identity_2(x, y)
......
...@@ -18,11 +18,11 @@ class TypeAnnotationVisitor(NodeVisitorSeq): ...@@ -18,11 +18,11 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
return TY_SELF return TY_SELF
if existing := self.scope.get(node): if existing := self.scope.get(node):
ty = existing.type ty = existing.type
if isinstance(ty, TypeVariable): if isinstance(ty, TypeType) and isinstance(ty.type_object, TypeVariable):
if existing is not self.scope.vars.get(node, None): if existing is not self.scope.vars.get(node, None):
# Type variable from outer scope, so we copy it # Type variable from outer scope, so we copy it
ty = TypeVariable(ty.name) ty = TypeVariable(ty.type_object.name)
self.scope.declare_local(node, ty) # todo: unneeded? self.scope.declare_local(node, TypeType(ty)) # todo: unneeded?
self.typevars.append(ty) self.typevars.append(ty)
if isinstance(ty, TypeType): if isinstance(ty, TypeType):
return ty.type_object return ty.type_object
......
...@@ -74,6 +74,8 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -74,6 +74,8 @@ class ScoperExprVisitor(ScoperVisitor):
obj = self.scope.get(node.id) obj = self.scope.get(node.id)
if not obj: if not obj:
raise NameError(f"Name {node.id} is not defined") raise NameError(f"Name {node.id} is not defined")
if isinstance(obj.type, TypeType) and isinstance(obj.type.type_object, TypeVariable):
raise NameError(f"Use of type variable")
return obj.type return obj.type
def visit_Compare(self, node: ast.Compare) -> BaseType: def visit_Compare(self, node: ast.Compare) -> BaseType:
......
...@@ -81,7 +81,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -81,7 +81,7 @@ class StdlibVisitor(NodeVisitorSeq):
def visit_Call(self, node: ast.Call) -> BaseType: def visit_Call(self, node: ast.Call) -> BaseType:
ty_op = self.visit(node.func) ty_op = self.visit(node.func)
if isinstance(ty_op, TypeType): if isinstance(ty_op, TypeType):
return ty_op.type_object(*[ast.literal_eval(arg) for arg in node.args]) return TypeType(ty_op.type_object(*[ast.literal_eval(arg) for arg in node.args]))
raise NotImplementedError(ast.unparse(node)) raise NotImplementedError(ast.unparse(node))
def anno(self) -> "TypeAnnotationVisitor": def anno(self) -> "TypeAnnotationVisitor":
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment