Commit c5021696 authored by Tom Niget's avatar Tom Niget

Continue work on async types handling

parent a1965638
...@@ -11,6 +11,7 @@ class int: ...@@ -11,6 +11,7 @@ class int:
def __and__(self, other: Self) -> Self: ... def __and__(self, other: Self) -> Self: ...
assert int.__add__
U = TypeVar("U") U = TypeVar("U")
...@@ -22,5 +23,6 @@ class list(Generic[U]): ...@@ -22,5 +23,6 @@ class list(Generic[U]):
def first(self) -> U: ... def first(self) -> U: ...
assert list[int].first
def print(*args) -> None: ... def print(*args) -> None: ...
...@@ -2,20 +2,30 @@ from typing import Callable, TypeVar, Generic ...@@ -2,20 +2,30 @@ from typing import Callable, TypeVar, Generic
T = TypeVar("T") T = TypeVar("T")
class Fork(Generic[T]): class Forked(Generic[T]):
def get(self) -> T: ... def get(self) -> T: ...
class Task(Generic[T]):
pass
class Future(Generic[T]):
def get(self) -> Task[T]: ...
assert Forked[int].get
def fork(f: Callable[[], T]) -> Fork[T]: def fork(f: Callable[[], T]) -> Task[Forked[T]]:
# stub # stub
class Res: class Res:
get = f get = f
return Res return Res
def future(f: Callable[[], T]) -> T: def future(f: Callable[[], T]) -> Task[Future[T]]:
# stub # stub
return f() class Res:
get = f
return Res
def sync() -> None: def sync() -> None:
......
...@@ -24,7 +24,7 @@ def f(x: int): ...@@ -24,7 +24,7 @@ def f(x: int):
return x + 1 return x + 1
def fct(param): def fct(param: int):
loc = f(456) loc = f(456)
global glob global glob
loc = 789 loc = 789
......
...@@ -5,8 +5,8 @@ def fibo(n: int) -> int: ...@@ -5,8 +5,8 @@ def fibo(n: int) -> int:
return n return n
a = future(lambda: fibo(n - 1)) a = future(lambda: fibo(n - 1))
b = future(lambda: fibo(n - 2)) b = future(lambda: fibo(n - 2))
return a + b return a.get() + b.get()
if __name__ == "__main__": if __name__ == "__main__":
print(fibo(30)) # should display 832040 print(fibo(20)) # should display 832040
\ No newline at end of file \ No newline at end of file
...@@ -6,7 +6,7 @@ from typing import Iterable ...@@ -6,7 +6,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp.consts import MAPPINGS from transpiler.phases.emit_cpp.consts import MAPPINGS
from transpiler.phases.typing import TypeVariable from transpiler.phases.typing import TypeVariable
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, ForkResult from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind
from transpiler.utils import UnsupportedNodeError from transpiler.utils import UnsupportedNodeError
class UniversalVisitor: class UniversalVisitor:
...@@ -68,8 +68,21 @@ class NodeVisitor(UniversalVisitor): ...@@ -68,8 +68,21 @@ class NodeVisitor(UniversalVisitor):
yield "int" yield "int"
elif node is TY_BOOL: elif node is TY_BOOL:
yield "bool" yield "bool"
elif isinstance(node, ForkResult): elif node is TY_NONE:
yield "Forked<" yield "void"
elif isinstance(node, Promise):
yield "typon::"
if node.kind == PromiseKind.TASK:
yield "Task"
elif node.kind == PromiseKind.JOIN:
yield "Join"
elif node.kind == PromiseKind.FUTURE:
yield "Future"
elif node.kind == PromiseKind.FORKED:
yield "Forked"
else:
raise NotImplementedError(node)
yield "<"
yield from self.visit(node.return_type) yield from self.visit(node.return_type)
yield ">" yield ">"
elif isinstance(node, TypeVariable): elif isinstance(node, TypeVariable):
......
...@@ -42,6 +42,11 @@ class BlockVisitor(NodeVisitor): ...@@ -42,6 +42,11 @@ class BlockVisitor(NodeVisitor):
yield "int main() { root().call(); }" yield "int main() { root().call(); }"
return return
yield "struct {"
yield from self.visit_func_new(node)
yield f"}} {node.name};"
return
yield "struct {" yield "struct {"
yield from self.visit_func(node, CoroutineMode.FAKE) yield from self.visit_func(node, CoroutineMode.FAKE)
...@@ -78,6 +83,34 @@ class BlockVisitor(NodeVisitor): ...@@ -78,6 +83,34 @@ class BlockVisitor(NodeVisitor):
yield "}" yield "}"
yield f"}} {node.name};" yield f"}} {node.name};"
def visit_func_new(self, node: ast.FunctionDef) -> Iterable[str]:
yield from self.visit(node.type.return_type)
yield "operator()"
yield "("
for i, (arg, argty) in enumerate(zip(node.args.args, node.type.parameters)):
if i != 0:
yield ", "
yield from self.visit(argty)
yield arg.arg
yield ")"
inner_scope = node.inner_scope
yield "{"
for child in node.body:
from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, CoroutineMode.ASYNC)
for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
yield from self.visit(decl.type)
yield f" {name};"
yield from child_visitor.visit(child)
yield "}"
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]: def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
templ, args, names = self.process_args(node.args) templ, args, names = self.process_args(node.args)
if templ: if templ:
......
...@@ -4,7 +4,7 @@ from pathlib import Path ...@@ -4,7 +4,7 @@ from pathlib import Path
from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \ from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, TY_MODULE, CppType, PyList, TypeType, ForkResult TypeVariable, TY_MODULE, CppType, PyList, TypeType, Forked, Task, Future
PRELUDE.vars.update({ PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT), # "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
...@@ -24,8 +24,10 @@ PRELUDE.vars.update({ ...@@ -24,8 +24,10 @@ PRELUDE.vars.update({
"Callable": VarDecl(VarKind.LOCAL, FunctionType), "Callable": VarDecl(VarKind.LOCAL, FunctionType),
"TypeVar": VarDecl(VarKind.LOCAL, TypeVariable), "TypeVar": VarDecl(VarKind.LOCAL, TypeVariable),
"CppType": VarDecl(VarKind.LOCAL, CppType), "CppType": VarDecl(VarKind.LOCAL, CppType),
"list": VarDecl(VarKind.LOCAL, PyList), "list": VarDecl(VarKind.LOCAL, TypeType(PyList)),
"Fork": VarDecl(VarKind.LOCAL, ForkResult), "Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)),
"Task": VarDecl(VarKind.LOCAL, TypeType(Task)),
"Future": VarDecl(VarKind.LOCAL, TypeType(Future)),
}) })
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib" typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
......
...@@ -7,7 +7,7 @@ from transpiler.phases.typing.common import ScoperVisitor ...@@ -7,7 +7,7 @@ from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \
Promise Promise, TY_NONE, PromiseKind
@dataclass @dataclass
...@@ -75,13 +75,14 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -75,13 +75,14 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_FunctionDef(self, node: ast.FunctionDef): def visit_FunctionDef(self, node: ast.FunctionDef):
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args] argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
rtype = Promise(self.visit_annotation(node.returns)) rtype = Promise(self.visit_annotation(node.returns), PromiseKind.TASK)
ftype = FunctionType(argtypes, rtype) ftype = FunctionType(argtypes, rtype)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
scope = self.scope.child(ScopeKind.FUNCTION) scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype scope.obj_type = ftype
scope.function = scope scope.function = scope
node.inner_scope = scope node.inner_scope = scope
node.type = ftype
for arg, ty in zip(node.args.args, argtypes): for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty) scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body: for b in node.body:
...@@ -89,6 +90,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -89,6 +90,8 @@ class ScoperBlockVisitor(ScoperVisitor):
visitor = ScoperBlockVisitor(scope, decls) visitor = ScoperBlockVisitor(scope, decls)
visitor.visit(b) visitor.visit(b)
b.decls = decls b.decls = decls
if not scope.has_return:
rtype.return_type.unify(TY_NONE)
def visit_If(self, node: ast.If): def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
...@@ -107,8 +110,9 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -107,8 +110,9 @@ class ScoperBlockVisitor(ScoperVisitor):
raise IncompatibleTypesError("Return outside function") raise IncompatibleTypesError("Return outside function")
ftype = fct.obj_type ftype = fct.obj_type
assert isinstance(ftype, FunctionType) assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else None vtype = self.expr().visit(node.value) if node.value else TY_NONE
vtype.unify(ftype.return_type.return_type if isinstance(ftype.return_type, Promise) else ftype.return_type) vtype.unify(ftype.return_type.return_type if isinstance(ftype.return_type, Promise) else ftype.return_type)
fct.has_return = True
def visit_Global(self, node: ast.Global): def visit_Global(self, node: ast.Global):
for name in node.names: for name in node.names:
......
import abc
import ast import ast
from typing import List from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \ from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind
DUNDER = { DUNDER = {
ast.Eq: "eq", ast.Eq: "eq",
...@@ -74,10 +75,28 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -74,10 +75,28 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Call(self, node: ast.Call) -> BaseType: def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func) ftype = self.visit(node.func)
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args]) rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
actual = rtype
node.is_await = False
if isinstance(actual, Promise):
node.is_await = True
actual = actual.return_type.resolve()
if isinstance(actual, Promise) and actual.kind == PromiseKind.FORKED \
and isinstance(fty := self.scope.function.obj_type.return_type, Promise):
fty.kind = PromiseKind.JOIN
return actual
if isinstance(rtype, Promise): if isinstance(rtype, Promise):
node.is_await = True node.is_await = True
return rtype.return_type if rtype.kind == PromiseKind.FORKED \
node.is_await = False and isinstance(fty := self.scope.function.obj_type.return_type, Promise):
fty.kind = PromiseKind.JOIN
else:
return rtype.return_type
else:
node.is_await = False
return rtype return rtype
def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]): def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]):
...@@ -165,6 +184,12 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -165,6 +184,12 @@ class ScoperExprVisitor(ScoperVisitor):
return PyDict(keys[0], values[0]) return PyDict(keys[0], values[0])
def visit_Subscript(self, node: ast.Subscript) -> BaseType: def visit_Subscript(self, node: ast.Subscript) -> BaseType:
left = self.visit(node.value)
args = node.slice if type(node.slice) == tuple else [node.slice]
if isinstance(left, TypeType) and isinstance(left.type_object, abc.ABCMeta):
# generic
return TypeType(left.type_object(*[self.visit(e).type_object for e in args]))
pass
raise NotImplementedError(node) raise NotImplementedError(node)
def visit_UnaryOp(self, node: ast.UnaryOp) -> BaseType: def visit_UnaryOp(self, node: ast.UnaryOp) -> BaseType:
......
...@@ -52,6 +52,7 @@ class Scope: ...@@ -52,6 +52,7 @@ class Scope:
vars: Dict[str, VarDecl] = field(default_factory=dict) vars: Dict[str, VarDecl] = field(default_factory=dict)
children: List["Scope"] = field(default_factory=list) children: List["Scope"] = field(default_factory=list)
obj_type: Optional[BaseType] = None obj_type: Optional[BaseType] = None
has_return: bool = False
@staticmethod @staticmethod
def make_global(): def make_global():
......
import ast import ast
import dataclasses import dataclasses
from abc import ABCMeta
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, List, Dict from typing import Optional, List, Dict
...@@ -38,8 +39,8 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -38,8 +39,8 @@ class StdlibVisitor(NodeVisitorSeq):
typevars = [] typevars = []
for b in node.bases: for b in node.bases:
if isinstance(b, ast.Subscript) and isinstance(b.value, ast.Name) and b.value.id == "Generic": if isinstance(b, ast.Subscript) and isinstance(b.value, ast.Name) and b.value.id == "Generic":
if isinstance(b.slice, ast.Index): if isinstance(b.slice, ast.Name):
typevars = [b.slice.value.id] typevars = [b.slice.id]
elif isinstance(b.slice, ast.Tuple): elif isinstance(b.slice, ast.Tuple):
typevars = [n.id for n in b.slice.value.elts] typevars = [n.id for n in b.slice.value.elts]
if existing := self.scope.get(node.name): if existing := self.scope.get(node.name):
...@@ -54,6 +55,9 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -54,6 +55,9 @@ class StdlibVisitor(NodeVisitorSeq):
for stmt in node.body: for stmt in node.body:
visitor.visit(stmt) visitor.visit(stmt)
def visit_Pass(self, node: ast.Pass):
pass
def visit_FunctionDef(self, node: ast.FunctionDef): def visit_FunctionDef(self, node: ast.FunctionDef):
arg_visitor = TypeAnnotationVisitor(self.scope.child(ScopeKind.FUNCTION), self.cur_class) arg_visitor = TypeAnnotationVisitor(self.scope.child(ScopeKind.FUNCTION), self.cur_class)
arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args] arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args]
...@@ -63,12 +67,12 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -63,12 +67,12 @@ class StdlibVisitor(NodeVisitorSeq):
ty.variadic = True ty.variadic = True
#arg_types.append(TY_VARARG) #arg_types.append(TY_VARARG)
if self.cur_class: if self.cur_class:
if isinstance(self.cur_class, TypeType): if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
else:
# ty_inst = FunctionType(arg_types[1:], ret_type) # ty_inst = FunctionType(arg_types[1:], ret_type)
# self.cur_class.args[0].add_inst_member(node.name, ty_inst) # self.cur_class.args[0].add_inst_member(node.name, ty_inst)
self.cur_class.type_object.methods[node.name] = ty.gen_sub(self.cur_class.type_object, self.typevars) self.cur_class.type_object.methods[node.name] = ty.gen_sub(self.cur_class.type_object, self.typevars)
else:
self.cur_class.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert): def visit_Assert(self, node: ast.Assert):
......
import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional, List, ClassVar, Callable from enum import Enum
from typing import Dict, Optional, List, ClassVar, Callable, Any
class IncompatibleTypesError(Exception): class IncompatibleTypesError(Exception):
pass pass
@dataclass @dataclass(eq=False)
class BaseType(ABC): class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False) members: Dict[str, "BaseType"] = field(default_factory=dict, init=False)
methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False) methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
...@@ -36,27 +38,33 @@ class BaseType(ABC): ...@@ -36,27 +38,33 @@ class BaseType(ABC):
def gen_sub(self, this: "BaseType", typevars) -> "Self": def gen_sub(self, this: "BaseType", typevars) -> "Self":
return self return self
def __repr__(self):
return str(self)
def to_list(self) -> List["BaseType"]: def to_list(self) -> List["BaseType"]:
return [self] return [self]
class MagicType(BaseType):
T = typing.TypeVar("T")
class MagicType(BaseType, typing.Generic[T]):
val: T
def __init__(self, val: T):
super().__init__()
self.val = val
def unify_internal(self, other: "BaseType"): def unify_internal(self, other: "BaseType"):
if type(self) is not type(other): if type(self) != type(other) or self.val != other.val:
raise IncompatibleTypesError() raise IncompatibleTypesError()
def contains_internal(self, other: "BaseType") -> bool: def contains_internal(self, other: "BaseType") -> bool:
return False return False
def __str__(self):
return str(self.val)
cur_var = 0 cur_var = 0
@dataclass @dataclass(eq=False)
class TypeVariable(BaseType): class TypeVariable(BaseType):
name: str = field(default_factory=lambda: chr(ord('a') + cur_var)) name: str = field(default_factory=lambda: chr(ord('a') + cur_var))
resolved: Optional[BaseType] = None resolved: Optional[BaseType] = None
...@@ -85,8 +93,10 @@ class TypeVariable(BaseType): ...@@ -85,8 +93,10 @@ class TypeVariable(BaseType):
return match return match
return self return self
GenMethodFactory = Callable[["BaseType"], "FunctionType"] GenMethodFactory = Callable[["BaseType"], "FunctionType"]
@dataclass @dataclass
class TypeOperator(BaseType, ABC): class TypeOperator(BaseType, ABC):
args: List[BaseType] args: List[BaseType]
...@@ -94,6 +104,10 @@ class TypeOperator(BaseType, ABC): ...@@ -94,6 +104,10 @@ class TypeOperator(BaseType, ABC):
variadic: bool = False variadic: bool = False
gen_methods: ClassVar[Dict[str, GenMethodFactory]] = {} gen_methods: ClassVar[Dict[str, GenMethodFactory]] = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.gen_methods = {}
def __post_init__(self): def __post_init__(self):
if self.name is None: if self.name is None:
self.name = self.__class__.__name__ self.name = self.__class__.__name__
...@@ -106,7 +120,11 @@ class TypeOperator(BaseType, ABC): ...@@ -106,7 +120,11 @@ class TypeOperator(BaseType, ABC):
if len(self.args) != len(other.args) and not (self.variadic or other.variadic): if len(self.args) != len(other.args) and not (self.variadic or other.variadic):
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments") raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
for a, b in zip(self.args, other.args): for a, b in zip(self.args, other.args):
a.unify(b) if isinstance(a, BaseType) and isinstance(b, BaseType):
a.unify(b)
else:
if a != b:
raise IncompatibleTypesError(f"Cannot unify {a} and {b}")
def contains_internal(self, other: "BaseType") -> bool: def contains_internal(self, other: "BaseType") -> bool:
return any(arg.contains(other) for arg in self.args) return any(arg.contains(other) for arg in self.args)
...@@ -158,6 +176,7 @@ class FunctionType(TypeOperator): ...@@ -158,6 +176,7 @@ class FunctionType(TypeOperator):
args = "()" args = "()"
return f"{args} -> {ret}" return f"{args} -> {ret}"
class CppType(TypeOperator): class CppType(TypeOperator):
def __init__(self, name: str): def __init__(self, name: str):
super().__init__([name], name) super().__init__([name], name)
...@@ -165,10 +184,12 @@ class CppType(TypeOperator): ...@@ -165,10 +184,12 @@ class CppType(TypeOperator):
def __str__(self): def __str__(self):
return self.name return self.name
class Union(TypeOperator): class Union(TypeOperator):
def __init__(self, left: BaseType, right: BaseType): def __init__(self, left: BaseType, right: BaseType):
super().__init__([left, right], "Union") super().__init__([left, right], "Union")
class TypeType(TypeOperator): class TypeType(TypeOperator):
def __init__(self, arg: BaseType): def __init__(self, arg: BaseType):
super().__init__([arg], "Type") super().__init__([arg], "Type")
...@@ -189,6 +210,7 @@ TY_VARARG = TypeOperator([], "vararg") ...@@ -189,6 +210,7 @@ TY_VARARG = TypeOperator([], "vararg")
TY_SELF = TypeOperator([], "Self") TY_SELF = TypeOperator([], "Self")
TY_SELF.gen_sub = lambda this, typevars: this TY_SELF.gen_sub = lambda this, typevars: this
class PyList(TypeOperator): class PyList(TypeOperator):
def __init__(self, arg: BaseType): def __init__(self, arg: BaseType):
super().__init__([arg], "list") super().__init__([arg], "list")
...@@ -197,6 +219,7 @@ class PyList(TypeOperator): ...@@ -197,6 +219,7 @@ class PyList(TypeOperator):
def element_type(self): def element_type(self):
return self.args[0] return self.args[0]
class PySet(TypeOperator): class PySet(TypeOperator):
def __init__(self, arg: BaseType): def __init__(self, arg: BaseType):
super().__init__([arg], "set") super().__init__([arg], "set")
...@@ -205,6 +228,7 @@ class PySet(TypeOperator): ...@@ -205,6 +228,7 @@ class PySet(TypeOperator):
def element_type(self): def element_type(self):
return self.args[0] return self.args[0]
class PyDict(TypeOperator): class PyDict(TypeOperator):
def __init__(self, key: BaseType, value: BaseType): def __init__(self, key: BaseType, value: BaseType):
super().__init__([key, value], "dict") super().__init__([key, value], "dict")
...@@ -217,22 +241,49 @@ class PyDict(TypeOperator): ...@@ -217,22 +241,49 @@ class PyDict(TypeOperator):
def value_type(self): def value_type(self):
return self.args[1] return self.args[1]
class TupleType(TypeOperator): class TupleType(TypeOperator):
def __init__(self, args: List[BaseType]): def __init__(self, args: List[BaseType]):
super().__init__(args, "tuple") super().__init__(args, "tuple")
class ForkResult(TypeOperator):
def __init__(self, args: BaseType): class PromiseKind(Enum):
super().__init__([args], "ForkResult") TASK = 0
JOIN = 1
FUTURE = 2
FORKED = 3
class Promise(TypeOperator, ABC):
def __init__(self, ret: BaseType, kind: PromiseKind):
super().__init__([ret, MagicType(kind)])
@property @property
def return_type(self): def return_type(self) -> BaseType:
return self.args[0] return self.args[0]
class Promise(TypeOperator):
def __init__(self, args: BaseType):
super().__init__([args], "Promise")
@property @property
def return_type(self): def kind(self) -> PromiseKind:
return self.args[0] return self.args[1].val
\ No newline at end of file
@kind.setter
def kind(self, value: PromiseKind):
self.args[1].val = value
def __str__(self):
return f"{self.kind.name.lower()}<{self.return_type}>"
class Forked(Promise):
"""Only use this for type specs"""
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.FORKED)
class Task(Promise):
"""Only use this for type specs"""
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.TASK)
class Future(Promise):
"""Only use this for type specs"""
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.FUTURE)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment