Commit a32ef1be authored by Tom Niget's avatar Tom Niget

Add initial (WIP) support for protocols

parent 8afba292
from typing import Self, TypeVar, Generic
from typing import Self, TypeVar, Generic, Protocol
class int:
def __add__(self, other: Self) -> Self: ...
......@@ -17,29 +17,31 @@ V = TypeVar("V")
# TODO: really, these should work as interfaces, on a duck-typing basis. it's gonna be a hell of a ride to implement
# unification for this
class HasLen:
class HasLen(Protocol):
def __len__(self) -> int: ...
def len(x: HasLen) -> int:
...
class Iterator(Generic[U]):
class Iterator(Protocol[U]):
def __iter__(self) -> Self: ...
def __next__(self) -> U: ...
class Iterable(Generic[U]):
class Iterable(Protocol[U]):
def __iter__(self) -> Iterator[U]: ...
class str(HasLen):
class str:
def find(self, sub: Self) -> int: ...
def format(self, *args) -> Self: ...
def encode(self, encoding: Self) -> bytes: ...
def __len__(self) -> int: ...
class bytes(HasLen):
class bytes:
def decode(self, encoding: str) -> str: ...
def __len__(self) -> int: ...
class list(Generic[U], HasLen, Iterable[U]):
class list(Generic[U]):
def __add__(self, other: Self) -> Self: ...
......@@ -48,6 +50,8 @@ class list(Generic[U], HasLen, Iterable[U]):
def __getitem__(self, index: int) -> U: ...
def pop(self, index: int = -1) -> U: ...
def __iter__(self) -> Iterator[U]: ...
def __len__(self) -> int: ...
assert [1, 2].__iter__()
assert list[int].__iter__
......
......@@ -56,6 +56,9 @@ class StdlibVisitor(NodeVisitorSeq):
sliceval = [n.id for n in b.slice.value.elts]
if isinstance(b.value, ast.Name) and b.value.id == "Generic":
typevars = sliceval
elif isinstance(b.value, ast.Name) and b.value.id == "Protocol":
typevars = sliceval
ty.type_object.is_protocol_gen = True
else:
idxs = [typevars.index(v) for v in sliceval]
parent = self.visit(b.value)
......@@ -63,12 +66,15 @@ class StdlibVisitor(NodeVisitorSeq):
assert isinstance(ty.type_object, ABCMeta)
ty.type_object.gen_parents.append(lambda selfvars: parent.type_object(*[selfvars[i] for i in idxs]))
else:
parent = self.visit(b)
assert isinstance(parent, TypeType)
if isinstance(ty.type_object, ABCMeta):
ty.type_object.gen_parents.append(parent.type_object)
if isinstance(b, ast.Name) and b.id == "Protocol":
ty.type_object.is_protocol_gen = True
else:
ty.type_object.parents.append(parent.type_object)
parent = self.visit(b)
assert isinstance(parent, TypeType)
if isinstance(ty.type_object, ABCMeta):
ty.type_object.gen_parents.append(parent.type_object)
else:
ty.type_object.parents.append(parent.type_object)
if not typevars and not existing:
ty.type_object = ty.type_object()
cl_scope = self.scope.child(ScopeKind.CLASS)
......
......@@ -121,6 +121,9 @@ class TypeOperator(BaseType, ABC):
optional_at: Optional[int] = None
gen_methods: ClassVar[Dict[str, GenMethodFactory]] = {}
gen_parents: ClassVar[List[BaseType]] = []
is_protocol: bool = False
is_protocol_gen: ClassVar[bool] = False
match_cache: set["TypeOperator"] = field(default_factory=set, init=False)
@staticmethod
def make_type(name: str):
......@@ -144,12 +147,31 @@ class TypeOperator(BaseType, ABC):
gp = gp(self.args)
self.parents.append(gp)
self.methods = {**gp.methods, **self.methods}
self.is_protocol = self.is_protocol or self.is_protocol_gen
def matches_protocol(self, protocol: "TypeOperator"):
if hash(protocol) in self.match_cache:
return
try:
dupl = protocol.gen_sub(self, {v.name: (TypeVariable(v.name) if isinstance(v.resolve(), TypeVariable) else v) for v in protocol.args})
self.match_cache.add(hash(protocol))
for name, ty in dupl.methods.items():
corresp = self.methods[name]
corresp.remove_self().unify(ty.remove_self())
except Exception as e:
self.match_cache.remove(hash(protocol))
raise IncompatibleTypesError(f"Type {self} doesn't implement protocol {protocol}: {e}")
def unify_internal(self, other: BaseType):
if not isinstance(other, TypeOperator):
raise IncompatibleTypesError()
if other.is_protocol and not self.is_protocol:
return other.unify_internal(self)
if self.is_protocol and not other.is_protocol:
return other.matches_protocol(self)
if len(self.args) < len(other.args):
return other.unify_internal(self)
assert self.is_protocol == other.is_protocol
if type(self) != type(other):
for parent in other.get_parents():
try:
......@@ -211,6 +233,7 @@ class TypeOperator(BaseType, ABC):
res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args]
res.methods = {k: v.gen_sub(this, vardict, cache) for k, v in self.methods.items()}
res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents]
res.is_protocol = self.is_protocol
return res
def to_list(self) -> List["BaseType"]:
......@@ -360,6 +383,9 @@ class Promise(TypeOperator, ABC):
@kind.setter
def kind(self, value: PromiseKind):
if value == PromiseKind.GENERATOR:
self.methods["__iter__"] = FunctionType([], self)
self.methods["__next__"] = FunctionType([], self.return_type)
self.args[1].val = value
def __str__(self):
......@@ -367,7 +393,7 @@ class Promise(TypeOperator, ABC):
def get_parents(self) -> List["BaseType"]:
if self.kind == PromiseKind.GENERATOR:
return [PyIterator(self.return_type), *super().get_parents()]
return [*super().get_parents()]
return super().get_parents()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment