Commit 1afe2cb5 authored by Tom Niget's avatar Tom Niget

Fork/Join basic support

parent c5dfd645
......@@ -16,6 +16,7 @@ class VarKind(Enum):
class VarDecl:
kind: VarKind
val: Optional[str]
future: bool = False
@dataclass
......@@ -89,7 +90,7 @@ class Scope:
return self.parent.vars
return None
def declare(self, name: str, val: Optional[str] = None) -> Optional[str]:
def declare(self, name: str, val: Optional[str] = None, future: bool = False) -> Optional[str]:
if self.exists_local(name):
# If the variable already exists in the current function or global scope, we don't need to declare it again.
# This is simply an assignment.
......@@ -97,5 +98,5 @@ class Scope:
vdict, prefix = self.vars, ""
if (root_vars := self.is_root()) is not None:
vdict, prefix = root_vars, "auto " # Root scope declarations can use `auto`.
vdict[name] = VarDecl(VarKind.LOCAL, val)
vdict[name] = VarDecl(VarKind.LOCAL, val, future)
return prefix
......@@ -59,6 +59,7 @@ class CoroutineMode(Flag):
ASYNC = 4
GENERATOR = 8 | ASYNC
TASK = 16 | ASYNC
JOIN = 32 | ASYNC
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
......
......@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Iterable, Optional
from transpiler.scope import VarDecl, VarKind, Scope
from transpiler.visitors import CoroutineMode, NodeVisitor, flatmap
from transpiler.visitors import CoroutineMode, NodeVisitor, flatmap, compare_ast
from transpiler.visitors.expr import ExpressionVisitor
from transpiler.visitors.search import SearchVisitor
......@@ -24,7 +24,7 @@ class BlockVisitor(NodeVisitor):
class YieldVisitor(SearchVisitor):
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
yield CoroutineMode.GENERATOR
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
......@@ -32,9 +32,17 @@ class BlockVisitor(NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_yield = YieldVisitor().match(node.body)
yield from self.visit_func(node, CoroutineMode.GENERATOR if has_yield else CoroutineMode.TASK)
if has_yield:
def visit_Call(self, node: ast.Call):
func = node.func
if compare_ast(func, ast.parse('fork', mode="eval").body):
yield CoroutineMode.JOIN
yield from ()
func_type = YieldVisitor().match(node.body)
if func_type is False:
func_type = CoroutineMode.TASK
yield from self.visit_func(node, func_type)
if func_type == CoroutineMode.GENERATOR:
templ, args, names = self.process_args(node.args)
if templ:
yield "template"
......@@ -86,6 +94,8 @@ class BlockVisitor(NodeVisitor):
yield "Task"
elif CoroutineMode.GENERATOR in generator:
yield "Generator"
elif CoroutineMode.JOIN in generator:
yield "Join"
yield f"<decltype(sync({', '.join(names)}))>"
yield "{"
inner_scope = self.scope.function(vars={node.name: VarDecl(VarKind.SELF, None)})
......@@ -143,7 +153,7 @@ class BlockVisitor(NodeVisitor):
yield from child_code # Yeet back the child node code.
if CoroutineMode.FAKE in generator:
yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements.
elif CoroutineMode.TASK in generator:
elif CoroutineMode.ASYNC in generator and CoroutineMode.GENERATOR not in generator:
if not has_return:
yield "co_return;"
yield "}"
......@@ -155,7 +165,8 @@ class BlockVisitor(NodeVisitor):
name = self.fix_name(lvalue.id)
# if name not in self._scope.vars:
if not self.scope.exists_local(name):
yield self.scope.declare(name, " ".join(self.expr().visit(val)) if val else None)
yield self.scope.declare(name, " ".join(self.expr().visit(val)) if val else None,
getattr(val, "is_future", False))
yield name
elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue)
......
......@@ -5,7 +5,7 @@ from typing import List, Iterable
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.scope import VarKind, Scope
from transpiler.visitors import CoroutineMode, NodeVisitor, join
from transpiler.visitors import CoroutineMode, NodeVisitor, join, compare_ast
class PrecedenceContext:
......@@ -74,8 +74,11 @@ class ExpressionVisitor(NodeVisitor):
def visit_Name(self, node: ast.Name) -> Iterable[str]:
res = self.fix_name(node.id)
if (decl := self.scope.get(res)) and decl.kind == VarKind.SELF:
res = "(*this)"
if (decl := self.scope.get(res)):
if decl.kind == VarKind.SELF:
res = "(*this)"
elif decl.future and CoroutineMode.ASYNC in self.generator:
res = f"{res}.get()"
yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
......@@ -95,6 +98,28 @@ class ExpressionVisitor(NodeVisitor):
if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs")
func = node.func
if compare_ast(func, ast.parse('fork', mode="eval").body):
assert len(node.args) == 1
arg = node.args[0]
assert isinstance(arg, ast.Lambda)
node.is_future = True
vis = self.reset()
vis.generator = CoroutineMode.SYNC
# todo: bad code
if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::fork("
yield from vis.visit(arg.body)
yield ")"
return
elif CoroutineMode.FAKE in self.generator:
yield from self.visit(arg.body)
return
elif compare_ast(func, ast.parse('sync', mode="eval").body):
if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::Sync()"
elif CoroutineMode.FAKE in self.generator:
yield from ()
return
# TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator:
yield "co_await "
......
......@@ -3,9 +3,9 @@ from typing import Callable, TypeVar
T = TypeVar("T")
def fork(_f: Callable[[], T]) -> T:
def fork(f: Callable[[], T]) -> T:
# stub
pass
return f()
def sync():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment