Commit 95b8479f authored by Tom Niget's avatar Tom Niget

Few things work

parent ef37573d
from typing import Self, Generic, Protocol, Optional
assert 5
class object:
def __eq__(self, other: Self) -> bool: ...
def __ne__(self, other: Self) -> bool: ...
class int:
def __add__(self, other: Self) -> Self: ...
def __sub__(self, other: Self) -> Self: ...
def __mul__(self, other: Self) -> Self: ...
def __and__(self, other: Self) -> Self: ...
def __neg__(self) -> Self: ...
def __init__(self, x: object) -> None: ...
def __lt__(self, other: Self) -> bool: ...
def __gt__(self, other: Self) -> bool: ...
def __mod__(self, other: Self) -> Self: ...
def __ge__(self, other: Self) -> bool: ...
class float:
def __init__(self, x: object) -> None: ...
assert int.__add__
assert (5).__add__
class slice:
pass
class HasLen(Protocol):
def __len__(self) -> int: ...
def len(x: HasLen) -> int:
...
class Iterator[U](Protocol):
def __iter__(self) -> Self: ...
def __next__(self) -> U: ...
class Iterable[U](Protocol):
def __iter__(self) -> Iterator[U]: ...
class str:
def find(self, sub: Self) -> int: ...
def format(self, *args) -> Self: ...
def encode(self, encoding: Self) -> bytes: ...
def __len__(self) -> int: ...
def __add__(self, other: Self) -> Self: ...
def __mul__(self, other: int) -> Self: ...
def startswith(self, prefix: Self) -> bool: ...
def __getitem__(self, item: int | slice) -> Self: ...
def isspace(self) -> bool: ...
def __contains__(self, item: Self) -> bool: ...
assert len("a")
class bytes:
def decode(self, encoding: str) -> str: ...
def __len__(self) -> int: ...
class list[U]:
def __add__(self, other: Self) -> Self: ...
def __mul__(self, other: int) -> Self: ...
def __getitem__(self, index: int) -> U: ...
def __setitem__(self, index: int, value: U) -> None: ...
def pop(self, index: int = -1) -> U: ...
def __iter__(self) -> Iterator[U]: ...
def __len__(self) -> int: ...
def append(self, value: U) -> None: ...
def __contains__(self, item: U) -> bool: ...
def __init__(self, it: Iterator[U]) -> None: ...
assert [1, 2].__iter__()
assert list[int].__iter__
class dict[U, V]:
def __getitem__(self, key: U) -> V: ...
def __setitem__(self, key: U, value: V) -> None: ...
def __len__(self) -> int: ...
assert(len(["a"]))
#assert list.__getitem__
assert [].__getitem__
assert [4].__getitem__
assert [1, 2, 3][1]
def iter[U](x: Iterable[U]) -> Iterator[U]:
...
assert iter
def next[U](it: Iterator[U], default: Optional[U] = None) -> U:
...
# what happens with multiple functions
assert iter(["1", "2"])
def identity[U](x: U) -> U:
...
assert identity(1)
assert identity("a")
def identity_2[U, V](x: U, y: V) -> tuple[U, V]:
...
assert list.__add__
assert list.__add__([5], [[6][0]])
assert list[int].__add__
assert identity_2(1, "a")
assert lambda x, y: identity_2(x, y)
assert lambda x: identity_2(x, x)
def print(*args) -> None: ...
assert print
def input(prompt: str = "") -> str:
...
def range(*args) -> Iterator[int]: ...
assert [].__add__
assert [6].__add__
assert [True].__add__
assert lambda x: [x].__add__
assert next(range(6), None)
class file:
def read(self, size: int=0) -> Task[str]: ...
def close(self) -> Task[None]: ...
def __enter__(self) -> Self: ...
def __exit__(self) -> Task[bool]: ...
def open(filename: str, mode: str) -> Task[file]: ...
def __test_opt(x: int, y: int = 5) -> int:
...
assert __test_opt
assert __test_opt(5)
assert __test_opt(5, 6)
assert not __test_opt(5, 6, 7)
assert not __test_opt()
class __test_type:
def __init__(self) -> None: ...
def test_opt(self, x: int, y: int = 5) -> int:
...
assert __test_type().test_opt(5)
assert __test_type().test_opt(5, 6)
assert not __test_type().test_opt(5, 6, 7)
assert not __test_type().test_opt()
def exit(code: int | None = None) -> None: ...
class Exception:
def __init__(self, message: str) -> None: ...
\ No newline at end of file
from typing import Self, Protocol, Optional
assert 5
class object:
def __eq__(self, other: Self) -> bool: ...
def __ne__(self, other: Self) -> bool: ...
class int:
def __add__(self, other: Self) -> Self: ...
def __sub__(self, other: Self) -> Self: ...
def __mul__(self, other: Self) -> Self: ...
def __and__(self, other: Self) -> Self: ...
def __neg__(self) -> Self: ...
def __init__(self, x: object) -> None: ...
def __lt__(self, other: Self) -> bool: ...
def __gt__(self, other: Self) -> bool: ...
def __mod__(self, other: Self) -> Self: ...
def __ge__(self, other: Self) -> bool: ...
class float:
def __init__(self, x: object) -> None: ...
assert int.__add__
assert (5).__add__
class slice:
pass
class HasLen(Protocol):
def __len__(self) -> int: ...
def len(x: HasLen) -> int:
...
class Iterator[U](Protocol):
def __iter__(self) -> Self: ...
def __next__(self) -> U: ...
class Iterable[U](Protocol):
def __iter__(self) -> Iterator[U]: ...
class str:
def find(self, sub: Self) -> int: ...
def format(self, *args) -> Self: ...
def encode(self, encoding: Self) -> bytes: ...
def __len__(self) -> int: ...
def __add__(self, other: Self) -> Self: ...
def __mul__(self, other: int) -> Self: ...
def startswith(self, prefix: Self) -> bool: ...
def __getitem__(self, item: int | slice) -> Self: ...
def isspace(self) -> bool: ...
def __contains__(self, item: Self) -> bool: ...
assert len("a")
class bytes:
def decode(self, encoding: str) -> str: ...
def __len__(self) -> int: ...
class list[U]:
def __add__(self, other: Self) -> Self: ...
def __mul__(self, other: int) -> Self: ...
def __getitem__(self, index: int) -> U: ...
def __setitem__(self, index: int, value: U) -> None: ...
def pop(self, index: int = -1) -> U: ...
def __iter__(self) -> Iterator[U]: ...
def __len__(self) -> int: ...
def append(self, value: U) -> None: ...
def __contains__(self, item: U) -> bool: ...
def __init__(self, it: Iterator[U]) -> None: ...
assert [1, 2].__iter__()
assert list[int].__iter__
class dict[U, V]:
def __getitem__(self, key: U) -> V: ...
def __setitem__(self, key: U, value: V) -> None: ...
def __len__(self) -> int: ...
assert(len(["a"]))
#assert list.__getitem__
assert [].__getitem__
assert [4].__getitem__
assert [1, 2, 3][1]
def iter[U](x: Iterable[U]) -> Iterator[U]:
...
assert iter
def next[U](it: Iterator[U], default: Optional[U] = None) -> U:
...
# what happens with multiple functions
assert iter(["1", "2"])
def identity[U](x: U) -> U:
...
assert identity(1)
assert identity("a")
def identity_2[U, V](x: U, y: V) -> tuple[U, V]:
...
assert list.__add__
assert list.__add__([5], [[6][0]])
assert list[int].__add__
assert identity_2(1, "a")
assert lambda x, y: identity_2(x, y)
assert lambda x: identity_2(x, x)
def print(*args) -> None: ...
assert print
def input(prompt: str = "") -> str:
...
def range(*args) -> Iterator[int]: ...
assert [].__add__
assert [6].__add__
assert [True].__add__
assert lambda x: [x].__add__
assert next(range(6), None)
class file:
def read(self, size: int=0) -> Task[str]: ...
def close(self) -> Task[None]: ...
def __enter__(self) -> Self: ...
def __exit__(self) -> Task[bool]: ...
def open(filename: str, mode: str) -> Task[file]: ...
def __test_opt(x: int, y: int = 5) -> int:
...
assert __test_opt
assert __test_opt(5)
assert __test_opt(5, 6)
assert not __test_opt(5, 6, 7)
assert not __test_opt()
class __test_type:
def __init__(self) -> None: ...
def test_opt(self, x: int, y: int = 5) -> int:
...
assert __test_type().test_opt(5)
assert __test_type().test_opt(5, 6)
assert not __test_type().test_opt(5, 6, 7)
assert not __test_type().test_opt()
def exit(code: int | None = None) -> None: ...
class Exception:
def __init__(self, message: str) -> None: ...
\ No newline at end of file
# coding: utf-8
Protocol = BuiltinFeature["Protocol"]
Self = BuiltinFeature["Self"]
Optional = BuiltinFeature["Optional"]
\ No newline at end of file
......@@ -6,12 +6,15 @@ import inspect
import os
import traceback
from pathlib import Path
#os.environ["TERM"] = "xterm-256"
import colorama
from transpiler.phases.desugar_compare import DesugarCompare
from transpiler.phases.desugar_op import DesugarOp
from transpiler.phases.typing import PRELUDE
from transpiler.phases.typing.modules import parse_module
colorama.init()
......@@ -20,10 +23,9 @@ from transpiler.phases.desugar_with import DesugarWith
#from transpiler.phases.emit_cpp.file import FileVisitor
from transpiler.phases.if_main import IfMainVisitor
#from transpiler.phases.typing.block import ScoperBlockVisitor
from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.scope import Scope, ScopeKind
from transpiler.utils import highlight
from itertools import islice
import sys
import colorful as cf
......@@ -43,7 +45,7 @@ def exception_hook(exc_type, exc_value, tb):
local_vars = tb.tb_frame.f_locals
name = tb.tb_frame.f_code.co_name
if name == "transpile":
if name in ("transpile", "parse_module"):
last_file = local_vars["path"]
if name == "visit" and (node := local_vars["node"]) and isinstance(node, ast.AST):
......@@ -66,7 +68,7 @@ def exception_hook(exc_type, exc_value, tb):
print()
tb = tb.tb_next
if last_node is not None:
if last_node is not None and last_file is not None:
print()
if not hasattr(last_node, "lineno"):
print(cf.red("Error: "), cf.white("No line number available"))
......@@ -175,6 +177,11 @@ else:
pydevd.original_excepthook = sys.excepthook
typon_std = Path(__file__).parent.parent / "stdlib"
#discover_module(typon_std, PRELUDE.child(ScopeKind.GLOBAL))
parse_module("builtins", typon_std, PRELUDE)
def transpile(source, name: str, path=None):
__TB__ = f"transpiling module {cf.white(name)}"
res = ast.parse(source, type_comments=True)
......
import ast
from pathlib import Path
from logging import debug
from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, ResolvedConcreteType, \
MemberDef, TY_FLOAT, TY_BUILTIN_FEATURE, TY_TUPLE, TY_DICT, TY_SET, TY_LIST, TY_BYTES, TY_OBJECT, TY_CPP_TYPE, \
TY_OPTIONAL, UniqueTypeMixin, TY_CALLABLE, TY_TASK
# PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TypeType(TY_INT)),
# "float": VarDecl(VarKind.LOCAL, TypeType(TY_FLOAT)),
# "str": VarDecl(VarKind.LOCAL, TypeType(TY_STR)),
# "bytes": VarDecl(VarKind.LOCAL, TypeType(TY_BYTES)),
# "bool": VarDecl(VarKind.LOCAL, TypeType(TY_BOOL)),
# "complex": VarDecl(VarKind.LOCAL, TypeType(TY_COMPLEX)),
# "None": VarDecl(VarKind.LOCAL, TypeType(TY_NONE)),
# "Callable": VarDecl(VarKind.LOCAL, TypeType(FunctionType)),
# #"TypeVar": VarDecl(VarKind.LOCAL, TypeType(TypeVariable)),
# "CppType": VarDecl(VarKind.LOCAL, TypeType(CppType)),
# "list": VarDecl(VarKind.LOCAL, TypeType(PyList)),
# "dict": VarDecl(VarKind.LOCAL, TypeType(PyDict)),
# "Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)),
# "Task": VarDecl(VarKind.LOCAL, TypeType(Task)),
# "Future": VarDecl(VarKind.LOCAL, TypeType(Future)),
# "Iterator": VarDecl(VarKind.LOCAL, TypeType(PyIterator)),
# "tuple": VarDecl(VarKind.LOCAL, TypeType(TupleType)),
# "slice": VarDecl(VarKind.LOCAL, TypeType(TY_SLICE)),
# "object": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
# "BuiltinFeature": VarDecl(VarKind.LOCAL, TypeType(BuiltinFeature)),
# "Any": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
# "Optional": VarDecl(VarKind.LOCAL, TypeType(lambda x: UnionType(x, TY_NONE))),
# })
from transpiler.phases.typing.common import PRELUDE
from transpiler.phases.typing.scope import VarKind, VarDecl
from transpiler.phases.typing.types import TY_TASK, TY_CALLABLE, TY_OPTIONAL, TY_CPP_TYPE, TY_BUILTIN_FEATURE, TY_TUPLE, \
TY_DICT, TY_SET, TY_LIST, TY_COMPLEX, TY_BYTES, TY_STR, TY_FLOAT, TY_INT, TY_BOOL, TY_OBJECT
prelude_vars = {
"object": TY_OBJECT,
......@@ -46,50 +17,8 @@ prelude_vars = {
"tuple": TY_TUPLE,
"BuiltinFeature": TY_BUILTIN_FEATURE,
"CppType": TY_CPP_TYPE,
"Optional": TY_OPTIONAL,
"Callable": TY_CALLABLE,
"Task": TY_TASK
}
PRELUDE.vars.update({name: VarDecl(VarKind.LOCAL, ty.type_type()) for name, ty in prelude_vars.items()})
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
def make_module(name: str, scope: Scope) -> ResolvedConcreteType:
class CreatedType(UniqueTypeMixin, ResolvedConcreteType):
def name(self):
return name
ty = CreatedType()
for n, v in scope.vars.items():
ty.fields[n] = MemberDef(v.type, v.val, False)
return ty
def discover_module(path: Path, scope):
for child in sorted(path.iterdir()):
if child.is_dir():
mod_scope = PRELUDE.child(ScopeKind.GLOBAL)
discover_module(child, mod_scope)
scope.vars[child.name] = make_mod_decl(child.name, mod_scope)
elif child.name == "__init__.py":
StdlibVisitor(scope).visit(ast.parse(child.read_text()))
debug(f"Visited {child}")
elif child.suffix == ".py":
mod_scope = PRELUDE.child(ScopeKind.GLOBAL)
StdlibVisitor(mod_scope).visit(ast.parse(child.read_text()))
if child.stem[-1] == "_":
child = child.with_name(child.stem[:-1])
scope.vars[child.stem] = make_mod_decl(child.name, mod_scope)
debug(f"Visited {child}")
def make_mod_decl(child, mod_scope):
return VarDecl(VarKind.MODULE, make_module(child, mod_scope), {k: v.type for k, v in mod_scope.vars.items()})
discover_module(typon_std, PRELUDE)
debug("Stdlib visited!")
#exit()
\ No newline at end of file
......@@ -43,10 +43,8 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
elif ty_op is TY_CPP_TYPE:
assert len(args) == 1
return make_cpp_type(args[0])
# return ty_op(*args)
assert isinstance(ty_op, GenericType)
return ty_op.instantiate(args)
# return TypeOperator([self.visit(node.value)], self.visit(node.slice.value))
def visit_List(self, node: ast.List) -> BaseType:
return TypeListType([self.visit(elt) for elt in node.elts])
......
......@@ -5,7 +5,7 @@ from typing import Dict, Optional
from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, BuiltinFeatureType
from transpiler.phases.utils import NodeVisitorSeq, AnnotationName
PRELUDE = Scope.make_global()
......@@ -121,4 +121,4 @@ def get_next(iter_type):
return next_type
def is_builtin(x, feature):
return isinstance(x, BuiltinFeature) and x.val == feature
\ No newline at end of file
return isinstance(x, BuiltinFeatureType) and x.feature() == feature
\ No newline at end of file
......@@ -180,9 +180,24 @@ class UnknownNameError(CompileError):
For example:
{highlight('print(abcd)')}
{highlight('import foobar')}
"""
@dataclass
class UnknownModuleError(CompileError):
name: str
def __str__(self) -> str:
return f"Unknown module: {highlight(self.name)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to import a module that does not exist.
For example:
{highlight('import abcd')}
"""
@dataclass
class UnknownModuleMemberError(CompileError):
......
......@@ -4,12 +4,12 @@ import inspect
from itertools import zip_longest
from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next, is_builtin
from transpiler.phases.typing.exceptions import ArgumentCountMismatchError
from transpiler.phases.typing.exceptions import ArgumentCountMismatchError, TypeMismatchKind, TypeMismatchError
from transpiler.phases.typing.types import BaseType, TY_STR, TY_BOOL, TY_INT, TY_COMPLEX, TY_FLOAT, TY_NONE, \
ClassTypeType, ResolvedConcreteType, GenericType, CallableInstanceType, TY_LIST, TY_SET, TY_DICT, RuntimeValue, \
TypeVariable, TY_LAMBDA, TypeListType, MethodType
from transpiler.phases.typing.scope import ScopeKind, VarDecl, VarKind
from transpiler.utils import linenodata
DUNDER = {
......@@ -138,7 +138,8 @@ class ScoperExprVisitor(ScoperVisitor):
raise ArgumentCountMismatchError(ftype, arguments)
if a is None and ftype.is_variadic:
break
assert a.try_assign(b)
if not a.try_assign(b):
raise TypeMismatchError(a, b, TypeMismatchKind.DIFFERENT_TYPE)
return ftype.return_type
# if isinstance(ftype, TypeType):# and isinstance(ftype.type_object, UserType):
......@@ -222,7 +223,7 @@ class ScoperExprVisitor(ScoperVisitor):
ty.python_func_used = True
if isinstance(ty, MethodType):
if bound and field.in_class_def and type(field.val) != RuntimeValue:
return ty.remove_self()
return ty.remove_self(ltype)
return ty
......
import ast
from pathlib import Path
from logging import debug
from transpiler.phases.typing.scope import Scope, VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.types import MemberDef, ResolvedConcreteType, UniqueTypeMixin
def make_module(name: str, scope: Scope) -> ResolvedConcreteType:
class CreatedType(UniqueTypeMixin, ResolvedConcreteType):
def name(self):
return name
ty = CreatedType()
for n, v in scope.vars.items():
ty.fields[n] = MemberDef(v.type, v.val, False)
return ty
visited_modules = {}
def parse_module(mod_name: str, python_path: Path, scope=None):
path = python_path / mod_name
if not path.exists():
path = path.with_suffix(".py")
if not path.exists():
path = path.with_stem(mod_name + "_")
if not path.exists():
raise FileNotFoundError(f"Could not find {path}")
if path.is_dir():
real_path = path / "__init__.py"
else:
real_path = path
if mod := visited_modules.get(real_path.as_posix()):
return mod
from transpiler import PRELUDE
mod_scope = scope or PRELUDE.child(ScopeKind.GLOBAL)
if real_path.suffix == ".py":
from transpiler.phases.typing.stdlib import StdlibVisitor
StdlibVisitor(python_path, mod_scope).visit(ast.parse(real_path.read_text()))
else:
raise NotImplementedError(f"Unsupported file type {path.suffix}")
mod = make_mod_decl(mod_name, mod_scope)
visited_modules[real_path.as_posix()] = mod
return mod
# def process_module(mod_path: Path, scope):
# if mod := visited_modules.get(mod_path.as_posix()):
# return mod
#
# if mod_path.is_dir():
# mod_scope = scope.child(ScopeKind.GLOBAL)
# discover_module(mod_path, mod_scope)
# mod = make_mod_decl(mod_path.name, mod_scope)
# scope.vars[mod_path.name] = mod
# elif mod_path.name == "__init__.py":
# StdlibVisitor(mod_path, scope).visit(ast.parse(mod_path.read_text()))
# mod = None
# debug(f"Visited {mod_path}")
# elif mod_path.suffix == ".py":
# mod_scope = scope.child(ScopeKind.GLOBAL)
# StdlibVisitor(mod_path, mod_scope).visit(ast.parse(mod_path.read_text()))
# if mod_path.stem[-1] == "_":
# mod_path = mod_path.with_name(mod_path.stem[:-1])
# mod = make_mod_decl(mod_path.name, mod_scope)
# scope.vars[mod_path.stem] = mod
# debug(f"Visited {mod_path}")
# return mod
#
# def discover_module(dir_path: Path, scope):
# mod = make_mod_decl(dir_path.name, scope)
# for child in sorted(dir_path.iterdir()):
# # if child.name == "__init__.py":
# # StdlibVisitor(mod_scope).visit(ast.parse(child.read_text()))
# # else:
# # process_module(child, mod_scope)
# child_mod = process_module(child, scope)
# visited_modules[child.as_posix()] = child_mod
# return mod
def make_mod_decl(child, mod_scope):
return VarDecl(VarKind.LOCAL, make_module(child, mod_scope), {k: v.type for k, v in mod_scope.vars.items()})
\ No newline at end of file
......@@ -2,9 +2,11 @@ import ast
import dataclasses
from abc import ABCMeta
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, List, Dict, Callable
from logging import debug
from transpiler.phases.typing.modules import parse_module
from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE, is_builtin
......@@ -17,7 +19,7 @@ from transpiler.phases.typing.types import BaseType, BuiltinGenericType, Builtin
from transpiler.phases.utils import NodeVisitorSeq
def visit_generic_item(
visit_nongeneric: Callable[[Scope, ConcreteType], None],
visit_nongeneric: Callable[[Scope, ResolvedConcreteType], None],
node,
output_type: BuiltinGenericType,
scope: Scope,
......@@ -59,8 +61,8 @@ def visit_generic_item(
new_scope.declare_local(name, op_val.type_type())
case ast.TypeVarTuple(name):
new_scope.declare_local(name, TypeTupleType(list(args_iter)).type_type())
# for a, b in constraints:
# raise NotImplementedError()
for a, b in constraints:
assert b.try_assign(a)
# todo
new_output_type = instance_type()
visit_nongeneric(new_scope, new_output_type)
......@@ -73,9 +75,20 @@ def visit_generic_item(
@dataclass
class StdlibVisitor(NodeVisitorSeq):
python_path: Path
scope: Scope = field(default_factory=lambda: PRELUDE)
cur_class: Optional[ResolvedConcreteType] = None
typevars: Dict[str, BaseType] = field(default_factory=dict)
def resolve_module_import(self, name: str) -> VarDecl:
# tries = [
# self.python_path.parent / f"{name}.py",
# self.python_path.parent / name
# ]
# for path in tries:
# if path.exists():
# return path
# raise FileNotFoundError(f"Could not find module {name}")
return parse_module(name, self.python_path)
def expr(self) -> ScoperExprVisitor:
return ScoperExprVisitor(self.scope)
......@@ -85,7 +98,7 @@ class StdlibVisitor(NodeVisitorSeq):
self.visit(stmt)
def visit_Assign(self, node: ast.Assign):
self.scope.vars[node.targets[0].id] = VarDecl(VarKind.LOCAL, self.visit(node.value))
self.scope.vars[node.targets[0].id] = VarDecl(VarKind.LOCAL, self.anno().visit(node.value).type_type())
def visit_AnnAssign(self, node: ast.AnnAssign):
ty = self.anno().visit(node.annotation)
......@@ -95,10 +108,21 @@ class StdlibVisitor(NodeVisitorSeq):
self.scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, ty)
def visit_ImportFrom(self, node: ast.ImportFrom):
pass
module = self.resolve_module_import(node.module)
node.module_obj = module.type
for alias in node.names:
thing = module.val.get(alias.name)
if not thing:
from transpiler.phases.typing.exceptions import UnknownModuleMemberError
raise UnknownModuleMemberError(node.module, alias.name)
alias.item_obj = thing
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing)
def visit_Import(self, node: ast.Import):
pass
for alias in node.names:
mod = self.resolve_module_import(alias.name)
alias.module_obj = mod.type
self.scope.vars[alias.asname or alias.name] = mod
def visit_ClassDef(self, node: ast.ClassDef):
if existing := self.scope.get(node.name):
......@@ -109,10 +133,18 @@ class StdlibVisitor(NodeVisitorSeq):
NewType = base_class(node.name)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, NewType.type_type())
def visit_nongeneric(scope: Scope, output: ConcreteType):
def visit_nongeneric(scope: Scope, output: ResolvedConcreteType):
cl_scope = scope.child(ScopeKind.CLASS)
cl_scope.declare_local("Self", output.type_type())
visitor = StdlibVisitor(cl_scope, output)
visitor = StdlibVisitor(self.python_path, cl_scope, output)
bases = [self.anno().visit(base) for base in node.bases]
match bases:
case []:
pass
case [prot] if is_builtin(prot, "Protocol"):
output.is_protocol = True
case _:
raise NotImplementedError("parents not handled yet: " + ", ".join(map(ast.unparse, node.bases)))
for stmt in node.body:
visitor.visit(stmt)
......@@ -132,24 +164,17 @@ class StdlibVisitor(NodeVisitorSeq):
@dataclass(eq=False, init=False)
class InstanceType(CallableInstanceType):
def __init__(self):
super().__init__([], None, 0)
def __init__(self, **kwargs):
super().__init__(**{"parameters": None, "return_type": None, **kwargs})
def __str__(self):
return f"{node.name}{super().__str__()}"
'''
class arguments(__ast.AST):
""" arguments(arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, expr* kw_defaults, arg? kwarg, expr* defaults) """'''
args = node.args
assert args.posonlyargs == []
#assert args.vararg is None TODO
assert args.kwonlyargs == []
assert args.kw_defaults == []
assert args.kwarg is None
#assert args.defaults == [] TODO
for i, arg in enumerate(args.args):
arg: ast.arg
......@@ -163,7 +188,7 @@ class StdlibVisitor(NodeVisitorSeq):
arg.annotation = ast.Name(arg_name, ast.Load())
else:
if isinstance(arg.annotation, ast.Name) and (
arg.annotation.id == "Self" or
#arg.annotation.id == "Self" or
any(k.name == arg.annotation.id for k in node.type_params)
):
# annotation is type variable so we keep it
......@@ -186,13 +211,13 @@ class StdlibVisitor(NodeVisitorSeq):
return f"FuncTypeGen${node.name}"
if cur_class_ref is not None:
def remove_self(self, new_return_type = None):
def remove_self(self, self_type):
class BoundFuncType(UniqueTypeMixin, GenericType):
def name(self) -> str:
return f"BoundFuncType${node.name}"
def _instantiate(self, args: list[ConcreteType]) -> GenericInstanceType:
return NewType.instantiate(args).remove_self()
return NewType.instantiate(args).remove_self(self_type)
def __str__(self):
return str(self.instantiate_default())
......@@ -213,25 +238,6 @@ class StdlibVisitor(NodeVisitorSeq):
visit_generic_item(visit_nongeneric, node, NewType, self.scope, InstanceType, True)
# tc = node.type_comment # todo : lire les commetnaries de type pour les fonctions génériques sinon trouver autre chose
# arg_visitor = TypeAnnotationVisitor(self.scope.child(ScopeKind.FUNCTION), self.cur_class)
# arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args]
# ret_type = arg_visitor.visit(node.returns)
# ty = FunctionType(arg_types, ret_type)
# ty.typevars = arg_visitor.typevars
# if node.args.vararg:
# ty.variadic = True
# ty.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
# if self.cur_class:
# ty.is_method = True
# assert isinstance(self.cur_class, TypeType)
# if isinstance(self.cur_class.type_object, ABCMeta):
# self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
# else:
# self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars), ())
# self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert):
if isinstance(node.test, ast.UnaryOp) and isinstance(node.test.op, ast.Not):
oper = node.test.operand
......
import ast
import dataclasses
import enum
import typing
from abc import ABC, abstractmethod
......@@ -32,14 +34,21 @@ class UnifyMode:
UnifyMode.NORMAL = UnifyMode()
UnifyMode.EXACT = UnifyMode(False, False)
@dataclass
class BlockData[N]:
node: N
scope: "Scope"
@dataclass(eq=False)
class BaseType(ABC):
block_data: Optional[BlockData] = field(default=None, init=False)
def resolve(self) -> "BaseType":
return self
def type_type(self) -> "ClassTypeType":
return TY_TYPE.instantiate([self])
return TY_TYPE.instantiate([self.resolve()])
@abstractmethod
def name(self) -> str:
......@@ -69,12 +78,17 @@ class BaseType(ABC):
return (needle is haystack) or haystack.contains_internal(needle)
def try_assign(self, other: "BaseType") -> bool:
return self.resolve().try_assign_internal(other.resolve())
def try_assign_internal(self, other: "BaseType") -> bool:
try:
self.unify(other)
return True
except:
return False
def deref(self):
return self
......@@ -144,6 +158,7 @@ class ResolvedConcreteType(ConcreteType):
fields: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
parents: list["ResolvedConcreteType"] = field(default_factory=lambda: [TY_OBJECT], init=False)
is_protocol: bool = field(default=False, init=False)
def get_mro(self):
"""
......@@ -176,8 +191,26 @@ class ResolvedConcreteType(ConcreteType):
def inherits(self, parent: BaseType):
return self == parent or any(p.inherits(parent) for p in self.parents)
def can_receive(self, value: BaseType):
return self == value
def try_assign_internal(self, other: BaseType) -> bool:
if self == other:
return True
if super().try_assign_internal(other):
return True
if self.is_protocol:
if isinstance(other, TypeVariable):
other.unify(self) # ? maybe, we'll see if it works
return True
assert isinstance(other, ResolvedConcreteType)
for name, member in self.fields.items():
corresponding = other.fields.get(name)
if corresponding is None:
#raise ProtocolMismatchError(self, protocol, f"missing method {name}")
return False
return member.type.deref().try_assign(corresponding.type.deref())
return False
class UniqueTypeMixin:
def unify_internal(self, other: "BaseType", mode: UnifyMode):
......@@ -234,6 +267,8 @@ class GenericInstanceType(ResolvedConcreteType):
if not isinstance(other, GenericInstanceType):
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if self.generic_parent != other.generic_parent:
if not (isinstance(self, CallableInstanceType) and isinstance(other, CallableInstanceType)):
# methods have different generic parent types but we don't really care
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if len(self.generic_args) != len(other.generic_args):
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
......@@ -419,6 +454,9 @@ class TypeListType(ConcreteType):
class UnionInstanceType(GenericInstanceType):
types: list[ConcreteType]
def try_assign_internal(self, other: BaseType) -> bool:
return super().try_assign_internal(other) or any(t for t in self.types if t.try_assign(other))
class UnionType(UniqueTypeMixin, GenericType):
def name(self):
return "Union"
......@@ -440,7 +478,7 @@ TY_OPTIONAL = OptionalType()
@typing.runtime_checkable
class MethodType(typing.Protocol):
def remove_self(self, new_return_type = None) -> ...:
def remove_self(self, self_type) -> ...:
raise NotImplementedError()
@dataclass(eq=False)
......@@ -451,19 +489,31 @@ class CallableInstanceType(GenericInstanceType, MethodType):
is_variadic: bool = False
def __post_init__(self):
if self.optional_at is None:
if self.optional_at is None and self.parameters is not None:
self.optional_at = len(self.parameters)
def remove_self(self, new_return_type = None):
res = CallableInstanceType(self.parameters[1:], new_return_type or self.return_type, self.optional_at - 1, self.is_variadic)
res.generic_parent = self.generic_parent
res.generic_args = self.generic_args
return res
#return self.generic_parent.instantiate([TypeListType(self.parameters[1:]), new_return_type or self.return_type])
def remove_self(self, self_type):
assert self.parameters[0].try_assign(self_type)
return dataclasses.replace(
self,
parameters=self.parameters[1:],
optional_at=self.optional_at - 1,
)
def __str__(self):
return f"({", ".join(map(str, self.parameters))}{", *args" if self.is_variadic else ""}) -> {self.return_type}"
return f"({", ".join(map(str, self.parameters + ([", *args"] if self.is_variadic else [])))}) -> {self.return_type}"
def try_assign_internal(self, other: BaseType) -> bool:
if not isinstance(other, CallableInstanceType):
return False
self.unify(other)
return True
# @dataclass(eq=False)
# class UserFunctionInstance(CallableInstanceType):
# scope: "transpiler.phases.typing.scope.Scope" = None
# node: ast.FunctionDef = None
class CallableType(UniqueTypeMixin, GenericType):
def name(self):
......@@ -494,8 +544,22 @@ class LambdaType(CallableType):
TY_LAMBDA = LambdaType()
class BuiltinFeatureType(BuiltinType):
@abstractmethod
def feature(self):
pass
def __eq__(self, other):
return type(self) == type(other)
def make_builtin_feature(name: str):
class CreatedType(BuiltinType):
match name:
case "Optional":
return TY_OPTIONAL
case "Union":
return TY_UNION
case _:
class CreatedType(BuiltinFeatureType):
def name(self):
return name
......@@ -532,4 +596,3 @@ class ClassType(UniqueTypeMixin, GenericType):
return ClassTypeType(*args)
TY_TYPE = ClassType()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment