Commit 8439961b authored by Tom Niget's avatar Tom Niget

Handle free generic functions

parent d2b145ff
...@@ -34,12 +34,15 @@ class Iterator(Generic[U]): ...@@ -34,12 +34,15 @@ class Iterator(Generic[U]):
def __next__(self) -> U: ... def __next__(self) -> U: ...
# type: TypeVar("U")
def next(it: Iterator[U], default: None) -> U: def next(it: Iterator[U], default: None) -> U:
... ...
# what happens with multiple functions # what happens with multiple functions
def identity(x: U) -> U:
...
assert identity(1)
assert identity("a")
def print(*args) -> None: ... def print(*args) -> None: ...
...@@ -47,5 +50,7 @@ def print(*args) -> None: ... ...@@ -47,5 +50,7 @@ def print(*args) -> None: ...
def range(*args) -> Iterator[int]: ... def range(*args) -> Iterator[int]: ...
def rangeb(*args) -> Iterator[bool]: ... def rangeb(*args) -> Iterator[bool]: ...
assert [6].__add__
assert [True].__add__
assert next(range(6), None) assert next(range(6), None)
assert next(rangeb(6), None) assert next(rangeb(6), None)
\ No newline at end of file
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional, List from typing import Optional, List
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_NONE, TypeType, TY_SELF from transpiler.phases.typing.types import BaseType, TY_NONE, TypeType, TY_SELF, TypeVariable
from transpiler.phases.utils import NodeVisitorSeq from transpiler.phases.utils import NodeVisitorSeq
...@@ -11,12 +11,19 @@ from transpiler.phases.utils import NodeVisitorSeq ...@@ -11,12 +11,19 @@ from transpiler.phases.utils import NodeVisitorSeq
class TypeAnnotationVisitor(NodeVisitorSeq): class TypeAnnotationVisitor(NodeVisitorSeq):
scope: Scope scope: Scope
cur_class: Optional[TypeType] = None cur_class: Optional[TypeType] = None
typevars: List[TypeVariable] = field(default_factory=list)
def visit_str(self, node: str) -> BaseType: def visit_str(self, node: str) -> BaseType:
if node in ("Self", "self") and self.cur_class: if node in ("Self", "self") and self.cur_class:
return TY_SELF return TY_SELF
if existing := self.scope.get(node): if existing := self.scope.get(node):
ty = existing.type ty = existing.type
if isinstance(ty, TypeVariable):
if existing is not self.scope.vars.get(node, None):
# Type variable from outer scope, so we copy it
ty = TypeVariable(ty.name)
self.scope.declare_local(node, ty) # todo: unneeded?
self.typevars.append(ty)
if isinstance(ty, TypeType): if isinstance(ty, TypeType):
return ty.type_object return ty.type_object
return ty return ty
......
...@@ -85,6 +85,8 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -85,6 +85,8 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Call(self, node: ast.Call) -> BaseType: def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func) ftype = self.visit(node.func)
if ftype.typevars:
ftype = ftype.gen_sub(None, {v.name: TypeVariable(v.name) for v in ftype.typevars})
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args]) rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
actual = rtype actual = rtype
node.is_await = False node.is_await = False
......
...@@ -64,6 +64,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -64,6 +64,7 @@ class StdlibVisitor(NodeVisitorSeq):
arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args] arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args]
ret_type = arg_visitor.visit(node.returns) ret_type = arg_visitor.visit(node.returns)
ty = FunctionType(arg_types, ret_type) ty = FunctionType(arg_types, ret_type)
ty.typevars = arg_visitor.typevars
if node.args.vararg: if node.args.vararg:
ty.variadic = True ty.variadic = True
if self.cur_class: if self.cur_class:
......
...@@ -15,6 +15,7 @@ class BaseType(ABC): ...@@ -15,6 +15,7 @@ class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False) members: Dict[str, "BaseType"] = field(default_factory=dict, init=False)
methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False) methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
parents: List["BaseType"] = field(default_factory=list, init=False) parents: List["BaseType"] = field(default_factory=list, init=False)
typevars: List["TypeVariable"] = field(default_factory=list, init=False)
def get_parents(self) -> List["BaseType"]: def get_parents(self) -> List["BaseType"]:
return self.parents return self.parents
...@@ -40,7 +41,7 @@ class BaseType(ABC): ...@@ -40,7 +41,7 @@ class BaseType(ABC):
def contains_internal(self, other: "BaseType") -> bool: def contains_internal(self, other: "BaseType") -> bool:
pass pass
def gen_sub(self, this: "BaseType", typevars) -> "Self": def gen_sub(self, this: "BaseType", typevars: Dict[str, "BaseType"]) -> "Self":
return self return self
def to_list(self) -> List["BaseType"]: def to_list(self) -> List["BaseType"]:
...@@ -181,11 +182,11 @@ class TypeOperator(BaseType, ABC): ...@@ -181,11 +182,11 @@ class TypeOperator(BaseType, ABC):
return hash((self.name, tuple(self.args))) return hash((self.name, tuple(self.args)))
def gen_sub(self, this: BaseType, typevars) -> "Self": def gen_sub(self, this: BaseType, typevars) -> "Self":
res = object.__new__(self.__class__) res = object.__new__(self.__class__) # todo: ugly... should make a clone()
if isinstance(this, TypeOperator): if isinstance(this, TypeOperator):
vardict = dict(zip(typevars.keys(), this.args)) vardict = dict(zip(typevars.keys(), this.args))
else: else:
vardict = {} vardict = typevars
res.args = [arg.resolve().gen_sub(this, vardict) for arg in self.args] res.args = [arg.resolve().gen_sub(this, vardict) for arg in self.args]
res.name = self.name res.name = self.name
res.variadic = self.variadic res.variadic = self.variadic
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment