Commit ad38a7ab authored by Tom Niget's avatar Tom Niget

Implement basic scope management for variables

parent 82de5cfa
......@@ -2,9 +2,26 @@
from typon import is_cpp
glob = 5
def f(x):
return x + 1
def fct(param):
loc = 456
global glob
loc = 789
glob = 123
a = 5
b = 6
z = f(a + b) * 2
if __name__ == "__main__":
# todo: 0x55 & 7 == 5
print("C++ " if is_cpp() else "Python",
"res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5,
2+3j)
print()
\ No newline at end of file
2 + 3j)
print()
# coding: utf-8
import ast
from dataclasses import dataclass, field
from enum import Enum
from itertools import chain, zip_longest
from typing import *
......@@ -40,7 +42,7 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
def transpile(source):
tree = ast.parse(source)
# print(ast.unparse(tree))
return "\n".join(filter(None, map(str, TyponVisitor().visit(tree))))
return "\n".join(filter(None, map(str, BlockVisitor(Scope()).visit(tree))))
SYMBOLS = {
......@@ -95,6 +97,13 @@ MAPPINGS = {
"""Mapping of Python builtin constants to C++ equivalents."""
class VarKind(Enum):
"""Kind of variable."""
LOCAL = 1
GLOBAL = 2
NONLOCAL = 3
class NodeVisitor:
def visit(self, node):
"""Visit a node."""
......@@ -115,15 +124,16 @@ class NodeVisitor:
if not node.args:
return "", "()"
f_args = [(arg, f"T{i + 1}") for i, arg in enumerate(node.args)]
f_args = [(arg.arg, f"T{i + 1}") for i, arg in enumerate(node.args)]
return (
"<" + ", ".join(f"typename {t}" for _, t in f_args) + ">",
"(" + ", ".join(f"{t} {next(self.visit(n))}" for n, t in f_args) + ")"
"(" + ", ".join(f"{t} {self.fix_name(n)}" for n, t in f_args) + ")"
)
def visit_arg(self, node: ast.arg) -> Iterable[str]:
# TODO: identifiers
yield node.arg
def fix_name(self, name: str) -> str:
if name.startswith("__") and name.endswith("__"):
return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name)
class ExpressionVisitor(NodeVisitor):
......@@ -148,9 +158,7 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, type(node))
def visit_Name(self, node: ast.Name) -> Iterable[str]:
if node.id.startswith("__") and node.id.endswith("__"):
return f"py_{node.id[2:-2]}"
yield MAPPINGS.get(node.id, node.id)
yield self.fix_name(node.id)
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
# TODO: operator precedence
......@@ -245,7 +253,22 @@ class ExpressionVisitor(NodeVisitor):
yield from self.visit(node.orelse)
class TyponVisitor(NodeVisitor):
@dataclass
class Scope:
parent: Optional["Scope"] = None
vars: Dict[str, VarKind] = field(default_factory=dict)
def is_global(self):
return self.parent is None
def exists(self, name: str) -> bool:
return name in self.vars or (self.parent is not None and self.parent.exists(name))
class BlockVisitor(NodeVisitor):
def __init__(self, scope: Scope):
self._scope = scope
def visit_Module(self, node: ast.Module) -> Iterable[str]:
stmt: ast.AST
yield "#include <python/builtins.hpp>"
......@@ -259,13 +282,13 @@ class TyponVisitor(NodeVisitor):
def visit_Import(self, node: ast.Import) -> Iterable[str]:
for name in node.names:
if name == "typon":
yield "// typon import"
yield ""
else:
raise NotImplementedError(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module == "typon":
yield "// typon import"
yield ""
else:
raise NotImplementedError(node)
......@@ -277,10 +300,21 @@ class TyponVisitor(NodeVisitor):
yield f"auto {node.name}"
yield args
yield "{"
inner = BlockVisitor(Scope(self._scope, vars={arg.arg: VarKind.LOCAL for arg in node.args.args}))
for child in node.body:
yield from self.visit(child)
yield from inner.visit(child)
yield "}"
def visit_Global(self, node: ast.Global) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarKind.GLOBAL
yield ""
def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarKind.NONLOCAL
yield ""
def visit_If(self, node: ast.If) -> Iterable[str]:
if not node.orelse and compare_ast(node.test, ast.parse('__name__ == "__main__"', mode="eval").body):
yield "int main() {"
......@@ -324,7 +358,13 @@ class TyponVisitor(NodeVisitor):
def visit_lvalue(self, lvalue: ast.expr) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple):
yield f"std::tie({', '.join(flatmap(ExpressionVisitor().visit, lvalue.elts))})"
elif isinstance(lvalue, (ast.Name, ast.Subscript)):
elif isinstance(lvalue, ast.Name):
name = self.fix_name(lvalue.id)
if name not in self._scope.vars:
self._scope.vars[name] = name
yield "auto "
yield name
elif isinstance(lvalue, ast.Subscript):
yield from ExpressionVisitor().visit(lvalue)
else:
raise NotImplementedError(lvalue)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment