Commit 1afe2cb5 authored by Tom Niget's avatar Tom Niget

Fork/Join basic support

parent c5dfd645
...@@ -16,6 +16,7 @@ class VarKind(Enum): ...@@ -16,6 +16,7 @@ class VarKind(Enum):
class VarDecl: class VarDecl:
kind: VarKind kind: VarKind
val: Optional[str] val: Optional[str]
future: bool = False
@dataclass @dataclass
...@@ -89,7 +90,7 @@ class Scope: ...@@ -89,7 +90,7 @@ class Scope:
return self.parent.vars return self.parent.vars
return None return None
def declare(self, name: str, val: Optional[str] = None) -> Optional[str]: def declare(self, name: str, val: Optional[str] = None, future: bool = False) -> Optional[str]:
if self.exists_local(name): if self.exists_local(name):
# If the variable already exists in the current function or global scope, we don't need to declare it again. # If the variable already exists in the current function or global scope, we don't need to declare it again.
# This is simply an assignment. # This is simply an assignment.
...@@ -97,5 +98,5 @@ class Scope: ...@@ -97,5 +98,5 @@ class Scope:
vdict, prefix = self.vars, "" vdict, prefix = self.vars, ""
if (root_vars := self.is_root()) is not None: if (root_vars := self.is_root()) is not None:
vdict, prefix = root_vars, "auto " # Root scope declarations can use `auto`. vdict, prefix = root_vars, "auto " # Root scope declarations can use `auto`.
vdict[name] = VarDecl(VarKind.LOCAL, val) vdict[name] = VarDecl(VarKind.LOCAL, val, future)
return prefix return prefix
...@@ -59,6 +59,7 @@ class CoroutineMode(Flag): ...@@ -59,6 +59,7 @@ class CoroutineMode(Flag):
ASYNC = 4 ASYNC = 4
GENERATOR = 8 | ASYNC GENERATOR = 8 | ASYNC
TASK = 16 | ASYNC TASK = 16 | ASYNC
JOIN = 32 | ASYNC
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]: def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
......
...@@ -4,7 +4,7 @@ from dataclasses import dataclass ...@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Iterable, Optional from typing import Iterable, Optional
from transpiler.scope import VarDecl, VarKind, Scope from transpiler.scope import VarDecl, VarKind, Scope
from transpiler.visitors import CoroutineMode, NodeVisitor, flatmap from transpiler.visitors import CoroutineMode, NodeVisitor, flatmap, compare_ast
from transpiler.visitors.expr import ExpressionVisitor from transpiler.visitors.expr import ExpressionVisitor
from transpiler.visitors.search import SearchVisitor from transpiler.visitors.search import SearchVisitor
...@@ -24,7 +24,7 @@ class BlockVisitor(NodeVisitor): ...@@ -24,7 +24,7 @@ class BlockVisitor(NodeVisitor):
class YieldVisitor(SearchVisitor): class YieldVisitor(SearchVisitor):
def visit_Yield(self, node: ast.Yield) -> bool: def visit_Yield(self, node: ast.Yield) -> bool:
yield True yield CoroutineMode.GENERATOR
def visit_FunctionDef(self, node: ast.FunctionDef): def visit_FunctionDef(self, node: ast.FunctionDef):
yield from () yield from ()
...@@ -32,9 +32,17 @@ class BlockVisitor(NodeVisitor): ...@@ -32,9 +32,17 @@ class BlockVisitor(NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
yield from () yield from ()
has_yield = YieldVisitor().match(node.body) def visit_Call(self, node: ast.Call):
yield from self.visit_func(node, CoroutineMode.GENERATOR if has_yield else CoroutineMode.TASK) func = node.func
if has_yield: if compare_ast(func, ast.parse('fork', mode="eval").body):
yield CoroutineMode.JOIN
yield from ()
func_type = YieldVisitor().match(node.body)
if func_type is False:
func_type = CoroutineMode.TASK
yield from self.visit_func(node, func_type)
if func_type == CoroutineMode.GENERATOR:
templ, args, names = self.process_args(node.args) templ, args, names = self.process_args(node.args)
if templ: if templ:
yield "template" yield "template"
...@@ -86,6 +94,8 @@ class BlockVisitor(NodeVisitor): ...@@ -86,6 +94,8 @@ class BlockVisitor(NodeVisitor):
yield "Task" yield "Task"
elif CoroutineMode.GENERATOR in generator: elif CoroutineMode.GENERATOR in generator:
yield "Generator" yield "Generator"
elif CoroutineMode.JOIN in generator:
yield "Join"
yield f"<decltype(sync({', '.join(names)}))>" yield f"<decltype(sync({', '.join(names)}))>"
yield "{" yield "{"
inner_scope = self.scope.function(vars={node.name: VarDecl(VarKind.SELF, None)}) inner_scope = self.scope.function(vars={node.name: VarDecl(VarKind.SELF, None)})
...@@ -143,7 +153,7 @@ class BlockVisitor(NodeVisitor): ...@@ -143,7 +153,7 @@ class BlockVisitor(NodeVisitor):
yield from child_code # Yeet back the child node code. yield from child_code # Yeet back the child node code.
if CoroutineMode.FAKE in generator: if CoroutineMode.FAKE in generator:
yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements. yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements.
elif CoroutineMode.TASK in generator: elif CoroutineMode.ASYNC in generator and CoroutineMode.GENERATOR not in generator:
if not has_return: if not has_return:
yield "co_return;" yield "co_return;"
yield "}" yield "}"
...@@ -155,7 +165,8 @@ class BlockVisitor(NodeVisitor): ...@@ -155,7 +165,8 @@ class BlockVisitor(NodeVisitor):
name = self.fix_name(lvalue.id) name = self.fix_name(lvalue.id)
# if name not in self._scope.vars: # if name not in self._scope.vars:
if not self.scope.exists_local(name): if not self.scope.exists_local(name):
yield self.scope.declare(name, " ".join(self.expr().visit(val)) if val else None) yield self.scope.declare(name, " ".join(self.expr().visit(val)) if val else None,
getattr(val, "is_future", False))
yield name yield name
elif isinstance(lvalue, ast.Subscript): elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue) yield from self.expr().visit(lvalue)
......
...@@ -5,7 +5,7 @@ from typing import List, Iterable ...@@ -5,7 +5,7 @@ from typing import List, Iterable
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.scope import VarKind, Scope from transpiler.scope import VarKind, Scope
from transpiler.visitors import CoroutineMode, NodeVisitor, join from transpiler.visitors import CoroutineMode, NodeVisitor, join, compare_ast
class PrecedenceContext: class PrecedenceContext:
...@@ -74,8 +74,11 @@ class ExpressionVisitor(NodeVisitor): ...@@ -74,8 +74,11 @@ class ExpressionVisitor(NodeVisitor):
def visit_Name(self, node: ast.Name) -> Iterable[str]: def visit_Name(self, node: ast.Name) -> Iterable[str]:
res = self.fix_name(node.id) res = self.fix_name(node.id)
if (decl := self.scope.get(res)) and decl.kind == VarKind.SELF: if (decl := self.scope.get(res)):
if decl.kind == VarKind.SELF:
res = "(*this)" res = "(*this)"
elif decl.future and CoroutineMode.ASYNC in self.generator:
res = f"{res}.get()"
yield res yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
...@@ -95,6 +98,28 @@ class ExpressionVisitor(NodeVisitor): ...@@ -95,6 +98,28 @@ class ExpressionVisitor(NodeVisitor):
if getattr(node, "kwargs", None): if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs") raise NotImplementedError(node, "kwargs")
func = node.func func = node.func
if compare_ast(func, ast.parse('fork', mode="eval").body):
assert len(node.args) == 1
arg = node.args[0]
assert isinstance(arg, ast.Lambda)
node.is_future = True
vis = self.reset()
vis.generator = CoroutineMode.SYNC
# todo: bad code
if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::fork("
yield from vis.visit(arg.body)
yield ")"
return
elif CoroutineMode.FAKE in self.generator:
yield from self.visit(arg.body)
return
elif compare_ast(func, ast.parse('sync', mode="eval").body):
if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::Sync()"
elif CoroutineMode.FAKE in self.generator:
yield from ()
return
# TODO: precedence needed? # TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator: if CoroutineMode.ASYNC in self.generator:
yield "co_await " yield "co_await "
......
...@@ -3,9 +3,9 @@ from typing import Callable, TypeVar ...@@ -3,9 +3,9 @@ from typing import Callable, TypeVar
T = TypeVar("T") T = TypeVar("T")
def fork(_f: Callable[[], T]) -> T: def fork(f: Callable[[], T]) -> T:
# stub # stub
pass return f()
def sync(): def sync():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment