Commit a45dfe8e authored by Tom Niget's avatar Tom Niget

Make example accelerated_fork work (fix decltype for nested async var decls)

parent 02310e9a
from typon import fork, sync
def fibo(n: int) -> int:
if n < 2:
return n
a = fibo(n - 1)
b = fibo(n - 2)
return a + b
def parallel_fibo(n: int) -> int:
if n < 2:
return n
if n < 25:
a = fibo(n - 1)
b = fibo(n - 2)
return a + b
x = fork(lambda: fibo(n - 1))
y = fork(lambda: fibo(n - 2))
sync()
return x + y
if __name__ == "__main__":
print(fibo(30)) # should display 832040
\ No newline at end of file
# coding: utf-8
import ast
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Dict
from typing import Optional, Dict, Tuple
class VarKind(Enum):
......@@ -15,7 +16,7 @@ class VarKind(Enum):
@dataclass
class VarDecl:
kind: VarKind
val: Optional[str]
val: Optional[Tuple[str, ast.AST]]
future: bool = False
......@@ -90,7 +91,7 @@ class Scope:
return self.parent.vars
return None
def declare(self, name: str, val: Optional[str] = None, future: bool = False) -> Optional[str]:
def declare(self, name: str, val: Optional[Tuple[str, ast.AST]] = None, future: bool = False) -> Optional[str]:
if self.exists_local(name):
# If the variable already exists in the current function or global scope, we don't need to declare it again.
# This is simply an assignment.
......
......@@ -147,7 +147,11 @@ class BlockVisitor(NodeVisitor):
# Hoist inner variables to the root scope.
for var, decl in child_visitor.scope.vars.items():
if decl.kind == VarKind.LOCAL: # Nested declarations become `decltype` declarations.
yield f"decltype({decl.val}) {var};"
if getattr(decl.val[1], "in_await", False):
# TODO(zdimension): really?
yield f"decltype({decl.val[0][9:]}.operator co_await().await_resume()) {var};"
else:
yield f"decltype({decl.val[0]}) {var};"
elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is.
inner_scope.vars[var] = decl
yield from child_code # Yeet back the child node code.
......@@ -165,7 +169,7 @@ class BlockVisitor(NodeVisitor):
name = self.fix_name(lvalue.id)
# if name not in self._scope.vars:
if not self.scope.exists_local(name):
yield self.scope.declare(name, " ".join(self.expr().visit(val)) if val else None,
yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None,
getattr(val, "is_future", False))
yield name
elif isinstance(lvalue, ast.Subscript):
......
......@@ -126,6 +126,7 @@ class ExpressionVisitor(NodeVisitor):
# TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator:
yield "co_await "
node.in_await = True
elif CoroutineMode.FAKE in self.generator:
func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
yield from self.prec("()").visit(func)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment