Commit a89e5d44 authored by Tom Niget's avatar Tom Niget

Add support for generators

parent aacd2fb0
......@@ -25,4 +25,13 @@ class list(Generic[U]):
assert list[int].first
class Iterator(Generic[U]):
def __iter__(self) -> Self: ...
def __next__(self) -> U: ...
def next(it: Iterator[U], default: None) -> U: ...
def print(*args) -> None: ...
def range(*args) -> Iterator[int]: ...
\ No newline at end of file
......@@ -65,6 +65,8 @@ class NodeVisitor(UniversalVisitor):
yield "Future"
elif node.kind == PromiseKind.FORKED:
yield "Forked"
elif node.kind == PromiseKind.GENERATOR:
yield "Generator"
else:
raise NotImplementedError(node)
yield "<"
......
......@@ -210,11 +210,13 @@ class ExpressionVisitor(NodeVisitor):
yield from self.visit(node.orelse)
def visit_Yield(self, node: ast.Yield) -> Iterable[str]:
if CoroutineMode.GENERATOR in self.generator:
yield "co_yield"
yield from self.prec("co_yield").visit(node.value)
elif CoroutineMode.FAKE in self.generator:
yield "return"
yield from self.visit(node.value)
else:
raise NotImplementedError(node)
#if CoroutineMode.GENERATOR in self.generator:
# yield "co_yield"
# yield from self.prec("co_yield").visit(node.value)
#elif CoroutineMode.FAKE in self.generator:
# yield "return"
# yield from self.visit(node.value)
#else:
# raise NotImplementedError(node)
yield "co_yield"
yield from self.prec("co_yield").visit(node.value)
......@@ -4,7 +4,7 @@ from pathlib import Path
from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, TY_MODULE, CppType, PyList, TypeType, Forked, Task, Future
TypeVariable, TY_MODULE, CppType, PyList, TypeType, Forked, Task, Future, PyIterator
PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
......@@ -28,6 +28,7 @@ PRELUDE.vars.update({
"Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)),
"Task": VarDecl(VarKind.LOCAL, TypeType(Task)),
"Future": VarDecl(VarKind.LOCAL, TypeType(Future)),
"Iterator": VarDecl(VarKind.LOCAL, TypeType(PyIterator))
})
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
......
......@@ -7,7 +7,7 @@ from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \
Promise, TY_NONE, PromiseKind
Promise, TY_NONE, PromiseKind, TupleType
@dataclass
......@@ -64,6 +64,10 @@ class ScoperBlockVisitor(ScoperVisitor):
if self.scope.kind == ScopeKind.FUNCTION_INNER:
self.root_decls[target.id] = VarDecl(VarKind.OUTER_DECL, decl_val)
return True
elif isinstance(target, ast.Tuple):
if not (isinstance(decl_val, TupleType) and len(target.elts) == len(decl_val.args)):
raise IncompatibleTypesError(f"Cannot unpack {decl_val} into {target}")
return any(self.visit_assign_target(t, ty) for t, ty in zip(target.elts, decl_val.args))
else:
raise NotImplementedError(target)
......@@ -118,6 +122,19 @@ class ScoperBlockVisitor(ScoperVisitor):
if node.orelse:
raise NotImplementedError(node.orelse)
def visit_For(self, node: ast.For):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
node.inner_scope = scope
assert isinstance(node.target, ast.Name)
scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, TypeVariable())
self.expr().visit(node.iter)
body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
for b in node.body:
body_visitor.visit(b)
if node.orelse:
raise NotImplementedError(node.orelse)
def visit_Expr(self, node: ast.Expr):
self.expr().visit(node.value)
......
......@@ -44,6 +44,17 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Tuple(self, node: ast.Tuple) -> BaseType:
return TupleType([self.visit(e) for e in node.elts])
def visit_Yield(self, node: ast.Yield) -> BaseType:
ytype = self.visit(node.value)
ftype = self.scope.function.obj_type.return_type
assert isinstance(ftype, Promise)
assert ftype.kind == PromiseKind.TASK
ftype.kind = PromiseKind.GENERATOR
ftype.return_type.unify(ytype)
return TY_NONE
def visit_Constant(self, node: ast.Constant) -> BaseType:
if isinstance(node.value, str):
return TY_STR
......@@ -77,7 +88,7 @@ class ScoperExprVisitor(ScoperVisitor):
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
actual = rtype
node.is_await = False
if isinstance(actual, Promise):
if isinstance(actual, Promise) and actual.kind != PromiseKind.GENERATOR:
node.is_await = True
actual = actual.return_type.resolve()
......@@ -202,6 +213,3 @@ class ScoperExprVisitor(ScoperVisitor):
if then != else_:
raise NotImplementedError("IfExp with different types not handled yet")
return then
def visit_Yield(self, node: ast.Yield) -> BaseType:
raise NotImplementedError(node)
......@@ -13,6 +13,10 @@ class IncompatibleTypesError(Exception):
class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False)
methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
parents: List["BaseType"] = field(default_factory=list, init=False)
def get_parents(self) -> List["BaseType"]:
return self.parents
def resolve(self) -> "BaseType":
return self
......@@ -117,6 +121,22 @@ class TypeOperator(BaseType, ABC):
def unify_internal(self, other: BaseType):
if not isinstance(other, TypeOperator):
raise IncompatibleTypesError()
if type(self) != type(other):
for parent in other.get_parents():
try:
self.unify(parent)
except IncompatibleTypesError:
pass
else:
return
for parent in self.get_parents():
try:
parent.unify(other)
except IncompatibleTypesError:
pass
else:
return
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different type and no common parents")
if len(self.args) != len(other.args) and not (self.variadic or other.variadic):
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
for a, b in zip(self.args, other.args):
......@@ -241,6 +261,15 @@ class PyDict(TypeOperator):
def value_type(self):
return self.args[1]
class PyIterator(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "iter")
@property
def element_type(self):
return self.args[0]
class TupleType(TypeOperator):
def __init__(self, args: List[BaseType]):
......@@ -252,6 +281,7 @@ class PromiseKind(Enum):
JOIN = 1
FUTURE = 2
FORKED = 3
GENERATOR = 4
class Promise(TypeOperator, ABC):
......@@ -273,6 +303,11 @@ class Promise(TypeOperator, ABC):
def __str__(self):
return f"{self.kind.name.lower()}<{self.return_type}>"
def get_parents(self) -> List["BaseType"]:
if self.kind == PromiseKind.GENERATOR:
return [PyIterator(self.return_type), *super().get_parents()]
return super().get_parents()
class Forked(Promise):
"""Only use this for type specs"""
def __init__(self, ret: BaseType):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment