Commit ad38a7ab authored by Tom Niget's avatar Tom Niget

Implement basic scope management for variables

parent 82de5cfa
...@@ -2,9 +2,26 @@ ...@@ -2,9 +2,26 @@
from typon import is_cpp from typon import is_cpp
glob = 5
def f(x):
return x + 1
def fct(param):
loc = 456
global glob
loc = 789
glob = 123
a = 5
b = 6
z = f(a + b) * 2
if __name__ == "__main__": if __name__ == "__main__":
# todo: 0x55 & 7 == 5 # todo: 0x55 & 7 == 5
print("C++ " if is_cpp() else "Python", print("C++ " if is_cpp() else "Python",
"res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5, "res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5,
2+3j) 2 + 3j)
print() print()
\ No newline at end of file
# coding: utf-8 # coding: utf-8
import ast import ast
from dataclasses import dataclass, field
from enum import Enum
from itertools import chain, zip_longest from itertools import chain, zip_longest
from typing import * from typing import *
...@@ -40,7 +42,7 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]: ...@@ -40,7 +42,7 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
def transpile(source): def transpile(source):
tree = ast.parse(source) tree = ast.parse(source)
# print(ast.unparse(tree)) # print(ast.unparse(tree))
return "\n".join(filter(None, map(str, TyponVisitor().visit(tree)))) return "\n".join(filter(None, map(str, BlockVisitor(Scope()).visit(tree))))
SYMBOLS = { SYMBOLS = {
...@@ -95,6 +97,13 @@ MAPPINGS = { ...@@ -95,6 +97,13 @@ MAPPINGS = {
"""Mapping of Python builtin constants to C++ equivalents.""" """Mapping of Python builtin constants to C++ equivalents."""
class VarKind(Enum):
"""Kind of variable."""
LOCAL = 1
GLOBAL = 2
NONLOCAL = 3
class NodeVisitor: class NodeVisitor:
def visit(self, node): def visit(self, node):
"""Visit a node.""" """Visit a node."""
...@@ -115,15 +124,16 @@ class NodeVisitor: ...@@ -115,15 +124,16 @@ class NodeVisitor:
if not node.args: if not node.args:
return "", "()" return "", "()"
f_args = [(arg, f"T{i + 1}") for i, arg in enumerate(node.args)] f_args = [(arg.arg, f"T{i + 1}") for i, arg in enumerate(node.args)]
return ( return (
"<" + ", ".join(f"typename {t}" for _, t in f_args) + ">", "<" + ", ".join(f"typename {t}" for _, t in f_args) + ">",
"(" + ", ".join(f"{t} {next(self.visit(n))}" for n, t in f_args) + ")" "(" + ", ".join(f"{t} {self.fix_name(n)}" for n, t in f_args) + ")"
) )
def visit_arg(self, node: ast.arg) -> Iterable[str]: def fix_name(self, name: str) -> str:
# TODO: identifiers if name.startswith("__") and name.endswith("__"):
yield node.arg return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name)
class ExpressionVisitor(NodeVisitor): class ExpressionVisitor(NodeVisitor):
...@@ -148,9 +158,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -148,9 +158,7 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, type(node)) raise NotImplementedError(node, type(node))
def visit_Name(self, node: ast.Name) -> Iterable[str]: def visit_Name(self, node: ast.Name) -> Iterable[str]:
if node.id.startswith("__") and node.id.endswith("__"): yield self.fix_name(node.id)
return f"py_{node.id[2:-2]}"
yield MAPPINGS.get(node.id, node.id)
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
# TODO: operator precedence # TODO: operator precedence
...@@ -245,7 +253,22 @@ class ExpressionVisitor(NodeVisitor): ...@@ -245,7 +253,22 @@ class ExpressionVisitor(NodeVisitor):
yield from self.visit(node.orelse) yield from self.visit(node.orelse)
class TyponVisitor(NodeVisitor): @dataclass
class Scope:
parent: Optional["Scope"] = None
vars: Dict[str, VarKind] = field(default_factory=dict)
def is_global(self):
return self.parent is None
def exists(self, name: str) -> bool:
return name in self.vars or (self.parent is not None and self.parent.exists(name))
class BlockVisitor(NodeVisitor):
def __init__(self, scope: Scope):
self._scope = scope
def visit_Module(self, node: ast.Module) -> Iterable[str]: def visit_Module(self, node: ast.Module) -> Iterable[str]:
stmt: ast.AST stmt: ast.AST
yield "#include <python/builtins.hpp>" yield "#include <python/builtins.hpp>"
...@@ -259,13 +282,13 @@ class TyponVisitor(NodeVisitor): ...@@ -259,13 +282,13 @@ class TyponVisitor(NodeVisitor):
def visit_Import(self, node: ast.Import) -> Iterable[str]: def visit_Import(self, node: ast.Import) -> Iterable[str]:
for name in node.names: for name in node.names:
if name == "typon": if name == "typon":
yield "// typon import" yield ""
else: else:
raise NotImplementedError(node) raise NotImplementedError(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]: def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module == "typon": if node.module == "typon":
yield "// typon import" yield ""
else: else:
raise NotImplementedError(node) raise NotImplementedError(node)
...@@ -277,10 +300,21 @@ class TyponVisitor(NodeVisitor): ...@@ -277,10 +300,21 @@ class TyponVisitor(NodeVisitor):
yield f"auto {node.name}" yield f"auto {node.name}"
yield args yield args
yield "{" yield "{"
inner = BlockVisitor(Scope(self._scope, vars={arg.arg: VarKind.LOCAL for arg in node.args.args}))
for child in node.body: for child in node.body:
yield from self.visit(child) yield from inner.visit(child)
yield "}" yield "}"
def visit_Global(self, node: ast.Global) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarKind.GLOBAL
yield ""
def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarKind.NONLOCAL
yield ""
def visit_If(self, node: ast.If) -> Iterable[str]: def visit_If(self, node: ast.If) -> Iterable[str]:
if not node.orelse and compare_ast(node.test, ast.parse('__name__ == "__main__"', mode="eval").body): if not node.orelse and compare_ast(node.test, ast.parse('__name__ == "__main__"', mode="eval").body):
yield "int main() {" yield "int main() {"
...@@ -324,7 +358,13 @@ class TyponVisitor(NodeVisitor): ...@@ -324,7 +358,13 @@ class TyponVisitor(NodeVisitor):
def visit_lvalue(self, lvalue: ast.expr) -> Iterable[str]: def visit_lvalue(self, lvalue: ast.expr) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple): if isinstance(lvalue, ast.Tuple):
yield f"std::tie({', '.join(flatmap(ExpressionVisitor().visit, lvalue.elts))})" yield f"std::tie({', '.join(flatmap(ExpressionVisitor().visit, lvalue.elts))})"
elif isinstance(lvalue, (ast.Name, ast.Subscript)): elif isinstance(lvalue, ast.Name):
name = self.fix_name(lvalue.id)
if name not in self._scope.vars:
self._scope.vars[name] = name
yield "auto "
yield name
elif isinstance(lvalue, ast.Subscript):
yield from ExpressionVisitor().visit(lvalue) yield from ExpressionVisitor().visit(lvalue)
else: else:
raise NotImplementedError(lvalue) raise NotImplementedError(lvalue)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment