Commit 101d8867 authored by Tom Niget's avatar Tom Niget

User generics attempt #4

parent 6cd4878c
...@@ -47,3 +47,19 @@ Note: gcc may require the `-fcoroutines` flag. ...@@ -47,3 +47,19 @@ Note: gcc may require the `-fcoroutines` flag.
`cd` into the `trans` directory, set up your `.env` file (you can copy the `.env.example` file), and run `python3 test_runner.py`. `cd` into the `trans` directory, set up your `.env` file (you can copy the `.env.example` file), and run `python3 test_runner.py`.
If you're getting include errors, make sure you've cloned the runtime submodule. If you forgot, run `git submodule update --init --recursive`. If you're getting include errors, make sure you've cloned the runtime submodule. If you forgot, run `git submodule update --init --recursive`.
## TODO (2023-10-17)
Implement custom protocol definition so that we can define stuff like Iterable and it generates a concept in C++ and a type erased wrapper (like std::function)
If you call a function that takes Iterable => the function is a template so no overhead
Be able to store functions using their real functor type name so no overhead of std::function
If you store a function with no further information it uses our callable type
If you store an object into a protocol type => it uses the protocol type erased wrapper
Implement decorators as template inheriting classes as to allow people to implement their own decorators in the c++ side
Use "static" to store the python interpreter for lazy init of python functions
\ No newline at end of file
# coding: utf-8
Generic: BuiltinFeature["Generic"]
TypeVar: BuiltinFeature["TypeVar"]
\ No newline at end of file
# coding: utf-8
from typing import TypeVar, Generic
from dataclasses import dataclass
T = TypeVar("T")
@dataclass
class Thing():
x: int
if __name__ == "__main__":
a = Thing(1)
\ No newline at end of file
...@@ -166,6 +166,7 @@ class BlockVisitor(NodeVisitor): ...@@ -166,6 +166,7 @@ class BlockVisitor(NodeVisitor):
def visit_Assign(self, node: ast.Assign) -> Iterable[str]: def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1: if len(node.targets) != 1:
raise NotImplementedError(node) raise NotImplementedError(node)
#if node.value.type
yield from self.visit_lvalue(node.targets[0], node.is_declare) yield from self.visit_lvalue(node.targets[0], node.is_declare)
yield " = " yield " = "
yield from self.expr().visit(node.value) yield from self.expr().visit(node.value)
......
...@@ -54,7 +54,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -54,7 +54,7 @@ class ScoperBlockVisitor(ScoperVisitor):
self.scope.vars[alias.asname or alias.name] = dataclasses.replace(mod, kind=VarKind.LOCAL) self.scope.vars[alias.asname or alias.name] = dataclasses.replace(mod, kind=VarKind.LOCAL)
def visit_ImportFrom(self, node: ast.ImportFrom): def visit_ImportFrom(self, node: ast.ImportFrom):
if node.module in {"typing", "__future__"}: if node.module in {"typing2", "__future__"}:
return return
module = self.get_module(node.module) module = self.get_module(node.module)
node.module_obj = module.type node.module_obj = module.type
...@@ -151,26 +151,48 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -151,26 +151,48 @@ class ScoperBlockVisitor(ScoperVisitor):
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
ctype = UserType(node.name) class NewUserType(UserType):
def __init__(self):
super().__init__(node.name)
#ctype = UserType(node.name)
ctype = NewUserType
cttype = TypeType(ctype) cttype = TypeType(ctype)
typevars = []
bases_after = []
for base in node.bases:
if isinstance(base, ast.Subscript):
if isinstance(base.slice, ast.Name):
sliceval = [base.slice.id]
elif isinstance(base.slice, ast.Tuple):
sliceval = [n.id for n in base.slice.elts]
if is_builtin(self.expr().visit(base.slice), "Generic"):
typevars = sliceval
else:
bases_after.append(base)
if not typevars:
cttype.type_object = cttype.type_object()
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, cttype) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, cttype)
scope = self.scope.child(ScopeKind.CLASS) scope = self.scope.child(ScopeKind.CLASS)
scope.obj_type = ctype scope.obj_type = cttype.type_object
scope.class_ = scope scope.class_ = scope
node.inner_scope = scope node.inner_scope = scope
node.type = ctype node.type = cttype.type_object
visitor = ScoperClassVisitor(scope, cur_class=cttype) visitor = ScoperClassVisitor(scope, cur_class=cttype)
visitor.visit_block(node.body) visitor.visit_block(node.body)
for deco in node.decorator_list: for base in bases_after:
deco = self.expr().visit(deco) base = self.expr().visit(base)
if is_builtin(deco, "dataclass"): if is_builtin(base, "Enum"):
# init_type = FunctionType([cttype, *cttype.members.values()], TypeVariable()) cttype.type_object.parents.append(TY_INT)
# cttype.methods["__init__"] = init_type for k, m in cttype.type_object.fields.items():
m.type = cttype.type_object
m.val = ast.literal_eval(m.val)
assert type(m.val) == int
cttype.type_object.fields["value"] = MemberDef(TY_INT)
lnd = linenodata(node) lnd = linenodata(node)
init_method = ast.FunctionDef( init_method = ast.FunctionDef(
name="__init__", name="__init__",
args=ast.arguments( args=ast.arguments(
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.get_members()]], args=[ast.arg(arg="self"), ast.arg(arg="value")],
defaults=[], defaults=[],
kw_defaults=[], kw_defaults=[],
kwarg=None, kwarg=None,
...@@ -179,10 +201,10 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -179,10 +201,10 @@ class ScoperBlockVisitor(ScoperVisitor):
), ),
body=[ body=[
ast.Assign( ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)], targets=[ast.Attribute(value=ast.Name(id="self"), attr="value")],
value=ast.Name(id=n), value=ast.Name(id="value"),
**lnd **lnd
) for n in ctype.get_members() )
], ],
decorator_list=[], decorator_list=[],
returns=None, returns=None,
...@@ -192,22 +214,19 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -192,22 +214,19 @@ class ScoperBlockVisitor(ScoperVisitor):
_, rtype = visitor.visit_FunctionDef(init_method) _, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype) visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method) node.body.append(init_method)
cttype.type_object.is_enum = True
else: else:
raise NotImplementedError(deco) raise NotImplementedError(base)
for base in node.bases: for deco in node.decorator_list:
base = self.expr().visit(base) deco = self.expr().visit(deco)
if is_builtin(base, "Enum"): if is_builtin(deco, "dataclass"):
ctype.parents.append(TY_INT) # init_type = FunctionType([cttype, *cttype.members.values()], TypeVariable())
for k, m in ctype.fields.items(): # cttype.methods["__init__"] = init_type
m.type = ctype
m.val = ast.literal_eval(m.val)
assert type(m.val) == int
ctype.fields["value"] = MemberDef(TY_INT)
lnd = linenodata(node) lnd = linenodata(node)
init_method = ast.FunctionDef( init_method = ast.FunctionDef(
name="__init__", name="__init__",
args=ast.arguments( args=ast.arguments(
args=[ast.arg(arg="self"), ast.arg(arg="value")], args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in cttype.type_object.get_members()]],
defaults=[], defaults=[],
kw_defaults=[], kw_defaults=[],
kwarg=None, kwarg=None,
...@@ -216,10 +235,10 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -216,10 +235,10 @@ class ScoperBlockVisitor(ScoperVisitor):
), ),
body=[ body=[
ast.Assign( ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr="value")], targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)],
value=ast.Name(id="value"), value=ast.Name(id=n),
**lnd **lnd
) ) for n in cttype.type_object.get_members()
], ],
decorator_list=[], decorator_list=[],
returns=None, returns=None,
...@@ -229,9 +248,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -229,9 +248,8 @@ class ScoperBlockVisitor(ScoperVisitor):
_, rtype = visitor.visit_FunctionDef(init_method) _, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype) visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method) node.body.append(init_method)
ctype.is_enum = True
else: else:
raise NotImplementedError(base) raise NotImplementedError(deco)
def visit_If(self, node: ast.If): def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
......
...@@ -4,10 +4,10 @@ import inspect ...@@ -4,10 +4,10 @@ import inspect
from typing import List from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next, is_builtin
from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \ from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \ TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \
TY_SLICE, TY_FLOAT, RuntimeValue TY_SLICE, TY_FLOAT, RuntimeValue, BuiltinFeature
from transpiler.utils import linenodata from transpiler.utils import linenodata
DUNDER = { DUNDER = {
...@@ -106,13 +106,11 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -106,13 +106,11 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Call(self, node: ast.Call) -> BaseType: def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func) ftype = self.visit(node.func)
if is_builtin(ftype, "TypeVar"):
return TypeType(TypeVariable(*[ast.literal_eval(arg) for arg in node.args]))
if ftype.typevars: if ftype.typevars:
ftype = ftype.gen_sub(None, {v.name: TypeVariable(v.name) for v in ftype.typevars}) ftype = ftype.gen_sub(None, {v.name: TypeVariable(v.name) for v in ftype.typevars})
from transpiler.exceptions import CompileError from transpiler.exceptions import CompileError
try:
argtypes = [self.visit(arg) for arg in node.args]
except CompileError as e:
pass
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args]) rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
actual = rtype actual = rtype
node.is_await = False node.is_await = False
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment