Commit e8ce3013 authored by Tom Niget's avatar Tom Niget

Add extension code emission

parent 28dc8d44
...@@ -176,7 +176,7 @@ else: ...@@ -176,7 +176,7 @@ else:
pydevd.original_excepthook = sys.excepthook pydevd.original_excepthook = sys.excepthook
def transpile(source, name="<module>", path=None): def transpile(source, name: str, path=None):
TB = f"transpiling module {cf.white(name)}" TB = f"transpiling module {cf.white(name)}"
res = ast.parse(source, type_comments=True) res = ast.parse(source, type_comments=True)
...@@ -198,5 +198,5 @@ def transpile(source, name="<module>", path=None): ...@@ -198,5 +198,5 @@ def transpile(source, name="<module>", path=None):
# disp_scope(res.scope) # disp_scope(res.scope)
code = "\n".join(filter(None, map(str, FileVisitor(Scope()).visit(res)))) code = "\n".join(filter(None, map(str, FileVisitor(Scope(), name).visit(res))))
return code return code
# coding: utf-8 # coding: utf-8
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Iterable, Optional from typing import Iterable, Optional
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
...@@ -17,7 +17,7 @@ from transpiler.phases.emit_cpp.search import SearchVisitor ...@@ -17,7 +17,7 @@ from transpiler.phases.emit_cpp.search import SearchVisitor
@dataclass @dataclass
class BlockVisitor(NodeVisitor): class BlockVisitor(NodeVisitor):
scope: Scope scope: Scope
generator: CoroutineMode = CoroutineMode.SYNC generator: CoroutineMode = field(default=CoroutineMode.SYNC, kw_only=True)
def expr(self) -> ExpressionVisitor: def expr(self) -> ExpressionVisitor:
return ExpressionVisitor(self.scope, self.generator) return ExpressionVisitor(self.scope, self.generator)
...@@ -60,7 +60,7 @@ class BlockVisitor(NodeVisitor): ...@@ -60,7 +60,7 @@ class BlockVisitor(NodeVisitor):
def visit_func_decls(self, body: list[ast.stmt], inner_scope: Scope, mode = CoroutineMode.ASYNC) -> Iterable[str]: def visit_func_decls(self, body: list[ast.stmt], inner_scope: Scope, mode = CoroutineMode.ASYNC) -> Iterable[str]:
for child in body: for child in body:
from transpiler.phases.emit_cpp.function import FunctionVisitor from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, mode) child_visitor = FunctionVisitor(inner_scope, generator=mode)
for name, decl in getattr(child, "decls", {}).items(): for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};" #yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
......
# coding: utf-8 # coding: utf-8
import ast import ast
from dataclasses import dataclass
from typing import Iterable from typing import Iterable
from transpiler.phases.emit_cpp.block import BlockVisitor from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.emit_cpp.module import ModuleVisitor, ModuleVisitor2 from transpiler.phases.emit_cpp.module import ModuleVisitor, ModuleVisitor2, ModuleVisitorExt
# noinspection PyPep8Naming # noinspection PyPep8Naming
@dataclass
class FileVisitor(BlockVisitor): class FileVisitor(BlockVisitor):
module_name: str
def visit_Module(self, node: ast.Module) -> Iterable[str]: def visit_Module(self, node: ast.Module) -> Iterable[str]:
TB = "emitting C++ code for Python module" TB = "emitting C++ code for Python module"
...@@ -23,7 +27,16 @@ class FileVisitor(BlockVisitor): ...@@ -23,7 +27,16 @@ class FileVisitor(BlockVisitor):
code = [line for stmt in node.body for line in visitor.visit(stmt)] code = [line for stmt in node.body for line in visitor.visit(stmt)]
yield from code yield from code
yield "}" yield "}"
yield "#ifdef TYPON_EXTENSION"
yield f"PYBIND11_MODULE({self.module_name}, m) {{"
yield f"m.doc() = \"Typon extension module '{self.module_name}'\";"
visitor = ModuleVisitorExt(self.scope)
code = [line for stmt in node.body for line in visitor.visit(stmt)]
yield from code
yield "}"
yield "#else"
yield "int main(int argc, char* argv[]) {" yield "int main(int argc, char* argv[]) {"
yield "py_sys::all.argv = typon::PyList<PyStr>(std::vector<PyStr>(argv, argv + argc));" yield "py_sys::all.argv = typon::PyList<PyStr>(std::vector<PyStr>(argv, argv + argc));"
yield "PROGRAMNS::root().call();" yield "PROGRAMNS::root().call();"
yield "}" yield "}"
yield "#endif"
...@@ -89,12 +89,12 @@ class FunctionVisitor(BlockVisitor): ...@@ -89,12 +89,12 @@ class FunctionVisitor(BlockVisitor):
# See the comments in visit_FunctionDef. # See the comments in visit_FunctionDef.
# A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same # A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same
# variables as the parent scope. # variables as the parent scope.
return FunctionVisitor(self.scope.child_share(), self.generator) return FunctionVisitor(self.scope.child_share(), generator=self.generator)
def emit_block(self, scope: Scope, items: Iterable[ast.stmt]) -> Iterable[str]: def emit_block(self, scope: Scope, items: Iterable[ast.stmt]) -> Iterable[str]:
yield "{" yield "{"
for child in items: for child in items:
yield from FunctionVisitor(scope, self.generator).visit(child) yield from FunctionVisitor(scope, generator=self.generator).visit(child)
yield "}" yield "}"
def visit_Break(self, node: ast.Break) -> Iterable[str]: def visit_Break(self, node: ast.Break) -> Iterable[str]:
......
...@@ -110,3 +110,18 @@ class ModuleVisitor2(NodeVisitor): ...@@ -110,3 +110,18 @@ class ModuleVisitor2(NodeVisitor):
def visit_AST(self, node: ast.AST) -> Iterable[str]: def visit_AST(self, node: ast.AST) -> Iterable[str]:
yield "" yield ""
pass pass
@dataclass
class ModuleVisitorExt(NodeVisitor):
scope: Scope
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
if getattr(node, "is_main", False):
yield from ()
return
#yield from BlockVisitor(self.scope).visit_free_func(node, FunctionEmissionKind.DEFINITION)
#yield f'm.def("{node.name}", CoroWrapper(PROGRAMNS::{node.name}));'
yield f'm.def("{node.name}", PROGRAMNS::{node.name});'
def visit_AST(self, node: ast.AST) -> Iterable[str]:
yield from ()
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment