Commit 90409113 authored by Tom Niget's avatar Tom Niget

work

parent 8fd5e9fb
...@@ -64,3 +64,5 @@ fonction : type concret qui a un type caché qui correspond à ce qui sera gén ...@@ -64,3 +64,5 @@ fonction : type concret qui a un type caché qui correspond à ce qui sera gén
pareil pour les classes pareil pour les classes
donc on peut passer des types génériques puisqu'ils sont wrappés par un type concret donc on peut passer des types génériques puisqu'ils sont wrappés par un type concret
----------------
faire des tags v0.1 v0.2 pour proto 1 et version qui marche
\ No newline at end of file
...@@ -20,3 +20,7 @@ ...@@ -20,3 +20,7 @@
# b = Thing[str]("abc") # b = Thing[str]("abc")
# print(a) # print(a)
# print(b) # print(b)
if __name__ == "__main__":
x = 5
\ No newline at end of file
...@@ -13,8 +13,10 @@ import colorama ...@@ -13,8 +13,10 @@ import colorama
from transpiler.phases.desugar_compare import DesugarCompare from transpiler.phases.desugar_compare import DesugarCompare
from transpiler.phases.desugar_op import DesugarOp from transpiler.phases.desugar_op import DesugarOp
from transpiler.phases.emit_cpp.module import emit_module
from transpiler.phases.typing import PRELUDE from transpiler.phases.typing import PRELUDE
from transpiler.phases.typing.modules import parse_module from transpiler.phases.typing.modules import parse_module
from transpiler.phases.typing.stdlib import StdlibVisitor
colorama.init() colorama.init()
...@@ -182,19 +184,18 @@ typon_std = Path(__file__).parent.parent / "stdlib" ...@@ -182,19 +184,18 @@ typon_std = Path(__file__).parent.parent / "stdlib"
parse_module("builtins", typon_std, PRELUDE) parse_module("builtins", typon_std, PRELUDE)
def transpile(source, name: str, path=None): def transpile(source, name: str, path: Path):
__TB__ = f"transpiling module {cf.white(name)}" __TB__ = f"transpiling module {cf.white(name)}"
res = ast.parse(source, type_comments=True)
exit()
IfMainVisitor().visit(res)
res = DesugarWith().visit(res)
res = DesugarCompare().visit(res)
res = DesugarOp().visit(res)
#ScoperBlockVisitor().visit(res)
# print(res.scope) def preprocess(node):
IfMainVisitor().visit(node)
node = DesugarWith().visit(node)
node = DesugarCompare().visit(node)
node = DesugarOp().visit(node)
return node
module = parse_module(path.stem, path.parent, preprocess=preprocess)
# display each scope
def disp_scope(scope, indent=0): def disp_scope(scope, indent=0):
debug(" " * indent, scope.kind) debug(" " * indent, scope.kind)
for child in scope.children: for child in scope.children:
...@@ -202,7 +203,29 @@ def transpile(source, name: str, path=None): ...@@ -202,7 +203,29 @@ def transpile(source, name: str, path=None):
for var in scope.vars.items(): for var in scope.vars.items():
debug(" " * (indent + 1), var) debug(" " * (indent + 1), var)
disp_scope(res.scope) def main_module():
yield from emit_module(module)
yield "#ifdef TYPON_EXTENSION"
# yield f"PYBIND11_MODULE({self.module_name}, m) {{"
# yield f"m.doc() = \"Typon extension module '{self.module_name}'\";"
# visitor = ModuleVisitorExt(self.scope)
# code = [line for stmt in node.body for line in visitor.visit(stmt)]
# yield from code
# yield "}"
yield "#else"
yield "typon::Root root() const {"
yield f"co_await dot(PROGRAMNS::{module.name()}, main)();"
yield "}"
yield "int main(int argc, char* argv[]) {"
yield "py_sys::all.argv = typon::PyList<PyStr>(std::vector<PyStr>(argv, argv + argc));"
yield f"root().call();"
yield "}"
yield "#endif"
code = "\n".join(filter(None, main_module()))
return code
exit() exit()
assert isinstance(res, ast.Module) assert isinstance(res, ast.Module)
......
# coding: utf-8
import ast
import enum
from enum import Flag
from itertools import chain
from typing import Iterable
from transpiler.phases.emit_cpp.consts import MAPPINGS
from transpiler.phases.typing import TypeVariable
from transpiler.phases.typing.exceptions import UnresolvedTypeVariableError
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind, TY_STR, UserType, \
TypeType, TypeOperator, TY_FLOAT, FunctionType, UnionType
from transpiler.utils import UnsupportedNodeError, highlight
class UniversalVisitor:
def visit(self, node):
"""Visit a node."""
__TB__ = f"emitting C++ code for {highlight(node)}"
# __TB_SKIP__ = True
if type(node) == list:
for n in node:
yield from self.visit(n)
else:
for parent in node.__class__.__mro__:
if visitor := getattr(self, 'visit_' + parent.__name__, None):
yield from visitor(node)
break
else:
yield from self.missing_impl(node)
def missing_impl(self, node):
raise UnsupportedNodeError(node)
class NodeVisitor(UniversalVisitor):
def process_args(self, node: ast.arguments) -> (str, str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
if getattr(node, field, None):
raise NotImplementedError(node, field)
if not node.args:
return "", "()", []
f_args = [(self.fix_name(arg.arg), f"T{i + 1}") for i, arg in enumerate(node.args)]
return (
"<" + ", ".join(f"typename {t}" for _, t in f_args) + ">",
"(" + ", ".join(f"{t} {n}" for n, t in f_args) + ")",
[n for n, _ in f_args]
)
def fix_name(self, name: str) -> str:
if name.startswith("__") and name.endswith("__"):
return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name)
def visit_BaseType(self, node: BaseType) -> Iterable[str]:
node = node.resolve()
if node is TY_INT:
yield "int"
elif node is TY_FLOAT:
yield "double"
elif node is TY_BOOL:
yield "bool"
elif node is TY_NONE:
yield "void"
elif node is TY_STR:
yield "PyStr"
elif isinstance(node, UserType):
# if node.is_reference:
# yield "PyObj<"
#yield "auto"
yield f"referencemodel::Rc<__main____oo<>::{node.name}__oo<>::Obj>"
# if node.is_reference:
# yield "::py_type>"
elif isinstance(node, TypeType):
yield "auto" # TODO
elif isinstance(node, FunctionType):
yield "std::function<"
yield from self.visit(node.return_type)
yield "("
yield from join(", ", map(self.visit, node.parameters))
yield ")>"
elif isinstance(node, Promise):
yield "typon::"
if node.kind == PromiseKind.TASK:
yield "Task"
elif node.kind == PromiseKind.JOIN:
yield "Join"
elif node.kind == PromiseKind.FUTURE:
yield "Future"
elif node.kind == PromiseKind.FORKED:
yield "Forked"
elif node.kind == PromiseKind.GENERATOR:
yield "Generator"
else:
raise NotImplementedError(node)
yield "<"
yield from self.visit(node.return_type)
yield ">"
elif isinstance(node, TypeVariable):
# yield f"TYPEVAR_{node.name}";return
raise UnresolvedTypeVariableError(node)
elif isinstance(node, UnionType) and (ty := node.is_optional()):
yield "std::optional<"
yield from self.visit(ty)
yield ">"
elif isinstance(node, TypeOperator):
yield "typon::Py" + node.name.title()
if node.args:
yield "<"
yield from join(", ", map(self.visit, node.args))
yield ">"
else:
raise NotImplementedError(node)
class CoroutineMode(Flag):
SYNC = 1
FAKE = 2 | SYNC
ASYNC = 4
GENERATOR = 8 | ASYNC
TASK = 16 | ASYNC
JOIN = 32 | ASYNC
class FunctionEmissionKind(enum.Enum):
DECLARATION = enum.auto()
DEFINITION = enum.auto()
METHOD = enum.auto()
LAMBDA = enum.auto()
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
items = iter(items)
try:
it = next(items)
if type(it) is str:
yield it
else:
yield from it
for item in items:
yield sep
it = item
if type(it) is str:
yield it
else:
yield from it
except StopIteration:
return
def flatmap(f, items):
return chain.from_iterable(map(f, items))
# coding: utf-8 from typing import Iterable
import ast
from typing import Iterable
def emit_class(clazz) -> Iterable[str]:
from dataclasses import dataclass yield f"template <typename _Base0 = referencemodel::object>"
from transpiler.phases.typing.scope import Scope yield f"struct {node.name}__oo : referencemodel::classtype<_Base0, {node.name}__oo<>> {{"
from transpiler.phases.emit_cpp import NodeVisitor, FunctionEmissionKind yield f"static constexpr std::string_view name = \"{node.name}\";"
inner = ClassInnerVisitor2(node.inner_scope)
class ClassVisitor(NodeVisitor): for stmt in node.body:
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]: yield from inner.visit(stmt)
if gen_instances := getattr(node, "gen_instances", None):
for args, inst in gen_instances.items(): yield f"struct Obj : referencemodel::instance<{node.name}__oo<>, Obj> {{"
yield from self.visit_ClassDef(inst)
return inner = ClassInnerVisitor4(node.inner_scope)
for stmt in node.body:
yield f"struct {node.name}_s;" yield from inner.visit(stmt)
yield f"extern {node.name}_s {node.name};"
yield f"struct {node.name}_s {{" yield "template <typename... U>"
yield "Obj(U&&... args) {"
yield "struct py_type {" yield "dot(this, __init__)(this, std::forward<U>(args)...);"
inner = ClassInnerVisitor(node.inner_scope) yield "}"
for stmt in node.body:
yield from inner.visit(stmt)
yield "};"
yield "template<typename... T> py_type(T&&... args) {"
yield "__init__(this, std::forward<T>(args)...);" yield "template <typename... T>"
yield "}" yield "auto operator() (T&&... args) const {"
yield "py_type() {}" yield "return referencemodel::rc(Obj{std::forward<T>(args)...});"
yield "py_type(const py_type&) = delete;" yield "}"
yield "py_type(py_type&&) = delete;"
yield f"}};"
if getattr(node.type, "is_enum", False): yield f"static constexpr {node.name}__oo<> {node.name} {{}};"
yield "int value;" yield f"static_assert(sizeof {node.name} == 1);"
yield "operator int() const { return value; }" \ No newline at end of file
yield "void py_repr(std::ostream &s) const {"
yield f's << "{node.name}.";'
yield "}"
else:
yield "void py_repr(std::ostream &s) const {"
yield f's << "{node.name}(";'
for i, (name, memb) in enumerate(node.type.get_members().items()):
if i != 0:
yield 's << ", ";'
yield f's << "{name}=";'
yield f"repr_to({name}, s);"
yield "s << ')';"
yield "}"
yield "void py_print(std::ostream &s) const {"
yield "py_repr(s);"
yield "}"
yield "};"
yield "template<typename... T> auto operator()(T&&... args) {"
yield "return pyobj<py_type>(std::forward<T>(args)...);"
yield "}"
# outer = ClassOuterVisitor(node.inner_scope)
# for stmt in node.body:
# yield from outer.visit(stmt)
yield f"}} {node.name};"
@dataclass
class ClassInnerVisitor(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
member = self.scope.obj_type.fields[node.target.id]
yield from self.visit(member.type)
yield node.target.id
yield ";"
def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
yield "static constexpr"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_Assign(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
# yield "struct {"
# yield "type* self;"
# from transpiler.phases.emit_cpp.block import BlockVisitor
# yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD, True)
# yield f"}} {node.name} {{ this }};"
yield f"struct {node.name}_m_s : referencemodel::method {{"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD)
yield f"}} static constexpr {node.name} {{}};"
yield ""
@dataclass
class ClassOuterVisitor(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
yield ""
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {"
yield "template<typename Self, typename... T>"
yield "auto operator()(Self self, T&&... args) {"
yield f"return dotp(self, {node.name})(std::forward<T>(args)...);"
yield "}"
yield f"}} {node.name};"
yield ""
# yield "struct : function {"
# from transpiler.phases.emit_cpp.block import BlockVisitor
# yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD)
# yield f"}} static constexpr {node.name} {{}};"
@dataclass
class ClassInnerVisitor2(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
yield ""
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct : referencemodel::method {"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD)
yield f"}} static constexpr {node.name} {{}};"
@dataclass
class ClassInnerVisitor4(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
member = self.scope.obj_type.fields[node.target.id]
yield from self.visit(member.type)
yield node.target.id
yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield ""
\ No newline at end of file
import ast
from itertools import chain
from typing import Iterable
from transpiler import highlight
import transpiler.phases.typing.types as types
from transpiler.phases.typing.exceptions import UnresolvedTypeVariableError
from transpiler.phases.typing.types import BaseType
from transpiler.utils import UnsupportedNodeError
class UniversalVisitor:
def visit(self, node):
"""Visit a node."""
__TB__ = f"emitting C++ code for {highlight(node)}"
# __TB_SKIP__ = True
if type(node) == list:
for n in node:
yield from self.visit(n)
else:
for parent in node.__class__.__mro__:
if visitor := getattr(self, 'visit_' + parent.__name__, None):
yield from visitor(node)
break
else:
yield from self.missing_impl(node)
def missing_impl(self, node):
raise UnsupportedNodeError(node)
class NodeVisitor(UniversalVisitor):
def process_args(self, node: ast.arguments) -> (str, str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
if getattr(node, field, None):
raise NotImplementedError(node, field)
if not node.args:
return "", "()", []
f_args = [(self.fix_name(arg.arg), f"T{i + 1}") for i, arg in enumerate(node.args)]
return (
"<" + ", ".join(f"typename {t}" for _, t in f_args) + ">",
"(" + ", ".join(f"{t} {n}" for n, t in f_args) + ")",
[n for n, _ in f_args]
)
def fix_name(self, name: str) -> str:
if name.startswith("__") and name.endswith("__"):
return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name)
def visit_BaseType(self, node: BaseType) -> Iterable[str]:
node = node.resolve()
match node:
case types.TY_INT:
yield "int"
case types.TY_FLOAT:
yield "double"
case types.TY_BOOL:
yield "bool"
case types.TY_NONE:
yield "void"
case types.TY_STR:
yield "PyStr"
case types.TypeVariable(name):
raise UnresolvedTypeVariableError(node)
case _:
raise NotImplementedError(node)
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
items = iter(items)
try:
it = next(items)
if type(it) is str:
yield it
else:
yield from it
for item in items:
yield sep
it = item
if type(it) is str:
yield it
else:
yield from it
except StopIteration:
return
def flatmap(f, items):
return chain.from_iterable(map(f, items))
# coding: utf-8
import ast
import enum
from enum import Flag
from itertools import chain
from typing import Iterable
from transpiler.phases.emit_cpp.consts import MAPPINGS
from transpiler.phases.typing import TypeVariable
from transpiler.phases.typing.exceptions import UnresolvedTypeVariableError
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind, TY_STR, UserType, \
TypeType, TypeOperator, TY_FLOAT, FunctionType, UnionType
from transpiler.utils import UnsupportedNodeError, highlight
class UniversalVisitor:
def visit(self, node):
"""Visit a node."""
__TB__ = f"emitting C++ code for {highlight(node)}"
# __TB_SKIP__ = True
if type(node) == list:
for n in node:
yield from self.visit(n)
else:
for parent in node.__class__.__mro__:
if visitor := getattr(self, 'visit_' + parent.__name__, None):
yield from visitor(node)
break
else:
yield from self.missing_impl(node)
def missing_impl(self, node):
raise UnsupportedNodeError(node)
class NodeVisitor(UniversalVisitor):
def process_args(self, node: ast.arguments) -> (str, str, str):
for field in ("posonlyargs", "vararg", "kwonlyargs", "kw_defaults", "kwarg", "defaults"):
if getattr(node, field, None):
raise NotImplementedError(node, field)
if not node.args:
return "", "()", []
f_args = [(self.fix_name(arg.arg), f"T{i + 1}") for i, arg in enumerate(node.args)]
return (
"<" + ", ".join(f"typename {t}" for _, t in f_args) + ">",
"(" + ", ".join(f"{t} {n}" for n, t in f_args) + ")",
[n for n, _ in f_args]
)
def fix_name(self, name: str) -> str:
if name.startswith("__") and name.endswith("__"):
return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name)
def visit_BaseType(self, node: BaseType) -> Iterable[str]:
node = node.resolve()
if node is TY_INT:
yield "int"
elif node is TY_FLOAT:
yield "double"
elif node is TY_BOOL:
yield "bool"
elif node is TY_NONE:
yield "void"
elif node is TY_STR:
yield "PyStr"
elif isinstance(node, UserType):
# if node.is_reference:
# yield "PyObj<"
#yield "auto"
yield f"referencemodel::Rc<__main____oo<>::{node.name}__oo<>::Obj>"
# if node.is_reference:
# yield "::py_type>"
elif isinstance(node, TypeType):
yield "auto" # TODO
elif isinstance(node, FunctionType):
yield "std::function<"
yield from self.visit(node.return_type)
yield "("
yield from join(", ", map(self.visit, node.parameters))
yield ")>"
elif isinstance(node, Promise):
yield "typon::"
if node.kind == PromiseKind.TASK:
yield "Task"
elif node.kind == PromiseKind.JOIN:
yield "Join"
elif node.kind == PromiseKind.FUTURE:
yield "Future"
elif node.kind == PromiseKind.FORKED:
yield "Forked"
elif node.kind == PromiseKind.GENERATOR:
yield "Generator"
else:
raise NotImplementedError(node)
yield "<"
yield from self.visit(node.return_type)
yield ">"
elif isinstance(node, TypeVariable):
# yield f"TYPEVAR_{node.name}";return
raise UnresolvedTypeVariableError(node)
elif isinstance(node, UnionType) and (ty := node.is_optional()):
yield "std::optional<"
yield from self.visit(ty)
yield ">"
elif isinstance(node, TypeOperator):
yield "typon::Py" + node.name.title()
if node.args:
yield "<"
yield from join(", ", map(self.visit, node.args))
yield ">"
else:
raise NotImplementedError(node)
class CoroutineMode(Flag):
SYNC = 1
FAKE = 2 | SYNC
ASYNC = 4
GENERATOR = 8 | ASYNC
TASK = 16 | ASYNC
JOIN = 32 | ASYNC
class FunctionEmissionKind(enum.Enum):
DECLARATION = enum.auto()
DEFINITION = enum.auto()
METHOD = enum.auto()
LAMBDA = enum.auto()
...@@ -5,7 +5,7 @@ from typing import Iterable, Optional ...@@ -5,7 +5,7 @@ from typing import Iterable, Optional
from transpiler.phases.typing.common import is_builtin from transpiler.phases.typing.common import is_builtin
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable, Promise, TypeType from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable
from transpiler.utils import compare_ast from transpiler.utils import compare_ast
from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap, FunctionEmissionKind from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap, FunctionEmissionKind
from transpiler.phases.emit_cpp.expr import ExpressionVisitor from transpiler.phases.emit_cpp.expr import ExpressionVisitor
......
# coding: utf-8
import ast
from typing import Iterable
from dataclasses import dataclass
from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp import NodeVisitor, FunctionEmissionKind
class ClassVisitor(NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
if gen_instances := getattr(node, "gen_instances", None):
for args, inst in gen_instances.items():
yield from self.visit_ClassDef(inst)
return
yield f"struct {node.name}_s;"
yield f"extern {node.name}_s {node.name};"
yield f"struct {node.name}_s {{"
yield "struct py_type {"
inner = ClassInnerVisitor(node.inner_scope)
for stmt in node.body:
yield from inner.visit(stmt)
yield "template<typename... T> py_type(T&&... args) {"
yield "__init__(this, std::forward<T>(args)...);"
yield "}"
yield "py_type() {}"
yield "py_type(const py_type&) = delete;"
yield "py_type(py_type&&) = delete;"
if getattr(node.type, "is_enum", False):
yield "int value;"
yield "operator int() const { return value; }"
yield "void py_repr(std::ostream &s) const {"
yield f's << "{node.name}.";'
yield "}"
else:
yield "void py_repr(std::ostream &s) const {"
yield f's << "{node.name}(";'
for i, (name, memb) in enumerate(node.type.get_members().items()):
if i != 0:
yield 's << ", ";'
yield f's << "{name}=";'
yield f"repr_to({name}, s);"
yield "s << ')';"
yield "}"
yield "void py_print(std::ostream &s) const {"
yield "py_repr(s);"
yield "}"
yield "};"
yield "template<typename... T> auto operator()(T&&... args) {"
yield "return pyobj<py_type>(std::forward<T>(args)...);"
yield "}"
# outer = ClassOuterVisitor(node.inner_scope)
# for stmt in node.body:
# yield from outer.visit(stmt)
yield f"}} {node.name};"
@dataclass
class ClassInnerVisitor(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
member = self.scope.obj_type.fields[node.target.id]
yield from self.visit(member.type)
yield node.target.id
yield ";"
def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
yield "static constexpr"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_Assign(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
# yield "struct {"
# yield "type* self;"
# from transpiler.phases.emit_cpp.block import BlockVisitor
# yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD, True)
# yield f"}} {node.name} {{ this }};"
yield f"struct {node.name}_m_s : referencemodel::method {{"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD)
yield f"}} static constexpr {node.name} {{}};"
yield ""
@dataclass
class ClassOuterVisitor(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
yield ""
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {"
yield "template<typename Self, typename... T>"
yield "auto operator()(Self self, T&&... args) {"
yield f"return dotp(self, {node.name})(std::forward<T>(args)...);"
yield "}"
yield f"}} {node.name};"
yield ""
# yield "struct : function {"
# from transpiler.phases.emit_cpp.block import BlockVisitor
# yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD)
# yield f"}} static constexpr {node.name} {{}};"
@dataclass
class ClassInnerVisitor2(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
yield ""
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct : referencemodel::method {"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD)
yield f"}} static constexpr {node.name} {{}};"
@dataclass
class ClassInnerVisitor4(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
member = self.scope.obj_type.fields[node.target.id]
yield from self.visit(member.type)
yield node.target.id
yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield ""
\ No newline at end of file
...@@ -219,30 +219,12 @@ class ExpressionVisitor(NodeVisitor): ...@@ -219,30 +219,12 @@ class ExpressionVisitor(NodeVisitor):
yield ")" yield ")"
def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]: def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]:
use_dot = None yield "dot"
if type(node.value.type) == TypeType: yield "(("
use_dot = "dots" yield from self.visit(node.value)
elif isinstance(node.type, FunctionType) and node.type.is_method and not isinstance(node.value.type, Promise): yield "), "
if node.value.type.resolve().is_reference: yield self.fix_name(node.attr)
use_dot = "dotp" yield ")"
else:
use_dot = "dot"
if use_dot:
use_dot = "dot"
if use_dot:
yield use_dot
yield "(("
yield from self.visit(node.value)
yield "), "
yield self.fix_name(node.attr)
yield ")"
else:
yield from self.prec(".").visit(node.value)
if node.value.type.resolve().is_reference:
yield "->"
else:
yield "."
yield self.fix_name(node.attr)
def visit_List(self, node: ast.List) -> Iterable[str]: def visit_List(self, node: ast.List) -> Iterable[str]:
if node.elts: if node.elts:
......
# coding: utf-8
import ast
from dataclasses import dataclass
from typing import Iterable
from transpiler.phases.emit_cpp.consts import SYMBOLS
from transpiler.phases.emit_cpp import CoroutineMode, FunctionEmissionKind
from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.typing.scope import Scope
from transpiler.phases.utils import PlainBlock
# noinspection PyPep8Naming
@dataclass
class FunctionVisitor(BlockVisitor):
def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
yield from self.expr().visit(node.value)
yield ";"
def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target, False)
yield SYMBOLS[type(node.op)] + "="
yield from self.expr().visit(node.value)
yield ";"
def visit_For(self, node: ast.For) -> Iterable[str]:
if not isinstance(node.target, ast.Name):
raise NotImplementedError(node)
if node.orelse:
yield "auto"
yield node.orelse_variable
yield "= true;"
yield f"for (auto {node.target.id} : "
yield from self.expr().visit(node.iter)
yield ")"
yield from self.emit_block(node.inner_scope, node.body) # TODO: why not reuse the scope used for analysis? same in while
if node.orelse:
yield "if ("
yield node.orelse_variable
yield ")"
yield from self.emit_block(node.inner_scope, node.orelse)
def visit_If(self, node: ast.If) -> Iterable[str]:
yield "if ("
yield from self.expr().visit(node.test)
yield ")"
yield from self.emit_block(node.inner_scope, node.body)
if node.orelse:
yield "else "
if isinstance(node.orelse, ast.If):
yield from self.visit(node.orelse)
else:
yield from self.emit_block(node.orelse_scope, node.orelse)
def visit_PlainBlock(self, node: PlainBlock) -> Iterable[str]:
yield from self.emit_block(node.inner_scope, node.body)
def visit_Return(self, node: ast.Return) -> Iterable[str]:
if CoroutineMode.ASYNC in self.generator:
yield "co_return "
else:
yield "return "
if node.value:
yield from self.expr().visit(node.value)
yield ";"
def visit_While(self, node: ast.While) -> Iterable[str]:
if node.orelse:
yield "auto"
yield node.orelse_variable
yield "= true;"
yield "while ("
yield from self.expr().visit(node.test)
yield ")"
yield from self.emit_block(node.inner_scope, node.body)
if node.orelse:
yield "if ("
yield node.orelse_variable
yield ")"
yield from self.emit_block(node.inner_scope, node.orelse)
def visit_Global(self, node: ast.Global) -> Iterable[str]:
yield ""
def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]:
yield ""
def block2(self) -> "FunctionVisitor":
# See the comments in visit_FunctionDef.
# A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same
# variables as the parent scope.
return FunctionVisitor(self.scope.child_share(), generator=self.generator)
def emit_block(self, scope: Scope, items: Iterable[ast.stmt]) -> Iterable[str]:
yield "{"
for child in items:
yield from FunctionVisitor(scope, generator=self.generator).visit(child)
yield "}"
def visit_Break(self, node: ast.Break) -> Iterable[str]:
if (loop := self.scope.is_in_loop()).orelse:
yield loop.orelse_variable
yield " = false;"
yield "break;"
def visit_Try(self, node: ast.Try) -> Iterable[str]:
yield from self.emit_block(node.inner_scope, node.body)
if node.orelse:
raise NotImplementedError(node, "orelse")
if node.finalbody:
raise NotImplementedError(node, "finalbody")
for handler in node.handlers:
#yield from self.visit(handler)
pass
# todo
def visit_Raise(self, node: ast.Raise) -> Iterable[str]:
yield "// raise"
# TODO
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "auto"
yield self.fix_name(node.name)
yield "="
yield from self.visit_func_new(node, FunctionEmissionKind.LAMBDA)
yield ";"
# coding: utf-8
import ast
from typing import Iterable
from dataclasses import dataclass, field
from transpiler.phases.typing import FunctionType
from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp import CoroutineMode, FunctionEmissionKind, NodeVisitor, join
from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.emit_cpp.class_ import ClassVisitor, ClassInnerVisitor, ClassInnerVisitor2, ClassInnerVisitor4
from transpiler.phases.emit_cpp.function import FunctionVisitor
from transpiler.utils import compare_ast, highlight
IGNORED_IMPORTS = {"typon", "typing", "__future__", "dataclasses", "enum"}
# noinspection PyPep8Naming
@dataclass
class ModuleVisitor(BlockVisitor):
includes: list[str] = field(default_factory=list)
def visit_Import(self, node: ast.Import) -> Iterable[str]:
__TB__ = f"emitting C++ code for {highlight(node)}"
for alias in node.names:
concrete = self.fix_name(alias.asname or alias.name)
if alias.module_obj.is_python:
yield f"namespace py_{concrete} {{"
yield f"struct {concrete}_t {{"
for name, obj in alias.module_obj.fields.items():
ty = obj.type.resolve()
if getattr(ty, "python_func_used", False):
yield from self.emit_python_func(alias.name, name, name, ty)
yield "} all;"
yield f"auto& get_all() {{ return all; }}"
yield "}"
yield f'auto& {concrete} = py_{concrete}::get_all();'
elif alias.name in IGNORED_IMPORTS:
yield ""
else:
yield from self.import_module(alias.name)
yield f'auto& {concrete} = py_{alias.name}::get_all();'
def import_module(self, name: str) -> Iterable[str]:
self.includes.append(f'#include <python/{name}.hpp>')
yield ""
def emit_python_func(self, mod: str, name: str, alias: str, fty: FunctionType) -> Iterable[str]:
__TB__ = f"emitting C++ code for Python function {highlight(f'{mod}.{name}')}"
yield "struct {"
yield f"auto operator()("
for i, argty in enumerate(fty.parameters):
if i != 0:
yield ", "
yield "lvalue_or_rvalue<"
yield from self.visit(argty)
yield f"> arg{i}"
yield ") {"
yield "InterpGuard guard{};"
yield "try {"
yield f"return py::module_::import(\"{mod}\").attr(\"{name}\")("
for i, argty in enumerate(fty.parameters):
if i != 0:
yield ", "
yield f"*arg{i}"
yield ").cast<"
yield from self.visit(fty.return_type)
yield ">();"
yield "} catch (py::error_already_set& e) {"
yield 'std::cerr << "Python exception: " << e.what() << std::endl;'
yield "throw;"
yield "}"
yield "}"
yield f"}} {alias};"
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module in IGNORED_IMPORTS:
yield ""
elif node.module_obj.is_python:
for alias in node.names:
fty = alias.item_obj.resolve()
#assert isinstance(fty, FunctionType)
yield from self.emit_python_func(node.module, alias.name, alias.asname or alias.name, fty)
else:
yield from self.import_module(node.module)
for alias in node.names:
yield f"auto& {alias.asname or alias.name} = py_{node.module}::get_all().{alias.name};"
def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
if isinstance(node.value, ast.Str):
if "\n" in node.value.s:
yield f"/*{node.value.s}*/"
else:
yield f"//{node.value.s}"
else:
raise NotImplementedError(node)
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
#yield from ClassVisitor().visit(node)
yield from ()
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from super().visit_free_func(node, FunctionEmissionKind.DECLARATION)
@dataclass
class ModuleVisitor2(NodeVisitor):
scope: Scope
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from BlockVisitor(self.scope).visit_free_func(node, FunctionEmissionKind.DEFINITION)
def visit_AST(self, node: ast.AST) -> Iterable[str]:
yield ""
pass
@dataclass
class ModuleVisitor3(NodeVisitor):
scope: Scope
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
yield from ()
return
if gen_instances := getattr(node, "gen_instances", None):
for args, inst in gen_instances.items():
yield from self.visit_ClassDef(inst)
return
yield f"static constexpr _detail_<__main__>::{node.name}<> {node.name} {{}};"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from BlockVisitor(self.scope).visit_free_func(node, FunctionEmissionKind.DEFINITION)
@dataclass
class ModuleVisitor4(NodeVisitor):
scope: Scope
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
if gen_instances := getattr(node, "gen_instances", None):
for args, inst in gen_instances.items():
yield from self.visit_ClassDef(inst)
return
yield f"template <typename _Base0 = referencemodel::object>"
yield f"struct {node.name}__oo : referencemodel::classtype<_Base0, {node.name}__oo<>> {{"
yield f"static constexpr std::string_view name = \"{node.name}\";"
inner = ClassInnerVisitor2(node.inner_scope)
for stmt in node.body:
yield from inner.visit(stmt)
yield f"struct Obj : referencemodel::instance<{node.name}__oo<>, Obj> {{"
inner = ClassInnerVisitor4(node.inner_scope)
for stmt in node.body:
yield from inner.visit(stmt)
yield "template <typename... U>"
yield "Obj(U&&... args) {"
yield "dot(this, __init__)(this, std::forward<U>(args)...);"
yield "}"
yield "};"
yield "template <typename... T>"
yield "auto operator() (T&&... args) const {"
yield "return referencemodel::rc(Obj{std::forward<T>(args)...});"
yield "}"
yield f"}};"
yield f"static constexpr {node.name}__oo<> {node.name} {{}};"
yield f"static_assert(sizeof {node.name} == 1);"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from ()
@dataclass
class ModuleVisitorExt(NodeVisitor):
scope: Scope
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
if getattr(node, "is_main", False):
yield from ()
return
#yield from BlockVisitor(self.scope).visit_free_func(node, FunctionEmissionKind.DEFINITION)
#yield f'm.def("{node.name}", CoroWrapper(PROGRAMNS::{node.name}));'
yield f'm.def("{node.name}", PROGRAMNS::{node.name});'
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
if gen_instances := getattr(node, "gen_instances", None):
for args, inst in gen_instances.items():
yield from self.visit_ClassDef(inst)
return
yield f"py::class_<PROGRAMNS::{node.name}_s::py_type>(m, \"{node.name}\")"
if init := node.type.fields.get("__init__", None):
init = init.type.resolve().remove_self()
init_params = init.parameters
yield ".def(py::init<"
yield from join(", ", map(self.visit, init_params))
yield ">())"
yield f'.def("__repr__", [](const PROGRAMNS::{node.name}_s::py_type& self)'
yield "{ return repr(self); })"
for f, v in node.type.fields.items():
if f == "__init__":
continue
if isinstance(v.type, FunctionType):
meth = v.type.remove_self()
yield f'.def("{f}", [](const PROGRAMNS::{node.name}_s::py_type& a'
if meth.parameters:
yield ","
vis = BlockVisitor(node.scope)
yield from vis.visit_func_params(((f"arg{i}", ty, None) for i, ty in enumerate(meth.parameters)), FunctionEmissionKind.LAMBDA)
yield f') {{ return dotp(&a, {f})('
if meth.parameters:
yield from join(", ", (f"arg{i}" for i, _ in enumerate(meth.parameters)))
yield ').call(); })'
else:
yield f'.def_readwrite("{f}", &PROGRAMNS::{node.name}_s::py_type::{f})'
yield ";"
pass
def visit_AST(self, node: ast.AST) -> Iterable[str]:
yield from ()
pass
...@@ -5,18 +5,21 @@ from transpiler.phases.typing.scope import Scope, VarKind, VarDecl, ScopeKind ...@@ -5,18 +5,21 @@ from transpiler.phases.typing.scope import Scope, VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.types import MemberDef, ResolvedConcreteType, UniqueTypeMixin from transpiler.phases.typing.types import MemberDef, ResolvedConcreteType, UniqueTypeMixin
def make_module(name: str, scope: Scope) -> ResolvedConcreteType: class ModuleType(UniqueTypeMixin, ResolvedConcreteType):
class CreatedType(UniqueTypeMixin, ResolvedConcreteType): pass
def make_module(name: str, scope: Scope) -> ModuleType:
class CreatedType(ModuleType):
def name(self): def name(self):
return name return name
ty = CreatedType() ty = CreatedType()
for n, v in scope.vars.items(): for n, v in scope.vars.items():
ty.fields[n] = MemberDef(v.type, v.val, False) ty.fields[n] = MemberDef(v.type, v.val, v.is_item_decl)
return ty return ty
visited_modules = {} visited_modules = {}
def parse_module(mod_name: str, python_path: Path, scope=None): def parse_module(mod_name: str, python_path: Path, scope=None, preprocess=None):
path = python_path / mod_name path = python_path / mod_name
if not path.exists(): if not path.exists():
...@@ -41,12 +44,15 @@ def parse_module(mod_name: str, python_path: Path, scope=None): ...@@ -41,12 +44,15 @@ def parse_module(mod_name: str, python_path: Path, scope=None):
if real_path.suffix == ".py": if real_path.suffix == ".py":
from transpiler.phases.typing.stdlib import StdlibVisitor from transpiler.phases.typing.stdlib import StdlibVisitor
StdlibVisitor(python_path, mod_scope).visit(ast.parse(real_path.read_text())) node = ast.parse(real_path.read_text())
if preprocess:
node = preprocess(node)
StdlibVisitor(python_path, mod_scope).visit(node)
else: else:
raise NotImplementedError(f"Unsupported file type {path.suffix}") raise NotImplementedError(f"Unsupported file type {path.suffix}")
mod = make_mod_decl(mod_name, mod_scope) mod = make_module(mod_name, mod_scope)
visited_modules[real_path.as_posix()] = mod visited_modules[real_path.as_posix()] = VarDecl(VarKind.LOCAL, mod, {k: v.type for k, v in mod_scope.vars.items()})
return mod return mod
# def process_module(mod_path: Path, scope): # def process_module(mod_path: Path, scope):
...@@ -82,7 +88,3 @@ def parse_module(mod_name: str, python_path: Path, scope=None): ...@@ -82,7 +88,3 @@ def parse_module(mod_name: str, python_path: Path, scope=None):
# child_mod = process_module(child, scope) # child_mod = process_module(child, scope)
# visited_modules[child.as_posix()] = child_mod # visited_modules[child.as_posix()] = child_mod
# return mod # return mod
def make_mod_decl(child, mod_scope):
return VarDecl(VarKind.LOCAL, make_module(child, mod_scope), {k: v.type for k, v in mod_scope.vars.items()})
\ No newline at end of file
...@@ -28,6 +28,7 @@ class VarDecl: ...@@ -28,6 +28,7 @@ class VarDecl:
kind: VarKind kind: VarKind
type: BaseType type: BaseType
val: Any = RuntimeValue() val: Any = RuntimeValue()
is_item_decl: bool = False
class ScopeKind(Enum): class ScopeKind(Enum):
......
...@@ -15,7 +15,7 @@ from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind ...@@ -15,7 +15,7 @@ from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, BuiltinGenericType, BuiltinType, create_builtin_generic_type, \ from transpiler.phases.typing.types import BaseType, BuiltinGenericType, BuiltinType, create_builtin_generic_type, \
create_builtin_type, ConcreteType, GenericInstanceType, TypeListType, TypeTupleType, GenericParameter, \ create_builtin_type, ConcreteType, GenericInstanceType, TypeListType, TypeTupleType, GenericParameter, \
GenericParameterKind, TypeVariable, ResolvedConcreteType, MemberDef, ClassTypeType, CallableInstanceType, \ GenericParameterKind, TypeVariable, ResolvedConcreteType, MemberDef, ClassTypeType, CallableInstanceType, \
MethodType, UniqueTypeMixin, GenericType MethodType, UniqueTypeMixin, GenericType, BlockData
from transpiler.phases.utils import NodeVisitorSeq from transpiler.phases.utils import NodeVisitorSeq
def visit_generic_item( def visit_generic_item(
...@@ -79,7 +79,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -79,7 +79,7 @@ class StdlibVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE) scope: Scope = field(default_factory=lambda: PRELUDE)
cur_class: Optional[ResolvedConcreteType] = None cur_class: Optional[ResolvedConcreteType] = None
def resolve_module_import(self, name: str) -> VarDecl: def resolve_module_import(self, name: str):
# tries = [ # tries = [
# self.python_path.parent / f"{name}.py", # self.python_path.parent / f"{name}.py",
# self.python_path.parent / name # self.python_path.parent / name
...@@ -109,19 +109,19 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -109,19 +109,19 @@ class StdlibVisitor(NodeVisitorSeq):
def visit_ImportFrom(self, node: ast.ImportFrom): def visit_ImportFrom(self, node: ast.ImportFrom):
module = self.resolve_module_import(node.module) module = self.resolve_module_import(node.module)
node.module_obj = module.type node.module_obj = module
for alias in node.names: for alias in node.names:
thing = module.val.get(alias.name) thing = module.fields.get(alias.name)
if not thing: if not thing:
from transpiler.phases.typing.exceptions import UnknownModuleMemberError from transpiler.phases.typing.exceptions import UnknownModuleMemberError
raise UnknownModuleMemberError(node.module, alias.name) raise UnknownModuleMemberError(node.module, alias.name)
alias.item_obj = thing alias.item_obj = thing.type
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing) self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing.type)
def visit_Import(self, node: ast.Import): def visit_Import(self, node: ast.Import):
for alias in node.names: for alias in node.names:
mod = self.resolve_module_import(alias.name) mod = self.resolve_module_import(alias.name)
alias.module_obj = mod.type alias.module_obj = mod
self.scope.vars[alias.asname or alias.name] = mod self.scope.vars[alias.asname or alias.name] = mod
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
...@@ -131,11 +131,12 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -131,11 +131,12 @@ class StdlibVisitor(NodeVisitorSeq):
else: else:
base_class = create_builtin_generic_type if node.type_params else create_builtin_type base_class = create_builtin_generic_type if node.type_params else create_builtin_type
NewType = base_class(node.name) NewType = base_class(node.name)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, NewType.type_type()) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, NewType.type_type(), is_item_decl=True)
def visit_nongeneric(scope: Scope, output: ResolvedConcreteType): def visit_nongeneric(scope: Scope, output: ResolvedConcreteType):
cl_scope = scope.child(ScopeKind.CLASS) cl_scope = scope.child(ScopeKind.CLASS)
cl_scope.declare_local("Self", output.type_type()) cl_scope.declare_local("Self", output.type_type())
output.block_data = BlockData(node, scope)
visitor = StdlibVisitor(self.python_path, cl_scope, output) visitor = StdlibVisitor(self.python_path, cl_scope, output)
bases = [self.anno().visit(base) for base in node.bases] bases = [self.anno().visit(base) for base in node.bases]
match bases: match bases:
...@@ -157,6 +158,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -157,6 +158,7 @@ class StdlibVisitor(NodeVisitorSeq):
def visit_nongeneric(scope, output: CallableInstanceType): def visit_nongeneric(scope, output: CallableInstanceType):
scope = scope.child(ScopeKind.FUNCTION) scope = scope.child(ScopeKind.FUNCTION)
arg_visitor = TypeAnnotationVisitor(scope) arg_visitor = TypeAnnotationVisitor(scope)
output.block_data = BlockData(node, scope)
output.parameters = [arg_visitor.visit(arg.annotation) for arg in node.args.args] output.parameters = [arg_visitor.visit(arg.annotation) for arg in node.args.args]
output.return_type = arg_visitor.visit(node.returns) output.return_type = arg_visitor.visit(node.returns)
output.optional_at = len(node.args.args) - len(node.args.defaults) output.optional_at = len(node.args.args) - len(node.args.defaults)
...@@ -198,6 +200,10 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -198,6 +200,10 @@ class StdlibVisitor(NodeVisitorSeq):
node.type_params.append(ast.TypeVar(arg_name, arg.annotation)) node.type_params.append(ast.TypeVar(arg_name, arg.annotation))
arg.annotation = ast.Name(arg_name, ast.Load()) arg.annotation = ast.Name(arg_name, ast.Load())
if node.returns is None:
node.returns = ast.Name("AutoVar$return", ast.Load())
node.type_params.append(ast.TypeVar("AutoVar$return", None))
# if self.cur_class is not None: # if self.cur_class is not None:
# node.type_params.append(ast.TypeVar("Self", None)) # node.type_params.append(ast.TypeVar("Self", None))
...@@ -232,9 +238,9 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -232,9 +238,9 @@ class StdlibVisitor(NodeVisitorSeq):
NewType = base_class() NewType = base_class()
FuncType.__name__ = NewType.name() FuncType.__name__ = NewType.name()
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, NewType) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, NewType, is_item_decl=True)
if self.cur_class is not None: if self.cur_class is not None:
self.cur_class.fields[node.name] = MemberDef(NewType, node) self.cur_class.fields[node.name] = MemberDef(NewType, node, in_class_def=True)
visit_generic_item(visit_nongeneric, node, NewType, self.scope, InstanceType, True) visit_generic_item(visit_nongeneric, node, NewType, self.scope, InstanceType, True)
......
...@@ -287,7 +287,9 @@ class GenericParameterKind(enum.Enum): ...@@ -287,7 +287,9 @@ class GenericParameterKind(enum.Enum):
@dataclass @dataclass
class GenericParameter: class GenericParameter:
name: str name: str
kind: GenericParameterKind kind: GenericParameterKind = GenericParameterKind.NORMAL
gparam = GenericParameter
@dataclass @dataclass
class GenericConstraint: class GenericConstraint:
...@@ -589,6 +591,8 @@ class ClassTypeType(GenericInstanceType): ...@@ -589,6 +591,8 @@ class ClassTypeType(GenericInstanceType):
inner_type: BaseType inner_type: BaseType
class ClassType(UniqueTypeMixin, GenericType): class ClassType(UniqueTypeMixin, GenericType):
parameters = [gparam("T")]
def name(self): def name(self):
return "Type" return "Type"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment