Commit f8de89ac authored by Tom Niget's avatar Tom Niget

Add new generics work

parent fd777ef8
.idea
cmake-build-*
\ No newline at end of file
cmake-build-*
typon.egg-info
\ No newline at end of file
......@@ -36,7 +36,7 @@ class socket:
pass
def getaddrinfo(host: str, port: int, family: int = 0, type: int = 0, proto: int = 0, flags: int = 0) -> \
Task[list[tuple[int, int, int, str, tuple[str, int] | str]]]:
Task[list[tuple[int, int, int, str, str]]]: # todo: incomplete return type
pass
AF_UNIX: int
\ No newline at end of file
# coding: utf-8
from typing import TypeVar, Generic
from dataclasses import dataclass
T = TypeVar("T")
@dataclass
class Thing(Generic[T]):
class Thing[T]:
x: T
def f(x: T):
def f[T](x: T):
pass
......
......@@ -196,7 +196,8 @@ def transpile(source, name: str, path=None):
for var in scope.vars.items():
debug(" " * (indent + 1), var)
# disp_scope(res.scope)
disp_scope(res.scope)
exit()
assert isinstance(res, ast.Module)
res.name = "__main__"
......
......@@ -3,48 +3,41 @@ from pathlib import Path
from logging import debug
from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, CppType, PyList, TypeType, Forked, Task, Future, PyIterator, TupleType, TypeOperator, BaseType, \
ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT, BuiltinFeature, UnionType, MemberDef
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, ResolvedConcreteType, \
MemberDef
PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
# "str": VarDecl(VarKind.LOCAL, TY_TYPE, TY_STR),
# "bool": VarDecl(VarKind.LOCAL, TY_TYPE, TY_BOOL),
# "complex": VarDecl(VarKind.LOCAL, TY_TYPE, TY_COMPLEX),
# "None": VarDecl(VarKind.LOCAL, TY_NONE, None),
# "Callable": VarDecl(VarKind.LOCAL, TY_TYPE, FunctionType),
# "TypeVar": VarDecl(VarKind.LOCAL, TY_TYPE, TypeVariable),
# "CppType": VarDecl(VarKind.LOCAL, TY_TYPE, CppType),
# "list": VarDecl(VarKind.LOCAL, TY_TYPE, PyList),
"int": VarDecl(VarKind.LOCAL, TypeType(TY_INT)),
"float": VarDecl(VarKind.LOCAL, TypeType(TY_FLOAT)),
"str": VarDecl(VarKind.LOCAL, TypeType(TY_STR)),
"bytes": VarDecl(VarKind.LOCAL, TypeType(TY_BYTES)),
"bool": VarDecl(VarKind.LOCAL, TypeType(TY_BOOL)),
"complex": VarDecl(VarKind.LOCAL, TypeType(TY_COMPLEX)),
"None": VarDecl(VarKind.LOCAL, TypeType(TY_NONE)),
"Callable": VarDecl(VarKind.LOCAL, TypeType(FunctionType)),
#"TypeVar": VarDecl(VarKind.LOCAL, TypeType(TypeVariable)),
"CppType": VarDecl(VarKind.LOCAL, TypeType(CppType)),
"list": VarDecl(VarKind.LOCAL, TypeType(PyList)),
"dict": VarDecl(VarKind.LOCAL, TypeType(PyDict)),
"Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)),
"Task": VarDecl(VarKind.LOCAL, TypeType(Task)),
"Future": VarDecl(VarKind.LOCAL, TypeType(Future)),
"Iterator": VarDecl(VarKind.LOCAL, TypeType(PyIterator)),
"tuple": VarDecl(VarKind.LOCAL, TypeType(TupleType)),
"slice": VarDecl(VarKind.LOCAL, TypeType(TY_SLICE)),
"object": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
"BuiltinFeature": VarDecl(VarKind.LOCAL, TypeType(BuiltinFeature)),
"Any": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
"Optional": VarDecl(VarKind.LOCAL, TypeType(lambda x: UnionType(x, TY_NONE))),
})
# PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TypeType(TY_INT)),
# "float": VarDecl(VarKind.LOCAL, TypeType(TY_FLOAT)),
# "str": VarDecl(VarKind.LOCAL, TypeType(TY_STR)),
# "bytes": VarDecl(VarKind.LOCAL, TypeType(TY_BYTES)),
# "bool": VarDecl(VarKind.LOCAL, TypeType(TY_BOOL)),
# "complex": VarDecl(VarKind.LOCAL, TypeType(TY_COMPLEX)),
# "None": VarDecl(VarKind.LOCAL, TypeType(TY_NONE)),
# "Callable": VarDecl(VarKind.LOCAL, TypeType(FunctionType)),
# #"TypeVar": VarDecl(VarKind.LOCAL, TypeType(TypeVariable)),
# "CppType": VarDecl(VarKind.LOCAL, TypeType(CppType)),
# "list": VarDecl(VarKind.LOCAL, TypeType(PyList)),
# "dict": VarDecl(VarKind.LOCAL, TypeType(PyDict)),
# "Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)),
# "Task": VarDecl(VarKind.LOCAL, TypeType(Task)),
# "Future": VarDecl(VarKind.LOCAL, TypeType(Future)),
# "Iterator": VarDecl(VarKind.LOCAL, TypeType(PyIterator)),
# "tuple": VarDecl(VarKind.LOCAL, TypeType(TupleType)),
# "slice": VarDecl(VarKind.LOCAL, TypeType(TY_SLICE)),
# "object": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
# "BuiltinFeature": VarDecl(VarKind.LOCAL, TypeType(BuiltinFeature)),
# "Any": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
# "Optional": VarDecl(VarKind.LOCAL, TypeType(lambda x: UnionType(x, TY_NONE))),
# })
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
def make_module(name: str, scope: Scope) -> BaseType:
ty = ModuleType([], f"{name}")
def make_module(name: str, scope: Scope) -> ResolvedConcreteType:
class CreatedType(ResolvedConcreteType):
def __str__(self):
return name
ty = CreatedType()
for n, v in scope.vars.items():
ty.fields[n] = MemberDef(v.type, v.val, False)
return ty
......
......@@ -4,33 +4,20 @@ from dataclasses import dataclass, field
from typing import Optional, List
from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_NONE, TypeType, TY_SELF, TypeVariable, UnionType
from transpiler.phases.typing.types import BaseType, TY_NONE, TypeVariable, TY_TYPE, ResolvedConcreteType, TypeListType
from transpiler.phases.utils import NodeVisitorSeq
@dataclass
class TypeAnnotationVisitor(NodeVisitorSeq):
scope: Scope
cur_class: Optional[TypeType] = None
typevars: List[TypeVariable] = field(default_factory=list)
def visit_str(self, node: str) -> BaseType:
if node in ("Self", "self") and self.cur_class:
if isinstance(self.cur_class.type_object, abc.ABCMeta) or self.cur_class.type_object.is_protocol_gen or self.cur_class.type_object.is_protocol:
return TY_SELF
else:
return self.cur_class.type_object
if existing := self.scope.get(node):
ty = existing.type
if isinstance(ty, TypeType) and isinstance(ty.type_object, TypeVariable):
if existing is not self.scope.vars.get(node, None):
# Type variable from outer scope, so we copy it
ty = TypeVariable(ty.type_object.name)
self.scope.declare_local(node, TypeType(ty)) # todo: unneeded?
self.typevars.append(ty)
if isinstance(ty, TypeType):
return ty.type_object
return ty
assert isinstance(ty, ResolvedConcreteType)
assert ty.inherits(TY_TYPE)
return ty.inner_type
from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(node)
......@@ -46,22 +33,24 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
raise NotImplementedError
def visit_Subscript(self, node: ast.Subscript) -> BaseType:
ty_op = self.visit(node.value)
args = list(node.slice.elts) if type(node.slice) == ast.Tuple else [node.slice]
args = [self.visit(arg) for arg in args]
return ty_op(*args)
# ty_op = self.visit(node.value)
# args = list(node.slice.elts) if type(node.slice) == ast.Tuple else [node.slice]
# args = [self.visit(arg) for arg in args]
# return ty_op(*args)
raise NotImplementedError()
# return TypeOperator([self.visit(node.value)], self.visit(node.slice.value))
def visit_List(self, node: ast.List) -> List[BaseType]:
return [self.visit(elt) for elt in node.elts]
def visit_List(self, node: ast.List) -> BaseType:
return TypeListType([self.visit(elt) for elt in node.elts])
def visit_Attribute(self, node: ast.Attribute) -> BaseType:
left = self.visit(node.value)
res = left.fields[node.attr].type
assert isinstance(res, TypeType)
return res.type_object
raise NotImplementedError()
# left = self.visit(node.value)
# res = left.fields[node.attr].type
# assert isinstance(res, TypeType)
# return res.type_object
def visit_BinOp(self, node: ast.BinOp) -> BaseType:
if isinstance(node.op, ast.BitOr):
return UnionType(self.visit(node.left), self.visit(node.right))
# if isinstance(node.op, ast.BitOr):
# return UnionType(self.visit(node.left), self.visit(node.right))
raise NotImplementedError(node.op)
......@@ -5,8 +5,7 @@ from typing import Dict, Optional
from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType, BuiltinFeature, FunctionType, \
Promise, PromiseKind
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE
from transpiler.phases.utils import NodeVisitorSeq, AnnotationName
PRELUDE = Scope.make_global()
......@@ -15,14 +14,13 @@ PRELUDE = Scope.make_global()
class ScoperVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL))
root_decls: Dict[str, VarDecl] = field(default_factory=dict)
cur_class: Optional[TypeType] = None
def expr(self) -> "ScoperExprVisitor":
from transpiler.phases.typing.expr import ScoperExprVisitor
return ScoperExprVisitor(self.scope, self.root_decls)
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope, self.cur_class)
return TypeAnnotationVisitor(self.scope)
def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType:
res = self.anno().visit(expr) if expr else TypeVariable()
......
......@@ -302,4 +302,21 @@ class MissingReturnError(CompileError):
{highlight(' if x > 0:')}
{highlight(' return 1')}
{highlight(' # if x <= 0, the function returns nothing')}
"""
@dataclass
class InconsistentMroError(CompileError):
bases: list[BaseType]
def __str__(self) -> str:
return f"Cannot create a cnossitent method resolution order (MRO) for bases {'\n'.join(map(highlight, self.bases))}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that a class has an inconsistent method resolution order (MRO).
For example:
{highlight('class A: pass')}
{highlight('class B(A): pass')}
{highlight('class C(B, A): pass')}
"""
\ No newline at end of file
......@@ -5,9 +5,7 @@ from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next, is_builtin
from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \
TY_SLICE, TY_FLOAT, RuntimeValue, BuiltinFeature
from transpiler.phases.typing.types import BaseType
from transpiler.utils import linenodata
DUNDER = {
......
......@@ -2,16 +2,67 @@ import ast
import dataclasses
from abc import ABCMeta
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from typing import Optional, List, Dict, Callable
from logging import debug
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE, is_builtin
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable, \
MemberDef, BuiltinFeature
from transpiler.phases.typing.types import BaseType, BuiltinGenericType, BuiltinType, create_builtin_generic_type, \
create_builtin_type, ConcreteType, GenericInstanceType, TypeListType, TypeTupleType, GenericParameter, \
GenericParameterKind, TypeVariable
from transpiler.phases.utils import NodeVisitorSeq
def visit_generic_item(
visit_nongeneric: Callable[[Scope, ConcreteType], None],
node,
output_type: BuiltinGenericType,
scope: Scope,
instance_type = None):
if node.type_params:
output_type.parameters = []
for param in node.type_params:
match param:
case ast.TypeVar(_, _):
kind = GenericParameterKind.NORMAL
case ast.ParamSpec(_):
kind = GenericParameterKind.PARAMETERS
case ast.TypeVarTuple(_):
kind = GenericParameterKind.TUPLE
output_type.parameters.append(GenericParameter(param.name, kind))
if instance_type is None:
class instance_type(GenericInstanceType):
pass
instance_type.__name__ = f"GenericInstance${node.name}"
def instantiate(args: list[ConcreteType]) -> GenericInstanceType:
new_scope = scope.child(ScopeKind.GLOBAL)
args_iter = iter(args)
constraints = []
anno = TypeAnnotationVisitor(new_scope)
for param in node.type_params:
op_val = next(args_iter, None)
if op_val is None:
op_val = TypeVariable()
match param:
case ast.TypeVar(name, bound):
new_scope.declare_local(name, op_val)
if bound is not None:
constraints.append((op_val, anno.visit(bound)))
case ast.ParamSpec(name):
assert isinstance(op_val, TypeListType)
new_scope.declare_local(name, op_val)
case ast.TypeVarTuple(name):
new_scope.declare_local(name, TypeTupleType(list(args_iter)))
for a, b in constraints:
raise NotImplementedError()
new_output_type = instance_type()
visit_nongeneric(new_scope, new_output_type)
return new_output_type
output_type.instantiate_ = instantiate
else:
visit_nongeneric(scope, output_type)
@dataclass
......@@ -34,10 +85,7 @@ class StdlibVisitor(NodeVisitorSeq):
ty = self.anno().visit(node.annotation)
if self.cur_class:
assert isinstance(self.cur_class, TypeType)
if isinstance(self.cur_class.type_object, ABCMeta):
raise NotImplementedError
else:
self.cur_class.type_object.fields[node.target.id] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars))
self.cur_class.type_object.fields[node.target.id] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars))
self.scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, ty)
def visit_ImportFrom(self, node: ast.ImportFrom):
......@@ -48,71 +96,54 @@ class StdlibVisitor(NodeVisitorSeq):
def visit_ClassDef(self, node: ast.ClassDef):
if existing := self.scope.get(node.name):
ty = existing.type
NewType = existing.type
else:
class BuiltinClassType(TypeOperator):
def __init__(self, *args):
super().__init__(args, node.name, is_reference=True)
ty = TypeType(BuiltinClassType)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
typevars = []
for b in node.bases:
if isinstance(b, ast.Subscript):
if isinstance(b.slice, ast.Name):
sliceval = [b.slice.id]
elif isinstance(b.slice, ast.Tuple):
sliceval = [n.id for n in b.slice.elts]
if isinstance(b.value, ast.Name) and b.value.id == "Generic":
typevars = sliceval
elif isinstance(b.value, ast.Name) and b.value.id == "Protocol":
typevars = sliceval
ty.type_object.is_protocol_gen = True
else:
idxs = [typevars.index(v) for v in sliceval]
parent = self.visit(b.value)
assert isinstance(parent, TypeType)
assert isinstance(ty.type_object, ABCMeta)
ty.type_object.gen_parents.append(lambda selfvars: parent.type_object(*[selfvars[i] for i in idxs]))
else:
if isinstance(b, ast.Name) and b.id == "Protocol":
ty.type_object.is_protocol_gen = True
else:
parent = self.visit(b)
assert isinstance(parent, TypeType)
if isinstance(ty.type_object, ABCMeta):
ty.type_object.gen_parents.append(parent.type_object)
else:
ty.type_object.parents.append(parent.type_object)
if not typevars and not existing:
ty.type_object = ty.type_object()
cl_scope = self.scope.child(ScopeKind.CLASS)
visitor = StdlibVisitor(cl_scope, ty)
for var in typevars:
visitor.typevars[var] = TypeType(TypeVariable(var))
for stmt in node.body:
visitor.visit(stmt)
base_class = create_builtin_generic_type if node.type_params else create_builtin_type
NewType = base_class(node.name)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, NewType.type_type())
def visit_nongeneric(scope, output: ConcreteType):
cl_scope = scope.child(ScopeKind.CLASS)
visitor = StdlibVisitor(cl_scope, output)
for stmt in node.body:
visitor.visit(stmt)
visit_generic_item(visit_nongeneric, node, NewType, self.scope)
def visit_Pass(self, node: ast.Pass):
pass
def visit_FunctionDef(self, node: ast.FunctionDef):
tc = node.type_comment # todo : lire les commetnaries de type pour les fonctions génériques sinon trouver autre chose
arg_visitor = TypeAnnotationVisitor(self.scope.child(ScopeKind.FUNCTION), self.cur_class)
arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args]
ret_type = arg_visitor.visit(node.returns)
ty = FunctionType(arg_types, ret_type)
ty.typevars = arg_visitor.typevars
if node.args.vararg:
ty.variadic = True
ty.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
if self.cur_class:
ty.is_method = True
assert isinstance(self.cur_class, TypeType)
if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
else:
self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars), ())
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_nongeneric(scope, output: ConcreteType):
cl_scope = scope.child(ScopeKind.CLASS)
visitor = StdlibVisitor(cl_scope, output)
for stmt in node.body:
visitor.visit(stmt)
'''
class arguments(__ast.AST):
""" arguments(arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, expr* kw_defaults, arg? kwarg, expr* defaults) """'''
visit_generic_item(visit_nongeneric, node, NewType, self.scope)
# tc = node.type_comment # todo : lire les commetnaries de type pour les fonctions génériques sinon trouver autre chose
# arg_visitor = TypeAnnotationVisitor(self.scope.child(ScopeKind.FUNCTION), self.cur_class)
# arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args]
# ret_type = arg_visitor.visit(node.returns)
# ty = FunctionType(arg_types, ret_type)
# ty.typevars = arg_visitor.typevars
# if node.args.vararg:
# ty.variadic = True
# ty.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
# if self.cur_class:
# ty.is_method = True
# assert isinstance(self.cur_class, TypeType)
# if isinstance(self.cur_class.type_object, ABCMeta):
# self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
# else:
# self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars), ())
# self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert):
if isinstance(node.test, ast.UnaryOp) and isinstance(node.test.op, ast.Not):
......
import ast
import dataclasses
import enum
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from itertools import zip_longest
from typing import Dict, Optional, List, ClassVar, Callable
from transpiler.utils import highlight
from typing import Dict, Optional, Callable
def get_default_parents():
......@@ -14,596 +10,279 @@ def get_default_parents():
return [obj]
return []
class RuntimeValue:
pass
@dataclass
class MemberDef:
type: "BaseType"
val: typing.Any = RuntimeValue()
in_class_def: bool = True
@dataclass(eq=False)
class BaseType(ABC):
pass
@dataclass(eq=False)
class TypeVariable(BaseType):
pass
@dataclass
class ConcreteType(BaseType):
fields: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
@dataclass
class GenericType(BaseType):
parameters: list[str]
def instanciate(self, args: list[BaseType]) -> BaseType:
raise NotImplementedError()
@dataclass
class UnifyMode:
search_hierarchy: bool = True
match_protocol: bool = True
UnifyMode.NORMAL = UnifyMode()
UnifyMode.EXACT = UnifyMode(False, False)
@dataclass(eq=False)
class BaseType(ABC):
#members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
#methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
fields: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
parents: List["BaseType"] = field(default_factory=get_default_parents, init=False)
typevars: List["TypeVariable"] = field(default_factory=list, init=False)
#static_members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
def get_members(self):
return {n: m for n, m in self.fields.items() if type(m.val) is RuntimeValue}
def get_parents(self) -> List["BaseType"]:
return self.parents
def iter_hierarchy_recursive(self) -> typing.Iterator["BaseType"]:
cache = set()
from queue import Queue
queue = Queue()
queue.put(self)
while not queue.empty():
cur = queue.get()
yield cur
if cur in cache:
continue
cache.add(cur)
if cur == TY_OBJECT:
continue
for p in cur.get_parents():
queue.put(p)
def inherits_from(self, other: "BaseType") -> bool:
from transpiler.exceptions import CompileError
for parent in self.iter_hierarchy_recursive():
try:
parent.unify(other, UnifyMode.EXACT)
except CompileError:
pass
else:
return True
return False
def resolve(self) -> "BaseType":
return self
@abstractmethod
def unify_internal(self, other: "BaseType", mode: UnifyMode):
pass
def unify(self, other: "BaseType", mode = UnifyMode.NORMAL):
a, b = self.resolve(), other.resolve()
__TB__ = f"unifying {highlight(a)} and {highlight(b)}"
if isinstance(b, TypeVariable):
a, b = b, a
a.unify_internal(b, mode)
def contains(self, other: "BaseType") -> bool:
needle, haystack = other.resolve(), self.resolve()
return (needle is haystack) or haystack.contains_internal(needle)
@abstractmethod
def contains_internal(self, other: "BaseType") -> bool:
pass
# @abstractmethod
# def clone(self) -> "BaseType":
# pass
def gen_sub(self, this: "BaseType", typevars: Dict[str, "BaseType"], cache=None) -> "Self":
return self
def to_list(self) -> List["BaseType"]:
return [self]
T = typing.TypeVar("T")
class MagicType(BaseType, typing.Generic[T]):
val: T
def __init__(self, val: T):
super().__init__()
self.val = val
def unify_internal(self, other: "BaseType", mode: UnifyMode):
if type(self) != type(other) or self.val != other.val:
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
def contains_internal(self, other: "BaseType") -> bool:
return False
def __str__(self):
return str(self.val)
def clone(self) -> "BaseType":
return type(self)(self.val)
class BuiltinFeature(MagicType):
pass
def type_type(self) -> "ClassTypeType":
return TY_TYPE.instantiate([self])
cur_var = 0
def next_var_id():
global cur_var
cur_var += 1
return cur_var
@dataclass(eq=False)
class ConcreteType(BaseType):
"""
A concrete type is the type of a concrete value.
It has fields and a list of parent concrete types.
Examples: int, str, list[int]
"""
@dataclass(eq=False)
class TypeVariable(BaseType):
class TypeVariable(ConcreteType):
name: str = field(default_factory=lambda: next_var_id())
resolved: Optional[BaseType] = None
patch_attrs: dict = field(default_factory=dict)
def __setattr__(self, key, value):
if "patch_attrs" in self.__dict__ and key not in self.__dict__:
self.patch_attrs[key] = value
else:
super().__setattr__(key, value)
resolved: Optional[ConcreteType] = None
def __getattr__(self, item):
if "patch_attrs" in self.__dict__ and item in self.patch_attrs:
return self.patch_attrs[item]
raise AttributeError(item)
def resolve(self) -> ConcreteType:
if self.resolved is None:
return self
return self.resolved.resolve()
def __str__(self):
if self.resolved is None:
#return f"TypeVar[\"{self.name}\"]"
# return f"TypeVar[\"{self.name}\"]"
return f"_{self.name}"
return str(self.resolved)
def resolve(self) -> BaseType:
def __eq__(self, other):
if not isinstance(other, BaseType):
return False
if self.resolved is None:
return self
return self.resolved.resolve()
def unify_internal(self, other: BaseType, mode: UnifyMode):
if self is not other:
if other.contains(self):
from transpiler.phases.typing.exceptions import RecursiveTypeUnificationError
raise RecursiveTypeUnificationError(self, other)
self.resolved = other
for k, v in self.patch_attrs.items():
setattr(other, k, v)
def contains_internal(self, other: BaseType) -> bool:
return self.resolve() is other.resolve()
def gen_sub(self, this: "BaseType", typevars, cache=None) -> "Self":
if match := typevars.get(self.name):
return match
return self
GenMethodFactory = Callable[["BaseType"], "FunctionType"]
return self == other
return self.resolved == other.resolve()
@dataclass(eq=False)
class TypeOperator(BaseType, ABC):
args: List[BaseType]
name: str = None
variadic: bool = False
optional_at: Optional[int] = None
gen_methods: ClassVar[Dict[str, GenMethodFactory]] = {}
gen_parents: ClassVar[List[BaseType]] = []
is_protocol: bool = False
is_protocol_gen: ClassVar[bool] = False
match_cache: set["TypeOperator"] = field(default_factory=set, init=False)
is_reference: bool = False
is_intermediary: bool = False
@staticmethod
def make_type(name: str):
class BuiltinType(TypeOperator):
def __init__(self):
super().__init__([], name)
return BuiltinType()
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.gen_methods = {}
cls.gen_parents = []
def __post_init__(self):
assert all(x is not None for x in self.args)
if self.name is None:
self.name = self.__class__.__name__
for name, factory in self.gen_methods.items():
self.fields[name] = MemberDef(factory(self), ())
for gp in self.gen_parents:
if not isinstance(gp, BaseType):
gp = gp(self.args)
self.parents.append(gp)
self.fields = {**gp.fields, **self.fields}
self.is_protocol = self.is_protocol or self.is_protocol_gen
self._add_default_eq()
def _add_default_eq(self):
if "__eq__" not in self.fields:
if "DEFAULT_EQ" in globals():
self.fields["__eq__"] = MemberDef(DEFAULT_EQ)
def matches_protocol(self, protocol: "TypeOperator"):
if hash(protocol) in self.match_cache:
return
from transpiler.phases.typing.exceptions import ProtocolMismatchError, TypeMismatchError
try:
dupl = protocol.gen_sub(self, {v.name: (TypeVariable(v.name) if isinstance(v.resolve(), TypeVariable) else v) for v in protocol.args})
self.match_cache.add(hash(protocol))
for name, ty in dupl.fields.items():
if name == "__eq__":
continue
if name not in self.fields:
raise ProtocolMismatchError(self, protocol, f"missing method {name}")
corresp = self.fields[name].type
corresp.remove_self().unify(ty.type.remove_self())
except TypeMismatchError as e:
if hash(protocol) in self.match_cache:
self.match_cache.remove(hash(protocol))
raise ProtocolMismatchError(self, protocol, e)
def unify_internal(self, other: BaseType, mode: UnifyMode):
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
# TODO(zdimension): this is really broken... but it would be nice
# if from_node := next(filter(None, (getattr(x, "from_node", None) for x in (other, self))), None):
# __TB_NODE__ = from_node
if not isinstance(other, TypeOperator):
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if mode.match_protocol:
if other.is_protocol and not self.is_protocol:
return other.unify_internal(self, mode)
if self.is_protocol and not other.is_protocol:
return other.matches_protocol(self) # TODO: doesn't print the correct type in the error message
assert self.is_protocol == other.is_protocol
if type(self) != type(other): # and ((TY_NONE not in {self, other}) or isinstance(({self, other} - {TY_NONE}).pop(), UnionType)):
if mode.search_hierarchy:
if self.inherits_from(other) or other.inherits_from(self):
return
# for parent in other.get_parents():
# try:
# self.unify(parent)
# except TypeMismatchError:
# pass
# else:
# return
# for parent in self.get_parents():
# try:
# parent.unify(other)
# except TypeMismatchError:
# pass
# else:
# return
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if len(self.args) < len(other.args):
return other.unify_internal(self, mode)
if True or len(self.args) == 0: # todo: why check len?
if self.name != other.name:
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
for i, (a, b) in enumerate(zip_longest(self.args, other.args)):
if a is None and self.variadic or b is None and other.variadic:
continue
if a is not None and b is None:
if self.optional_at is not None and i >= self.optional_at:
continue
else:
if getattr(other, "is_python_func", False):
other.args.append(a)
continue
else:
from transpiler.phases.typing.exceptions import ArgumentCountMismatchError
raise ArgumentCountMismatchError(*sorted((self, other), key=lambda x: x.is_intermediary))
if isinstance(a, BaseType) and isinstance(b, BaseType):
a.unify(b)
else:
if a != b:
raise TypeMismatchError(a, b, TypeMismatchKind.DIFFERENT_TYPE)
def contains_internal(self, other: "BaseType") -> bool:
return any(arg.contains(other) for arg in self.args)
def __str__(self):
return self.name + (f"<{', '.join(map(str, self.args))}>" if self.args else "")
class ResolvedConcreteType(ConcreteType):
"""
A concrete type is the type of a concrete value.
def __repr__(self):
return self.__str__()
It has fields and a list of parent concrete types.
def __hash__(self):
return hash((self.name, tuple(self.args)))
Examples: int, str, list[int]
"""
def gen_sub(self, this: BaseType, typevars, cache=None) -> "Self":
cache = cache or {}
if me := cache.get(self):
return me
if len(self.args) == 0:
return self
assert all(x is not None for x in self.args)
res = object.__new__(self.__class__) # todo: ugly... should make a clone()
cache[self] = res
if isinstance(this, TypeOperator) and not isinstance(this, FunctionType):
vardict = dict(zip(typevars.keys(), this.args))
else:
vardict = typevars
for k, v in self.__dict__.items():
setattr(res, k, v)
res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args]
res.fields = {k: dataclasses.replace(v, type=v.type.gen_sub(this, vardict, cache)) for k, v in self.fields.items()}
res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents]
#res.is_protocol = self.is_protocol
return res
fields: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
parents: list["ResolvedConcreteType"] = field(default_factory=lambda: [TY_OBJECT], init=False)
def to_list(self) -> List["BaseType"]:
return [self, *self.args]
def get_mro(self):
"""
Performs linearization according to the MRO spec.
@dataclass
class ModuleType(TypeOperator):
is_python: bool = False
https://www.python.org/download/releases/2.3/mro/
"""
def merge(*lists):
lists = [l for l in lists if len(l) > 0]
for i, l in enumerate(lists):
first = l[0]
for j, l2 in enumerate(lists):
if j == i:
continue
if first in l2:
break
else:
return [first] + merge(*[x[1:] for x in lists if x[0] != first])
# unable to find a next element
from transpiler.phases.typing.exceptions import InconsistentMroError
raise InconsistentMroError(self.parents)
class FunctionType(TypeOperator):
is_python_func: bool = False
python_func_used: bool = False
is_method: bool = False
return [self] + merge(*[p.get_mro() for p in self.parents], self.parents)
def __iter__(self):
x = 5
pass
return iter([str(self)])
def inherits(self, parent: BaseType):
return self == parent or any(p.inherits(parent) for p in self.parents)
def __init__(self, args: List[BaseType], ret: BaseType):
super().__init__([ret, *args])
@dataclass(eq=False)
class BuiltinType(ResolvedConcreteType):
pass
@property
def parameters(self):
return self.args[1:]
@property
def return_type(self):
return self.args[0]
@dataclass(eq=False, init=False)
class GenericInstanceType(ResolvedConcreteType):
"""
An instance of a generic type.
@return_type.setter
def return_type(self, value):
self.args[0] = value
Examples: list[int], dict[str, object], Callable[[int, int], int]
"""
generic_parent: "GenericType" = field(init=False)
generic_args: list[ConcreteType] = field(init=False)
def __str__(self):
ret, *args = map(str, self.args)
if self.optional_at is not None:
args = args[:self.optional_at] + [f"{x}=..." for x in args[self.optional_at:]]
if self.variadic:
args.append("*args")
if args:
args = f"{', '.join(args)}"
else:
args = ""
return f"({args}) -> {ret}"
def remove_self(self):
res = FunctionType(self.parameters[1:], self.return_type)
res.is_method = self.is_method
res.variadic = self.variadic
res.optional_at = self.optional_at - 1 if self.optional_at is not None else None
return res
def __init__(self):
super().__init__()
def inherits(self, parent: BaseType):
return self.generic_parent == parent or super().inherits(parent)
class CppType(TypeOperator):
def __init__(self, name: str):
super().__init__([name], name)
def __eq__(self, other):
if isinstance(other, GenericInstanceType):
return self.generic_parent == other.generic_parent and self.generic_args == other.generic_args
return False
def __str__(self):
return self.name
return f"{self.generic_parent}[{', '.join(map(str, self.generic_args))}]"
class Union(TypeOperator):
def __init__(self, left: BaseType, right: BaseType):
super().__init__([left, right], "Union")
class TypeType(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "Type")
@property
def type_object(self) -> BaseType:
return self.args[0]
@type_object.setter
def type_object(self, value: BaseType):
self.args[0] = value
class GenericParameterKind(enum.Enum):
NORMAL = enum.auto()
TUPLE = enum.auto()
PARAMETERS = enum.auto()
@dataclass
class GenericParameter:
name: str
kind: GenericParameterKind
TY_OBJECT = TypeOperator.make_type("object")
TY_SELF = TypeOperator.make_type("Self")
def self_gen_sub(this, typevars, _):
if this is not None:
return this
return TY_SELF
TY_SELF.gen_sub = self_gen_sub
TY_BOOL = TypeOperator.make_type("bool")
TY_TYPE = TypeOperator.make_type("type")
TY_INT = TypeOperator.make_type("int")
TY_FLOAT = TypeOperator.make_type("float")
TY_STR = TypeOperator.make_type("str")
TY_BYTES = TypeOperator.make_type("bytes")
TY_COMPLEX = TypeOperator.make_type("complex")
TY_NONE = TypeOperator.make_type("NoneType")
#TY_MODULE = TypeOperator([], "module")
TY_VARARG = TypeOperator.make_type("vararg")
TY_SLICE = TypeOperator.make_type("slice")
@dataclass
class GenericConstraint:
left: ResolvedConcreteType
right: ResolvedConcreteType
class PyList(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "list")
@dataclass(eq=False, init=False)
class GenericType(BaseType):
parameters: list[GenericParameter]
@property
def element_type(self):
return self.args[0]
def constraints(self, args: list[ConcreteType]) -> list[GenericConstraint]:
return []
@abstractmethod
def _instantiate(self, args: list[ConcreteType]) -> GenericInstanceType:
raise NotImplementedError()
class PySet(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "set")
def instantiate(self, args: list[ConcreteType]) -> GenericInstanceType:
res = self._instantiate(args)
res.generic_args = args
res.generic_parent = self
return res
@property
def element_type(self):
return self.args[0]
@dataclass(eq=False, init=False)
class BuiltinGenericType(GenericType):
constraints_: Callable[[list[ConcreteType]], list[GenericConstraint]]
instantiate_: Callable[[list[ConcreteType]], GenericInstanceType]
class PyDict(TypeOperator):
def __init__(self, key: BaseType, value: BaseType):
super().__init__([key, value], "dict")
def constraints(self, args: list[ConcreteType]) -> list[GenericConstraint]:
return self.constraints_(args)
@property
def key_type(self):
return self.args[0]
def _instantiate(self, args: list[ConcreteType]) -> GenericInstanceType:
return self.instantiate_(args)
@property
def value_type(self):
return self.args[1]
def create_builtin_type(name: str):
class CreatedType(BuiltinType):
def __str__(self):
return name
CreatedType.__name__ = f"BuiltinType${name}"
res = CreatedType()
return res
class PyIterator(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "iter")
@property
def element_type(self):
return self.args[0]
TY_OBJECT = None
TY_OBJECT = create_builtin_type("object")
TY_OBJECT.parents = []
TY_BOOL = create_builtin_type("bool")
TY_INT = create_builtin_type("int")
TY_FLOAT = create_builtin_type("float")
TY_STR = create_builtin_type("str")
TY_BYTES = create_builtin_type("bytes")
TY_COMPLEX = create_builtin_type("complex")
TY_NONE = create_builtin_type("NoneType")
class TupleType(TypeOperator):
def __init__(self, *args: List[BaseType]):
super().__init__(args, "tuple")
def unimpl(*args, **kwargs):
raise NotImplementedError()
class PromiseKind(Enum):
TASK = 0
JOIN = 1
FUTURE = 2
FORKED = 3
GENERATOR = 4
def create_builtin_generic_type(name: str):
class CreatedType(BuiltinGenericType):
def __str__(self):
return name
CreatedType.__name__ = f"BuiltinGenericType${name}"
res = CreatedType()
return res
class Promise(TypeOperator, ABC):
def __init__(self, ret: BaseType, kind: PromiseKind):
super().__init__([ret, MagicType(kind)])
TY_LIST = create_builtin_generic_type("list")
TY_SET = create_builtin_generic_type("set")
TY_DICT = create_builtin_generic_type("dict")
TY_TUPLE = create_builtin_generic_type("tuple")
@property
def return_type(self) -> BaseType:
return self.args[0]
@dataclass(unsafe_hash=False)
class TypeTupleType(ConcreteType):
"""
Special type used to represent a tuple of types.
@property
def kind(self) -> PromiseKind:
return self.args[1].val
Used in tuple types: tuple[int, str, bool]
@kind.setter
def kind(self, value: PromiseKind):
if value == PromiseKind.GENERATOR:
f_iter = FunctionType([], self)
f_iter.is_method = True
self.fields["__iter__"] = MemberDef(f_iter, ())
f_next = FunctionType([], self.return_type)
f_next.is_method = True
self.fields["__next__"] = MemberDef(f_next, ())
self.args[1].val = value
Can only be used unpacked: type A[*P] = tuple[*P]
"""
contents: list[ConcreteType]
def __str__(self):
return f"{self.kind.name.lower()}<{self.return_type}>"
def get_parents(self) -> List["BaseType"]:
if self.kind == PromiseKind.GENERATOR:
return [*super().get_parents()]
return super().get_parents()
class Forked(Promise):
"""Only use this for type specs"""
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.FORKED)
return f"*[{', '.join(map(str, self.contents))}]"
@dataclass(unsafe_hash=False)
class TypeListType(ConcreteType):
"""
Special type used to represent a list of types.
class Task(Promise):
"""Only use this for type specs"""
Used in function types for the parameters: Callable[[int, int], int]
"""
contents: list[ConcreteType]
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.TASK)
def __str__(self):
return f"[{', '.join(map(str, self.contents))}]"
@dataclass(eq=False)
class CallableInstanceType(GenericInstanceType):
parameters: list[ConcreteType]
return_type: ConcreteType
class Future(Promise):
"""Only use this for type specs"""
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.FUTURE)
class CallableType(GenericType):
def __str__(self):
return "Callable"
def _instantiate(self, args: list[ConcreteType]) -> CallableInstanceType:
match args:
case [TypeListType([*args]), ret]:
return CallableInstanceType(args, ret)
case _:
raise ValueError
class UserType(TypeOperator):
def __init__(self, name: str):
super().__init__([], name=name, is_reference=True)
def unify_internal(self, other: "BaseType", mode: UnifyMode):
if type(self) != type(other):
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
TY_CALLABLE = CallableType()
class UnionType(TypeOperator):
def __init__(self, *args: List[BaseType]):
super().__init__(args, "Union")
self.parents.extend(set(args))
@dataclass(eq=False)
class ClassTypeType(GenericInstanceType):
inner_type: BaseType
def is_optional(self):
if len(self.args) == 2 and TY_NONE in self.args:
return (set(self.args) - {TY_NONE}).pop()
return False
class ClassType(GenericType):
def __str__(self):
return "Type"
class GenericUserType(UserType):
pass
def _instantiate(self, args: list[ConcreteType]) -> ClassTypeType:
return ClassTypeType(*args)
class MonomorphizedUserType(UserType):
pass
\ No newline at end of file
TY_TYPE = ClassType()
import ast
from pathlib import Path
from logging import debug
from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, CppType, PyList, TypeType, Forked, Task, Future, PyIterator, TupleType, TypeOperator, BaseType, \
ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT, BuiltinFeature, UnionType, MemberDef
PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
# "str": VarDecl(VarKind.LOCAL, TY_TYPE, TY_STR),
# "bool": VarDecl(VarKind.LOCAL, TY_TYPE, TY_BOOL),
# "complex": VarDecl(VarKind.LOCAL, TY_TYPE, TY_COMPLEX),
# "None": VarDecl(VarKind.LOCAL, TY_NONE, None),
# "Callable": VarDecl(VarKind.LOCAL, TY_TYPE, FunctionType),
# "TypeVar": VarDecl(VarKind.LOCAL, TY_TYPE, TypeVariable),
# "CppType": VarDecl(VarKind.LOCAL, TY_TYPE, CppType),
# "list": VarDecl(VarKind.LOCAL, TY_TYPE, PyList),
"int": VarDecl(VarKind.LOCAL, TypeType(TY_INT)),
"float": VarDecl(VarKind.LOCAL, TypeType(TY_FLOAT)),
"str": VarDecl(VarKind.LOCAL, TypeType(TY_STR)),
"bytes": VarDecl(VarKind.LOCAL, TypeType(TY_BYTES)),
"bool": VarDecl(VarKind.LOCAL, TypeType(TY_BOOL)),
"complex": VarDecl(VarKind.LOCAL, TypeType(TY_COMPLEX)),
"None": VarDecl(VarKind.LOCAL, TypeType(TY_NONE)),
"Callable": VarDecl(VarKind.LOCAL, TypeType(FunctionType)),
#"TypeVar": VarDecl(VarKind.LOCAL, TypeType(TypeVariable)),
"CppType": VarDecl(VarKind.LOCAL, TypeType(CppType)),
"list": VarDecl(VarKind.LOCAL, TypeType(PyList)),
"dict": VarDecl(VarKind.LOCAL, TypeType(PyDict)),
"Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)),
"Task": VarDecl(VarKind.LOCAL, TypeType(Task)),
"Future": VarDecl(VarKind.LOCAL, TypeType(Future)),
"Iterator": VarDecl(VarKind.LOCAL, TypeType(PyIterator)),
"tuple": VarDecl(VarKind.LOCAL, TypeType(TupleType)),
"slice": VarDecl(VarKind.LOCAL, TypeType(TY_SLICE)),
"object": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
"BuiltinFeature": VarDecl(VarKind.LOCAL, TypeType(BuiltinFeature)),
"Any": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
"Optional": VarDecl(VarKind.LOCAL, TypeType(lambda x: UnionType(x, TY_NONE))),
})
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
def make_module(name: str, scope: Scope) -> BaseType:
ty = ModuleType([], f"{name}")
for n, v in scope.vars.items():
ty.fields[n] = MemberDef(v.type, v.val, False)
return ty
def discover_module(path: Path, scope):
for child in sorted(path.iterdir()):
if child.is_dir():
mod_scope = PRELUDE.child(ScopeKind.GLOBAL)
discover_module(child, mod_scope)
scope.vars[child.name] = make_mod_decl(child.name, mod_scope)
elif child.name == "__init__.py":
StdlibVisitor(scope).visit(ast.parse(child.read_text()))
debug(f"Visited {child}")
elif child.suffix == ".py":
mod_scope = PRELUDE.child(ScopeKind.GLOBAL)
StdlibVisitor(mod_scope).visit(ast.parse(child.read_text()))
if child.stem[-1] == "_":
child = child.with_name(child.stem[:-1])
scope.vars[child.stem] = make_mod_decl(child.name, mod_scope)
debug(f"Visited {child}")
def make_mod_decl(child, mod_scope):
return VarDecl(VarKind.MODULE, make_module(child, mod_scope), {k: v.type for k, v in mod_scope.vars.items()})
discover_module(typon_std, PRELUDE)
debug("Stdlib visited!")
#exit()
\ No newline at end of file
import abc
import ast
from dataclasses import dataclass, field
from typing import Optional, List
from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_NONE, TypeVariable, TY_TYPE, ResolvedConcreteType, TypeListType
from transpiler.phases.utils import NodeVisitorSeq
@dataclass
class TypeAnnotationVisitor(NodeVisitorSeq):
scope: Scope
def visit_str(self, node: str) -> BaseType:
if existing := self.scope.get(node):
ty = existing.type
assert isinstance(ty, ResolvedConcreteType)
assert ty.inherits(TY_TYPE)
return ty.inner_type
from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(node)
def visit_Name(self, node: ast.Name) -> BaseType:
return self.visit_str(node.id)
def visit_Constant(self, node: ast.Constant) -> BaseType:
if node.value is None:
return TY_NONE
if type(node.value) == str:
return node.value
raise NotImplementedError
def visit_Subscript(self, node: ast.Subscript) -> BaseType:
# ty_op = self.visit(node.value)
# args = list(node.slice.elts) if type(node.slice) == ast.Tuple else [node.slice]
# args = [self.visit(arg) for arg in args]
# return ty_op(*args)
raise NotImplementedError()
# return TypeOperator([self.visit(node.value)], self.visit(node.slice.value))
def visit_List(self, node: ast.List) -> BaseType:
return TypeListType([self.visit(elt) for elt in node.elts])
def visit_Attribute(self, node: ast.Attribute) -> BaseType:
raise NotImplementedError()
# left = self.visit(node.value)
# res = left.fields[node.attr].type
# assert isinstance(res, TypeType)
# return res.type_object
def visit_BinOp(self, node: ast.BinOp) -> BaseType:
# if isinstance(node.op, ast.BitOr):
# return UnionType(self.visit(node.left), self.visit(node.right))
raise NotImplementedError(node.op)
import ast
import copy
import dataclasses
import importlib
from dataclasses import dataclass
from transpiler.exceptions import CompileError
from transpiler.utils import highlight, linenodata
from transpiler.phases.typing import make_mod_decl
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next, is_builtin
from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT, MemberDef, \
RuntimeValue, GenericUserType, MonomorphizedUserType
from transpiler.phases.utils import PlainBlock, AnnotationName
@dataclass
class ScoperBlockVisitor(ScoperVisitor):
stdlib: bool = False
def visit_Pass(self, node: ast.Pass):
pass
def get_module(self, name: str) -> VarDecl:
mod = self.scope.get(name, VarKind.MODULE)
if mod is None:
# try lookup with importlib
py_mod = importlib.import_module(name)
mod_scope = Scope()
# copy all functions to mod_scope
for fname, obj in py_mod.__dict__.items():
if callable(obj):
# fty = FunctionType([], TypeVariable())
# fty.is_python_func = True
fty = TypeVariable()
fty.is_python_func = True
mod_scope.vars[fname] = VarDecl(VarKind.LOCAL, fty)
mod = make_mod_decl(name, mod_scope)
mod.type.is_python = True
self.scope.vars[name] = mod
if mod is None:
from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(name)
assert isinstance(mod, VarDecl), mod
assert isinstance(mod.type, ModuleType), mod.type
return mod
def visit_Import(self, node: ast.Import):
for alias in node.names:
mod = self.get_module(alias.name)
alias.module_obj = mod.type
self.scope.vars[alias.asname or alias.name] = dataclasses.replace(mod, kind=VarKind.LOCAL)
def visit_ImportFrom(self, node: ast.ImportFrom):
if node.module in {"typing2", "__future__"}:
return
module = self.get_module(node.module)
node.module_obj = module.type
for alias in node.names:
thing = module.val.get(alias.name)
if not thing:
from transpiler.phases.typing.exceptions import UnknownModuleMemberError
raise UnknownModuleMemberError(node.module, alias.name)
alias.item_obj = thing
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing)
def visit_Module(self, node: ast.Module):
self.visit_block(node.body)
def get_type(self, node: ast.expr) -> BaseType:
if type := getattr(node, "type", None):
return type
self.expr().visit(node)
return node.type
# ntype = TypeVariable()
# node.type = ntype
# return ntype
def visit_Assign(self, node: ast.Assign):
if len(node.targets) != 1:
raise NotImplementedError(node)
target = node.targets[0]
ty = self.get_type(node.value)
decl = self.visit_assign_target(target, ty)
if not hasattr(node, "is_declare"):
node.is_declare = decl
def visit_AnnAssign(self, node: ast.AnnAssign):
if node.simple != 1:
raise NotImplementedError(node)
if not isinstance(node.target, ast.Name):
raise NotImplementedError(node)
ty = self.visit_annotation(node.annotation)
decl = self.visit_assign_target(node.target, ty)
if not hasattr(node, "is_declare"):
node.is_declare = decl
if node.value is not None:
ty_val = self.get_type(node.value)
__TB__ = f"unifying annotation {highlight(node.annotation)} with value {highlight(node.value)} of type {highlight(ty_val)}"
ty.unify(ty_val)
def visit_assign_target(self, target, decl_val: BaseType) -> bool:
__TB__ = f"analyzing assignment target {highlight(target)} with value {highlight(decl_val)}"
if isinstance(target, ast.Name):
if target.id == "_":
return False
target.type = decl_val
if vdecl := self.scope.get(target.id, {VarKind.LOCAL, VarKind.GLOBAL, VarKind.NONLOCAL}, restrict_function=True):
__TB__ = f"unifying existing variable {highlight(target.id)} of type {highlight(vdecl.type)} with assigned value {highlight(decl_val)}"
vdecl.type.unify(decl_val)
return False
else:
self.scope.vars[target.id] = VarDecl(VarKind.LOCAL, decl_val)
if self.scope.kind == ScopeKind.FUNCTION_INNER:
self.root_decls[target.id] = VarDecl(VarKind.OUTER_DECL, decl_val)
return False
return True
elif isinstance(target, ast.Tuple):
if not isinstance(decl_val, TupleType):
from transpiler.phases.typing.exceptions import InvalidUnpackError
raise InvalidUnpackError(decl_val)
if len(target.elts) != len(decl_val.args):
from transpiler.phases.typing.exceptions import InvalidUnpackCountError
raise InvalidUnpackCountError(decl_val, len(target.elts))
target.type = decl_val
decls = [self.visit_assign_target(t, ty) for t, ty in zip(target.elts, decl_val.args)] # eager evaluated
return decls
elif isinstance(target, ast.Attribute):
attr_type = self.expr().visit(target)
attr_type.unify(decl_val)
return False
elif isinstance(target, ast.Subscript):
expr = self.expr()
left = expr.visit(target.value)
args = target.slice if type(target.slice) == tuple else [target.slice]
args = [expr.visit(e) for e in args]
if len(args) == 1:
args = args[0]
expr.make_dunder([left, args, decl_val], "setitem")
return False
else:
raise NotImplementedError(ast.unparse(target))
def visit_FunctionDef(self, node: ast.FunctionDef):
ftype = self.parse_function(node)
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
def process_class_ast(self, ctype: BaseType, node: ast.ClassDef, bases_after: list[ast.expr]):
scope = self.scope.child(ScopeKind.CLASS)
scope.obj_type = ctype
scope.class_ = scope
node.inner_scope = scope
node.type = ctype
visitor = ScoperClassVisitor(scope, cur_class=TypeType(ctype))
visitor.visit_block(node.body)
for base in bases_after:
base = self.expr().visit(base)
if is_builtin(base, "Enum"):
ctype.parents.append(TY_INT)
for k, m in ctype.fields.items():
m.type = ctype
m.val = ast.literal_eval(m.val)
assert type(m.val) == int
ctype.fields["value"] = MemberDef(TY_INT)
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), ast.arg(arg="value")],
defaults=[],
kw_defaults=[],
kwarg=None,
kwonlyargs=[],
posonlyargs=[],
),
body=[
ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr="value")],
value=ast.Name(id="value"),
**lnd
)
],
decorator_list=[],
returns=None,
type_comment=None,
**lnd
)
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method)
ctype.is_enum = True
else:
raise NotImplementedError(base)
for deco in node.decorator_list:
deco = self.expr().visit(deco)
if is_builtin(deco, "dataclass"):
# init_type = FunctionType([cttype, *cttype.members.values()], TypeVariable())
# cttype.methods["__init__"] = init_type
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.get_members()]],
defaults=[],
kw_defaults=[],
kwarg=None,
kwonlyargs=[],
posonlyargs=[],
),
body=[
ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)],
value=ast.Name(id=n),
**lnd
) for n in ctype.get_members()
],
decorator_list=[],
returns=None,
type_comment=None,
**lnd
)
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method)
else:
raise NotImplementedError(deco)
return ctype
def visit_ClassDef(self, node: ast.ClassDef):
copied = copy.deepcopy(node)
class NewUserType(UserType):
def __init__(self):
super().__init__(node.name)
#ctype = UserType(node.name)
typevars = []
bases_after = []
for base in node.bases:
if isinstance(base, ast.Subscript):
if isinstance(base.slice, ast.Name):
sliceval = [base.slice.id]
elif isinstance(base.slice, ast.Tuple):
sliceval = [n.id for n in base.slice.elts]
if is_builtin(self.expr().visit(base.value), "Generic"):
typevars = sliceval
else:
bases_after.append(base)
if typevars:
# generic
#ctype = GenericUserType(node.name, typevars, node)
var_scope = self.scope.child(ScopeKind.GLOBAL)
var_visitor = ScoperBlockVisitor(var_scope, self.root_decls)
node.gen_instances = {}
class OurGenericType(GenericUserType):
# def __init__(self, *args):
# super().__init__(node.name)
# for tv, arg in zip(typevars, args):
# var_scope.declare_local(tv, arg)
# var_visitor.process_class_ast(self, node, bases_after)
def __new__(cls, *args, **kwargs):
res = MonomorphizedUserType(node.name + "$$" + "__".join(map(str, args)) + "$$")
for tv, arg in zip(typevars, args):
var_scope.declare_local(tv, arg)
new_node = copy.deepcopy(copied)
new_node.name = res.name
var_visitor.process_class_ast(res, new_node, bases_after)
node.gen_instances[tuple(args)] = new_node
return res
ctype = OurGenericType
else:
# not generic
ctype = self.process_class_ast(UserType(node.name), node, bases_after)
cttype = TypeType(ctype)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, cttype)
def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
node.inner_scope = scope
self.expr().visit(node.test)
then_scope = scope.child(ScopeKind.FUNCTION_INNER)
then_visitor = ScoperBlockVisitor(then_scope, self.root_decls)
then_visitor.visit_block(node.body)
if node.orelse:
else_scope = scope.child(ScopeKind.FUNCTION_INNER)
else_visitor = ScoperBlockVisitor(else_scope, self.root_decls)
else_visitor.visit_block(node.orelse)
node.orelse_scope = else_scope
if then_scope.diverges and else_scope.diverges:
self.scope.diverges = True
def visit_While(self, node: ast.While):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
scope.is_loop = node
node.inner_scope = scope
self.expr().visit(node.test)
body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
body_visitor.visit_block(node.body)
if node.orelse:
orelse_scope = scope.child(ScopeKind.FUNCTION_INNER)
orelse_visitor = ScoperBlockVisitor(orelse_scope, self.root_decls)
orelse_visitor.visit_block(node.orelse)
node.orelse_variable = f"orelse_{id(node)}"
def visit_PlainBlock(self, node: PlainBlock):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
node.inner_scope = scope
body_visitor = ScoperBlockVisitor(scope, self.root_decls)
body_visitor.visit_block(node.body)
def visit_For(self, node: ast.For):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
scope.is_loop = node
node.inner_scope = scope
assert isinstance(node.target, ast.Name)
var_var = TypeVariable()
scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, var_var)
seq_type = self.expr().visit(node.iter)
iter_type = get_iter(seq_type)
next_type = get_next(iter_type)
var_var.unify(next_type)
body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
body_visitor.visit_block(node.body)
if node.orelse:
orelse_scope = scope.child(ScopeKind.FUNCTION_INNER)
orelse_visitor = ScoperBlockVisitor(orelse_scope, self.root_decls)
orelse_visitor.visit_block(node.orelse)
node.orelse_variable = f"orelse_{id(node)}"
def visit_Expr(self, node: ast.Expr):
self.expr().visit(node.value)
def visit_Return(self, node: ast.Return):
fct = self.scope.function
if fct is None:
from transpiler.phases.typing.exceptions import OutsideFunctionError
raise OutsideFunctionError()
ftype = fct.obj_type
assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else TY_NONE
vtype.unify(ftype.return_type.return_type if isinstance(ftype.return_type, Promise) else ftype.return_type)
self.scope.diverges = True
#fct.has_return = True
def visit_Global(self, node: ast.Global):
for name in node.names:
self.scope.function.vars[name] = VarDecl(VarKind.GLOBAL, None)
if name not in self.scope.global_scope.vars:
self.scope.global_scope.vars[name] = VarDecl(VarKind.LOCAL, None)
def visit_Nonlocal(self, node: ast.Global):
fct = self.scope.function
if fct is None:
from transpiler.phases.typing.exceptions import OutsideFunctionError
raise OutsideFunctionError()
for name in node.names:
fct.vars[name] = VarDecl(VarKind.NONLOCAL, None)
if name not in fct.parent.vars:
fct.parent.vars[name] = VarDecl(VarKind.LOCAL, None)
def visit_AugAssign(self, node: ast.AugAssign):
target, value = map(self.get_type, (node.target, node.value))
try:
self.expr().make_dunder([target, value], "i" + DUNDER[type(node.op)])
except CompileError as e:
self.visit_assign_target(node.target, self.expr().make_dunder([target, value], DUNDER[type(node.op)]))
# equivalent = ast.Assign(
# targets=[node.target],
# value=ast.BinOp(left=node.target, op=node.op, right=node.value, **linenodata(node)),
# **linenodata(node))
# self.visit(equivalent)
def visit(self, node: ast.AST):
if isinstance(node, ast.AST):
__TB_SKIP__ = True
super().visit(node)
node.scope = self.scope
else:
raise NotImplementedError(node)
def visit_Break(self, _node: ast.Break):
if not self.scope.is_in_loop():
from transpiler.phases.typing.exceptions import OutsideLoopError
raise OutsideLoopError()
def visit_Continue(self, _node: ast.Continue):
if not self.scope.is_in_loop():
from transpiler.phases.typing.exceptions import OutsideLoopError
raise OutsideLoopError()
def visit_Try(self, node: ast.Try):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
node.inner_scope = scope
body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
body_visitor.visit_block(node.body)
# todo
for handler in node.handlers:
handler_scope = scope.child(ScopeKind.FUNCTION_INNER)
handler_visitor = ScoperBlockVisitor(handler_scope, self.root_decls)
handler_visitor.visit_block(handler.body)
if node.orelse:
else_scope = scope.child(ScopeKind.FUNCTION_INNER)
else_visitor = ScoperBlockVisitor(else_scope, self.root_decls)
else_visitor.visit_block(node.orelse)
if node.finalbody:
raise NotImplementedError(node.finalbody)
def visit_Raise(self, node: ast.Raise):
self.scope.diverges = True
if node.exc:
self.expr().visit(node.exc)
if node.cause:
self.expr().visit(node.cause)
\ No newline at end of file
# coding: utf-8
import ast
from dataclasses import dataclass, field
from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import PromiseKind, Promise, BaseType, MemberDef
@dataclass
class ScoperClassVisitor(ScoperVisitor):
fdecls: list[(ast.FunctionDef, BaseType)] = field(default_factory=list)
def visit_AnnAssign(self, node: ast.AnnAssign):
assert node.value is None, "Class field should not have a value"
assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)"
assert isinstance(node.target, ast.Name)
self.scope.obj_type.fields[node.target.id] = MemberDef(self.visit_annotation(node.annotation))
def visit_Assign(self, node: ast.Assign):
assert len(node.targets) == 1, "Can't use destructuring in class static member"
assert isinstance(node.targets[0], ast.Name)
node.is_declare = True
valtype = self.expr().visit(node.value)
node.targets[0].type = valtype
self.scope.obj_type.fields[node.targets[0].id] = MemberDef(valtype, node.value)
def visit_FunctionDef(self, node: ast.FunctionDef):
ftype = self.parse_function(node)
ftype.parameters[0].unify(self.scope.obj_type)
inner = ftype.return_type
if node.name != "__init__":
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
ftype.is_method = True
self.scope.obj_type.fields[node.name] = MemberDef(ftype, node)
return (node, inner)
import ast
from dataclasses import dataclass, field
from typing import Dict, Optional
from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType, BuiltinFeature, FunctionType, \
Promise, PromiseKind
from transpiler.phases.utils import NodeVisitorSeq, AnnotationName
PRELUDE = Scope.make_global()
@dataclass
class ScoperVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL))
root_decls: Dict[str, VarDecl] = field(default_factory=dict)
cur_class: Optional[TypeType] = None
def expr(self) -> "ScoperExprVisitor":
from transpiler.phases.typing.expr import ScoperExprVisitor
return ScoperExprVisitor(self.scope, self.root_decls)
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope, self.cur_class)
def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType:
res = self.anno().visit(expr) if expr else TypeVariable()
assert not isinstance(res, TypeType)
return res
def annotate_arg(self, arg: ast.arg) -> BaseType:
if arg.annotation is None or isinstance(arg.annotation, AnnotationName):
res = TypeVariable()
arg.annotation = AnnotationName(res)
return res
else:
return self.visit_annotation(arg.annotation)
def parse_function(self, node: ast.FunctionDef):
argtypes = [self.annotate_arg(arg) for arg in node.args.args]
rtype = self.visit_annotation(node.returns)
ftype = FunctionType(argtypes, rtype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
ftype.optional_at = len(node.args.args) - len(node.args.defaults)
for ty, default in zip(argtypes[ftype.optional_at:], node.args.defaults):
self.expr().visit(default).unify(ty)
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
self.fdecls.append((node, rtype))
return ftype
def visit_block(self, block: list[ast.AST]):
if not block:
return
__TB__ = f"running type analysis on block starting with {highlight(block[0])}"
self.fdecls = []
for b in block:
self.visit(b)
if self.fdecls:
old_list = self.fdecls
exc = None
while True:
new_list = []
for node, rtype in old_list:
from transpiler.exceptions import CompileError
try:
self.visit_function_definition(node, rtype)
except CompileError as e:
new_list.append((node, rtype))
if not exc or getattr(node, "is_main", False):
exc = e
if len(new_list) == len(old_list):
raise exc
if not new_list:
break
old_list = new_list
exc = None
def visit_function_definition(self, node, rtype):
__TB__ = f"running type analysis on the body of {highlight(node)}"
__TB_NODE__ = node
from transpiler.phases.typing.block import ScoperBlockVisitor
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(node.inner_scope, decls)
visitor.fdecls = []
visitor.visit(b)
if len(visitor.fdecls) > 1:
raise NotImplementedError("?")
elif len(visitor.fdecls) == 1:
fnode, frtype = visitor.fdecls[0]
self.visit_function_definition(fnode, frtype)
#del node.inner_scope.vars[fnode.name]
visitor.visit_assign_target(ast.Name(fnode.name), fnode.type)
b.decls = decls
if not node.inner_scope.diverges and not (isinstance(node.type.return_type, Promise) and node.type.return_type.kind == PromiseKind.GENERATOR):
from transpiler.phases.typing.exceptions import TypeMismatchError
try:
rtype.unify(TY_NONE)
except TypeMismatchError as e:
from transpiler.phases.typing.exceptions import MissingReturnError
raise MissingReturnError(node) from e
def get_iter(seq_type):
try:
iter_type = seq_type.fields["__iter__"].type.return_type
except:
from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type)
return iter_type
def get_next(iter_type):
try:
next_type = iter_type.fields["__next__"].type.return_type
except:
from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type)
return next_type
def is_builtin(x, feature):
return isinstance(x, BuiltinFeature) and x.val == feature
\ No newline at end of file
import ast
import enum
from dataclasses import dataclass
from transpiler.utils import highlight
from transpiler.exceptions import CompileError
from transpiler.phases.typing.types import TypeVariable, BaseType, TypeOperator
@dataclass
class UnresolvedTypeVariableError(CompileError):
variable: TypeVariable
def __str__(self) -> str:
return f"Unresolved type variable: {self.variable}"
def detail(self, last_node: ast.AST = None) -> str:
if isinstance(last_node, (ast.Import, ast.ImportFrom)):
return f"""
This indicates the compiler was unable to infer the type of a function in a module.
Currently, Typon cannot determine the type of Python functions imported from other modules, except
for the standard library.
As such, you need to give enough information to the compiler to infer the type of the function.
For example:
↓↓↓ this tells the compiler that {highlight('math.factorial')} returns an {highlight('int')}
{highlight('res: int = math.factorial(5)')}"""
return f"""
This generally indicates the compiler was unable to infer the type of a variable or expression.
A common fix is to add a type annotation to the variable or function.
For example:
↓↓↓ this tells the compiler that {highlight('x')} is an {highlight('int')}
{highlight('def f(x: int):')}
"""
@dataclass
class RecursiveTypeUnificationError(CompileError):
needle: BaseType
haystack: BaseType
def __str__(self) -> str:
return f"Recursive type unification: {highlight(self.needle)} and {highlight(self.haystack)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This generally indicates a recursive type definition. Such types are not currently supported.
For example:
{highlight('T = tuple[T]')}
In the current case, {highlight(self.haystack)} contains type {highlight(self.needle)}, but an attempt was made to
unify them.
"""
class TypeMismatchKind(enum.Enum):
NO_COMMON_PARENT = enum.auto()
DIFFERENT_TYPE = enum.auto()
@dataclass
class TypeMismatchError(CompileError):
expected: BaseType
got: BaseType
reason: TypeMismatchKind
def __str__(self) -> str:
return f"Type mismatch: expected {highlight(self.expected)}, got {highlight(self.got)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This generally indicates a type error.
For example:
{highlight('def f(x: int): ...')}
{highlight('f("hello")')}
In the current case, the compiler expected an expression of type {highlight(self.expected)}, but instead got
an expression of type {highlight(self.got)}.
"""
@dataclass
class ArgumentCountMismatchError(CompileError):
func: TypeOperator
arguments: TypeOperator
def __str__(self) -> str:
fcount = str(len(self.func.args))
if self.func.variadic:
fcount = f"at least {fcount}"
return f"Argument count mismatch: expected {fcount}, got {len(self.arguments.args)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates missing or extraneous arguments in a function call or type instantiation.
The called or instantiated signature was {highlight(self.func)}.
Other examples:
{highlight('def f(x: int): ...')}
{highlight('f(1, 2)')}
Here, the function {highlight('f')} expects one argument, but was called with two.
{highlight('x: list[int, str]')}
Here, the type {highlight('list')} expects one argument, but was instantiated with two.
"""
@dataclass
class ProtocolMismatchError(CompileError):
value: BaseType
protocol: BaseType
reason: Exception | str
def __str__(self) -> str:
return f"Protocol mismatch: {str(self.value)} does not implement {str(self.protocol)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This generally indicates a type error.
For example:
{highlight('def f(x: Iterable[int]): ...')}
{highlight('f("hello")')}
In the current case, the compiler expected an expression whose type implements {highlight(self.protocol)}, but
instead got an expression of type {highlight(self.value)}.
"""
@dataclass
class NotCallableError(CompileError):
value: BaseType
def __str__(self) -> str:
return f"Trying to call a non-function type: {highlight(self.value)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to call an object that is not a function.
For example:
{highlight('x = 1')}
{highlight('x()')}
"""
@dataclass
class MissingAttributeError(CompileError):
value: BaseType
attribute: str
def __str__(self) -> str:
return f"Missing attribute: {highlight(self.value)} has no attribute {highlight(self.attribute)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to access an attribute that does not exist.
For example:
{highlight('x = 1')}
{highlight('print(x.y)')}
"""
@dataclass
class UnknownNameError(CompileError):
name: str
def __str__(self) -> str:
return f"Unknown name: {highlight(self.name)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to access a name that does not exist.
For example:
{highlight('print(abcd)')}
{highlight('import foobar')}
"""
@dataclass
class UnknownModuleMemberError(CompileError):
module: str
name: str
def __str__(self) -> str:
return f"Unknown module member: Module {highlight(self.module)} does not contain {highlight(self.name)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to import
For example:
{highlight('from math import abcd')}
"""
@dataclass
class InvalidUnpackCountError(CompileError):
value: BaseType
count: int
def __str__(self) -> str:
return f"Invalid unpack: {highlight(self.value)} cannot be unpacked into {self.count} variables"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to unpack a value that cannot be unpacked into the given number of
variables.
For example:
{highlight('a, b, c = 1, 2')}
"""
@dataclass
class InvalidUnpackError(CompileError):
value: BaseType
def __str__(self) -> str:
return f"Invalid unpack: {highlight(self.value)} cannot be unpacked"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to unpack a value that cannot be unpacked.
For example:
{highlight('a, b, c = 1')}
Moreover, currently typon only supports unpacking tuples.
"""
@dataclass
class NotIterableError(CompileError):
value: BaseType
def __str__(self) -> str:
return f"Not iterable: {highlight(self.value)} is not iterable"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to iterate over a value that is not iterable.
For example:
{highlight('for x in 1: ...')}
Iterable types must implement the Python {highlight('Iterable')} protocol, which requires the presence of a
{highlight('__iter__')} method.
"""
@dataclass
class NotIteratorError(CompileError):
value: BaseType
def __str__(self) -> str:
return f"Not iterator: {highlight(self.value)} is not an iterator"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to iterate over a value that is not an iterator.
For example:
{highlight('x = next(5)')}
Iterator types must implement the Python {highlight('Iterator')} protocol, which requires the presence of a
{highlight('__next__')} method.
"""
@dataclass
class OutsideFunctionError(CompileError):
def __str__(self) -> str:
return f"{highlight('return')} and {highlight('nonlocal')} cannot be used outside of a function"
def detail(self, last_node: ast.AST = None) -> str:
return ""
@dataclass
class OutsideLoopError(CompileError):
def __str__(self) -> str:
return f"{highlight('break')} and {highlight('continue')} can only be used inside a loop"
def detail(self, last_node: ast.AST = None) -> str:
return ""
@dataclass
class MissingReturnError(CompileError):
node: ast.FunctionDef
def __str__(self) -> str:
return f"Missing return: not all code paths in {highlight(self.node)} return"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that a function is missing a {highlight('return')} statement in one or more of its code paths.
For example:
{highlight('def f(x: int):')}
{highlight(' if x > 0:')}
{highlight(' return 1')}
{highlight(' # if x <= 0, the function returns nothing')}
"""
@dataclass
class InconsistentMroError(CompileError):
bases: list[BaseType]
def __str__(self) -> str:
return f"Cannot create a cnossitent method resolution order (MRO) for bases {'\n'.join(map(highlight, self.bases))}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that a class has an inconsistent method resolution order (MRO).
For example:
{highlight('class A: pass')}
{highlight('class B(A): pass')}
{highlight('class C(B, A): pass')}
"""
\ No newline at end of file
import abc
import ast
import inspect
from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next, is_builtin
from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \
TY_SLICE, TY_FLOAT, RuntimeValue, BuiltinFeature
from transpiler.utils import linenodata
DUNDER = {
ast.Eq: "eq",
ast.NotEq: "ne",
ast.Mult: "mul",
ast.Add: "add",
ast.Sub: "sub",
ast.Div: "truediv",
ast.FloorDiv: "floordiv",
ast.Mod: "mod",
ast.Lt: "lt",
ast.Gt: "gt",
ast.GtE: "ge",
ast.LtE: "le",
ast.LShift: "lshift",
ast.RShift: "rshift",
ast.BitXor: "xor",
ast.BitOr: "or",
ast.BitAnd: "and",
ast.USub: "neg",
ast.UAdd: "pos",
ast.Invert: "invert",
ast.In: "contains",
}
class ScoperExprVisitor(ScoperVisitor):
def visit(self, node) -> BaseType:
if existing := getattr(node, "type", None):
return existing.resolve()
__TB_SKIP__ = True
res = super().visit(node)
if not res:
__TB_SKIP__ = False
raise NotImplementedError(f"`{ast.unparse(node)}` {type(node)}")
res = res.resolve()
if True or not hasattr(res, "from_node"):
res.from_node = node
node.type = res
return res
def visit_Tuple(self, node: ast.Tuple) -> BaseType:
return TupleType(*[self.visit(e) for e in node.elts])
def visit_Slice(self, node: ast.Slice) -> BaseType:
for n in ("lower", "upper", "step"):
if arg := getattr(node, n):
self.visit(arg).unify(TY_INT)
return TY_SLICE
def visit_Yield(self, node: ast.Yield) -> BaseType:
ytype = self.visit(node.value)
ftype = self.scope.function.obj_type.return_type
assert isinstance(ftype, Promise)
assert ftype.kind == PromiseKind.TASK
ftype.kind = PromiseKind.GENERATOR
ftype.return_type.unify(ytype)
self.scope.function.has_yield = True
return TY_NONE
def visit_Constant(self, node: ast.Constant) -> BaseType:
if isinstance(node.value, str):
return TY_STR
elif isinstance(node.value, bool):
return TY_BOOL
elif isinstance(node.value, int):
return TY_INT
elif isinstance(node.value, complex):
return TY_COMPLEX
elif isinstance(node.value, float):
return TY_FLOAT
elif node.value is None:
return TY_NONE
else:
raise NotImplementedError(node, type(node))
def visit_Name(self, node: ast.Name) -> BaseType:
obj = self.scope.get(node.id)
if not obj:
from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(node.id)
ty = obj.type.resolve()
if isinstance(ty, TypeType) and isinstance(ty.type_object, TypeVariable):
raise NameError(f"Use of type variable") # todo: when does this happen exactly?
if getattr(ty, "is_python_func", False):
ty.python_func_used = True
return ty
def visit_BoolOp(self, node: ast.BoolOp) -> BaseType:
for value in node.values:
self.visit(value)
return TY_BOOL
def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func)
if is_builtin(ftype, "TypeVar"):
return TypeType(TypeVariable(*[ast.literal_eval(arg) for arg in node.args]))
if ftype.typevars:
ftype = ftype.gen_sub(None, {v.name: TypeVariable(v.name) for v in ftype.typevars})
from transpiler.exceptions import CompileError
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
actual = rtype
node.is_await = False
if isinstance(actual, Promise) and actual.kind != PromiseKind.GENERATOR:
node.is_await = True
actual = actual.return_type.resolve()
if self.scope.function and isinstance(actual, Promise) and actual.kind == PromiseKind.FORKED \
and isinstance(fty := self.scope.function.obj_type.return_type, Promise):
fty.kind = PromiseKind.JOIN
return actual
def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]):
if isinstance(ftype, TypeType):# and isinstance(ftype.type_object, UserType):
init: FunctionType = self.visit_getattr(ftype, "__init__").remove_self()
init.return_type = ftype.type_object
return self.visit_function_call(init, arguments)
if isinstance(ftype, FunctionType):
ret = ftype.return_type
elif isinstance(ftype, TypeVariable):
ret = TypeVariable()
else:
from transpiler.phases.typing.exceptions import NotCallableError
raise NotCallableError(ftype)
#is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list())
equivalent = FunctionType(arguments, ret)
equivalent.is_intermediary = True
ftype.unify(equivalent)
return equivalent.return_type
def visit_Lambda(self, node: ast.Lambda) -> BaseType:
argtypes = [TypeVariable() for _ in node.args.args]
rtype = TypeVariable()
ftype = FunctionType(argtypes, rtype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.body.scope = scope
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
decls = {}
visitor = ScoperExprVisitor(scope, decls)
rtype.unify(visitor.visit(node.body))
node.body.decls = decls
return ftype
# def visit_BinOp(self, node: ast.BinOp) -> BaseType:
# left, right = map(self.visit, (node.left, node.right))
# return self.make_dunder([left, right], DUNDER[type(node.op)])
# def visit_Compare(self, node: ast.Compare) -> BaseType:
# left, right = map(self.visit, (node.left, node.comparators[0]))
# op = node.ops[0]
# if type(op) == ast.In:
# left, right = right, left
# return self.make_dunder([left, right], DUNDER[type(op)])
def visit_Attribute(self, node: ast.Attribute) -> BaseType:
ltype = self.visit(node.value)
return self.visit_getattr(ltype, node.attr)
def visit_getattr(self, ltype: BaseType, name: str) -> BaseType:
bound = True
if isinstance(ltype, TypeType):
# if mdecl := ltype.static_members.get(name):
# attr = mdecl.type
# if getattr(attr, "is_python_func", False):
# attr.python_func_used = True
# return attr
ltype = ltype.type_object
bound = False
if isinstance(ltype, abc.ABCMeta):
ctor = ltype.__init__
args = list(inspect.signature(ctor).parameters.values())[1:]
if not all(arg.annotation == BaseType for arg in args):
raise NotImplementedError("I don't know how to handle this type")
ltype = ltype(*(TypeVariable() for _ in args))
# if mdecl := ltype.members.get(name):
# attr = mdecl.type
# if getattr(attr, "is_python_func", False):
# attr.python_func_used = True
# return attr
# if meth := ltype.methods.get(name):
# meth = meth.gen_sub(ltype, {})
# if bound:
# return meth.remove_self()
# else:
# return meth
if field := ltype.fields.get(name):
ty = field.type.resolve()
if getattr(ty, "is_python_func", False):
ty.python_func_used = True
if isinstance(ty, FunctionType):
ty = ty.gen_sub(ltype, {})
if bound and field.in_class_def and type(field.val) != RuntimeValue:
return ty.remove_self()
return ty
from transpiler.phases.typing.exceptions import MissingAttributeError
parents = ltype.iter_hierarchy_recursive()
next(parents)
for p in parents:
try:
return self.visit_getattr(p, name)
except MissingAttributeError as e:
pass
# class MemberProtocol(TypeOperator):
# pass
raise MissingAttributeError(ltype, name)
def visit_List(self, node: ast.List) -> BaseType:
if not node.elts:
return PyList(TypeVariable())
elems = [self.visit(e) for e in node.elts]
first, *rest = elems
for e in rest:
try:
first.unify(e)
except:
raise NotImplementedError(f"List with different types not handled yet: {', '.join(map(str, elems))}")
return PyList(elems[0])
def visit_Set(self, node: ast.Set) -> BaseType:
if not node.elts:
return PySet(TypeVariable())
elems = [self.visit(e) for e in node.elts]
if len(set(elems)) != 1:
raise NotImplementedError("Set with different types not handled yet")
return PySet(elems[0])
def visit_Dict(self, node: ast.Dict) -> BaseType:
if not node.keys:
return PyDict(TypeVariable(), TypeVariable())
keys = [self.visit(e) for e in node.keys]
values = [self.visit(e) for e in node.values]
if len(set(keys)) != 1 or len(set(values)) != 1:
raise NotImplementedError(f"Dict with different types not handled yet in `{ast.unparse(node)}`")
return PyDict(keys[0], values[0])
def visit_Subscript(self, node: ast.Subscript) -> BaseType:
left = self.visit(node.value)
args = node.slice if type(node.slice) == tuple else [node.slice]
args = [self.visit(e) for e in args]
if isinstance(left, TypeType) and isinstance(left.type_object, abc.ABCMeta):
# generic
return TypeType(left.type_object(*[arg.type_object if isinstance(arg, TypeType) else arg for arg in args]))
pass
return self.make_dunder([left, *args], "getitem")
def visit_UnaryOp(self, node: ast.UnaryOp) -> BaseType:
val = self.visit(node.operand)
if isinstance(node.op, ast.Not):
return TY_BOOL
return self.make_dunder([val], DUNDER[type(node.op)])
def visit_IfExp(self, node: ast.IfExp) -> BaseType:
self.visit(node.test)
then = self.visit(node.body)
else_ = self.visit(node.orelse)
if then != else_:
raise NotImplementedError("IfExp with different types not handled yet")
return then
def make_dunder(self, args: List[BaseType], name: str) -> BaseType:
return self.visit_function_call(
self.visit_getattr(TypeType(args[0]), f"__{name}__"),
args
)
def visit_ListComp(self, node: ast.ListComp) -> BaseType:
if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0]
iter_type = get_iter(self.visit(gen.iter))
node.input_item_type = get_next(iter_type)
virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER)
from transpiler import ScoperBlockVisitor
visitor = ScoperBlockVisitor(virt_scope)
visitor.visit_assign_target(gen.target, node.input_item_type)
node.item_type = visitor.expr().visit(node.elt)
for if_ in gen.ifs:
visitor.expr().visit(if_)
gen.ifs_node = ast.BoolOp(ast.And(), gen.ifs, **linenodata(node))
return PyList(node.item_type)
\ No newline at end of file
import ast
from dataclasses import field, dataclass
from enum import Enum
from typing import Optional, Dict, List, Any
from transpiler.phases.typing.types import BaseType, RuntimeValue
class VarKind(Enum):
"""Kind of variable."""
LOCAL = 1
"""`xxx = ...`"""
GLOBAL = 2
"""`global xxx"""
NONLOCAL = 3
"""`nonlocal xxx`"""
SELF = 4
OUTER_DECL = 5
MODULE = 6
class VarType:
pass
@dataclass
class VarDecl:
kind: VarKind
type: BaseType
val: Any = RuntimeValue()
class ScopeKind(Enum):
GLOBAL = 1
"""Global (module) scope"""
FUNCTION = 2
"""Function scope"""
FUNCTION_INNER = 3
"""Block (if, for, ...) scope inside a function"""
CLASS = 4
"""Class scope"""
@dataclass
class Scope:
parent: Optional["Scope"] = None
kind: ScopeKind = ScopeKind.GLOBAL
function: Optional["Scope"] = None
global_scope: Optional["Scope"] = None
vars: Dict[str, VarDecl] = field(default_factory=dict)
children: List["Scope"] = field(default_factory=list)
obj_type: Optional[BaseType] = None
diverges: bool = False
class_: Optional["Scope"] = None
is_loop: Optional[ast.For | ast.While] = None
@staticmethod
def make_global():
res = Scope()
res.global_scope = res
return res
def is_in_loop(self) -> Optional[ast.For | ast.While]:
if self.is_loop:
return self.is_loop
if self.parent is not None and self.kind != ScopeKind.FUNCTION:
return self.parent.is_in_loop()
return None
def child(self, kind: ScopeKind):
res = Scope(self, kind, self.function, self.global_scope)
if kind == ScopeKind.GLOBAL:
res.global_scope = res
self.children.append(res)
return res
def declare_local(self, name: str, type: BaseType):
"""Declares a local variable"""
self.vars[name] = VarDecl(VarKind.LOCAL, type)
def get(self, name: str, kind: VarKind | set[VarKind] = VarKind.LOCAL, restrict_function: bool = False) -> Optional[VarDecl]:
"""
Gets the variable declaration of a variable in the current scope or any parent scope.
"""
if type(kind) is VarKind:
kind = {kind}
if (res := self.vars.get(name)) and res.kind in kind:
if res.kind == VarKind.GLOBAL:
return self.global_scope.get(name, kind)
elif res.kind == VarKind.NONLOCAL:
return self.function.parent.get(name, VarKind.LOCAL, True)
return res
if self.parent is not None and not (self.kind == ScopeKind.FUNCTION and restrict_function):
return self.parent.get(name, kind, restrict_function)
return None
import ast
import dataclasses
from abc import ABCMeta
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from logging import debug
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE, is_builtin
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable, \
MemberDef, BuiltinFeature
from transpiler.phases.utils import NodeVisitorSeq
@dataclass
class StdlibVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE)
cur_class: Optional[BaseType] = None
typevars: Dict[str, BaseType] = field(default_factory=dict)
def expr(self) -> ScoperExprVisitor:
return ScoperExprVisitor(self.scope)
def visit_Module(self, node: ast.Module):
for stmt in node.body:
self.visit(stmt)
def visit_Assign(self, node: ast.Assign):
self.scope.vars[node.targets[0].id] = VarDecl(VarKind.LOCAL, self.visit(node.value))
def visit_AnnAssign(self, node: ast.AnnAssign):
ty = self.anno().visit(node.annotation)
if self.cur_class:
assert isinstance(self.cur_class, TypeType)
if isinstance(self.cur_class.type_object, ABCMeta):
raise NotImplementedError
else:
self.cur_class.type_object.fields[node.target.id] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars))
self.scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, ty)
def visit_ImportFrom(self, node: ast.ImportFrom):
pass
def visit_Import(self, node: ast.Import):
pass
def visit_ClassDef(self, node: ast.ClassDef):
if existing := self.scope.get(node.name):
ty = existing.type
else:
class BuiltinClassType(TypeOperator):
def __init__(self, *args):
super().__init__(args, node.name, is_reference=True)
ty = TypeType(BuiltinClassType)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
typevars = []
for b in node.bases:
if isinstance(b, ast.Subscript):
if isinstance(b.slice, ast.Name):
sliceval = [b.slice.id]
elif isinstance(b.slice, ast.Tuple):
sliceval = [n.id for n in b.slice.elts]
if isinstance(b.value, ast.Name) and b.value.id == "Generic":
typevars = sliceval
elif isinstance(b.value, ast.Name) and b.value.id == "Protocol":
typevars = sliceval
ty.type_object.is_protocol_gen = True
else:
idxs = [typevars.index(v) for v in sliceval]
parent = self.visit(b.value)
assert isinstance(parent, TypeType)
assert isinstance(ty.type_object, ABCMeta)
ty.type_object.gen_parents.append(lambda selfvars: parent.type_object(*[selfvars[i] for i in idxs]))
else:
if isinstance(b, ast.Name) and b.id == "Protocol":
ty.type_object.is_protocol_gen = True
else:
parent = self.visit(b)
assert isinstance(parent, TypeType)
if isinstance(ty.type_object, ABCMeta):
ty.type_object.gen_parents.append(parent.type_object)
else:
ty.type_object.parents.append(parent.type_object)
if not typevars and not existing:
ty.type_object = ty.type_object()
cl_scope = self.scope.child(ScopeKind.CLASS)
visitor = StdlibVisitor(cl_scope, ty)
for var in typevars:
visitor.typevars[var] = TypeType(TypeVariable(var))
for stmt in node.body:
visitor.visit(stmt)
def visit_Pass(self, node: ast.Pass):
pass
def visit_FunctionDef(self, node: ast.FunctionDef):
tc = node.type_comment # todo : lire les commetnaries de type pour les fonctions génériques sinon trouver autre chose
arg_visitor = TypeAnnotationVisitor(self.scope.child(ScopeKind.FUNCTION), self.cur_class)
arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args]
ret_type = arg_visitor.visit(node.returns)
ty = FunctionType(arg_types, ret_type)
ty.typevars = arg_visitor.typevars
if node.args.vararg:
ty.variadic = True
ty.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
if self.cur_class:
ty.is_method = True
assert isinstance(self.cur_class, TypeType)
if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
else:
self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars), ())
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert):
if isinstance(node.test, ast.UnaryOp) and isinstance(node.test.op, ast.Not):
oper = node.test.operand
try:
res = self.expr().visit(oper)
except:
debug(f"Type of {ast.unparse(oper)} := INVALID")
else:
raise AssertionError(f"Assertion should fail, got {res} for {ast.unparse(oper)}")
else:
debug(f"Type of {ast.unparse(node.test)} := {self.expr().visit(node.test)}")
def visit_Call(self, node: ast.Call) -> BaseType:
ty_op = self.visit(node.func)
if is_builtin(ty_op, "TypeVar"):
return TypeType(TypeVariable(*[ast.literal_eval(arg) for arg in node.args]))
if isinstance(ty_op, TypeType):
return TypeType(ty_op.type_object(*[ast.literal_eval(arg) for arg in node.args]))
raise NotImplementedError(ast.unparse(node))
def anno(self) -> "TypeAnnotationVisitor":
return TypeAnnotationVisitor(self.scope, self.cur_class)
def visit_str(self, node: str) -> BaseType:
if existing := self.scope.get(node):
return existing.type
from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(node)
def visit_Name(self, node: ast.Name) -> BaseType:
if node.id == "TypeVar":
return BuiltinFeature("TypeVar")
return self.visit_str(node.id)
\ No newline at end of file
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Optional, Callable
def get_default_parents():
if obj := globals().get("TY_OBJECT"):
return [obj]
return []
class RuntimeValue:
pass
@dataclass
class MemberDef:
type: "BaseType"
val: typing.Any = RuntimeValue()
in_class_def: bool = True
@dataclass(eq=False)
class BaseType(ABC):
def resolve(self) -> "BaseType":
return self
cur_var = 0
def next_var_id():
global cur_var
cur_var += 1
return cur_var
@dataclass(eq=False)
class ConcreteType(BaseType):
"""
A concrete type is the type of a concrete value.
It has fields and a list of parent concrete types.
Examples: int, str, list[int]
"""
@dataclass(eq=False)
class TypeVariable(ConcreteType):
name: str = field(default_factory=lambda: next_var_id())
resolved: Optional[ConcreteType] = None
def resolve(self) -> ConcreteType:
if self.resolved is None:
return self
return self.resolved.resolve()
def __str__(self):
if self.resolved is None:
# return f"TypeVar[\"{self.name}\"]"
return f"_{self.name}"
return str(self.resolved)
def __eq__(self, other):
if not isinstance(other, BaseType):
return False
if self.resolved is None:
return self == other
return self.resolved == other.resolve()
@dataclass(eq=False)
class ResolvedConcreteType(ConcreteType):
"""
A concrete type is the type of a concrete value.
It has fields and a list of parent concrete types.
Examples: int, str, list[int]
"""
fields: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
parents: list["ResolvedConcreteType"] = field(default_factory=lambda: [TY_OBJECT], init=False)
def get_mro(self):
"""
Performs linearization according to the MRO spec.
https://www.python.org/download/releases/2.3/mro/
"""
def merge(*lists):
lists = [l for l in lists if len(l) > 0]
for i, l in enumerate(lists):
first = l[0]
for j, l2 in enumerate(lists):
if j == i:
continue
if first in l2:
break
else:
return [first] + merge(*[x[1:] for x in lists if x[0] != first])
# unable to find a next element
from transpiler.phases.typing.exceptions import InconsistentMroError
raise InconsistentMroError(self.parents)
return [self] + merge(*[p.get_mro() for p in self.parents], self.parents)
def inherits(self, parent: BaseType):
return self == parent or any(p.inherits(parent) for p in self.parents)
@dataclass(eq=False, init=False)
class GenericInstanceType(ResolvedConcreteType):
"""
An instance of a generic type.
Examples: list[int], dict[str, object], Callable[[int, int], int]
"""
generic_parent: "GenericType"
generic_args: list[ConcreteType]
def __init__(self):
super().__init__()
def inherits(self, parent: BaseType):
return self.generic_parent == parent or super().inherits(parent)
def __eq__(self, other):
if isinstance(other, GenericInstanceType):
return self.generic_parent == other.generic_parent and self.generic_args == other.generic_args
return False
def __str__(self):
return f"{self.generic_parent}[{', '.join(map(str, self.generic_args))}]"
@dataclass
class GenericConstraint:
left: ResolvedConcreteType
right: ResolvedConcreteType
@dataclass(eq=False, init=False)
class GenericType(BaseType):
parameters: list[str]
def constraints(self, args: list[ConcreteType]) -> list[GenericConstraint]:
return []
@abstractmethod
def _instantiate(self, args: list[ConcreteType]) -> GenericInstanceType:
raise NotImplementedError()
def instantiate(self, args: list[ConcreteType]) -> GenericInstanceType:
res = self._instantiate(args)
res.generic_args = args
res.generic_parent = self
return res
@dataclass(eq=False, init=False)
class BuiltinGenericType(GenericType):
constraints_: Callable[[list[ConcreteType]], list[GenericConstraint]]
instantiate_: Callable[[list[ConcreteType]], GenericInstanceType]
def constraints(self, args: list[ConcreteType]) -> list[GenericConstraint]:
return self.constraints_(args)
def _instantiate(self, args: list[ConcreteType]) -> GenericInstanceType:
return self.instantiate_(args)
def create_builtin_type(name: str):
class CreatedType(BuiltinGenericType):
def __str__(self):
return name
res = CreatedType()
return res
TY_OBJECT = None
TY_OBJECT = create_builtin_type("object")
TY_OBJECT.parents = []
TY_BOOL = create_builtin_type("bool")
TY_INT = create_builtin_type("int")
TY_FLOAT = create_builtin_type("float")
TY_STR = create_builtin_type("str")
TY_BYTES = create_builtin_type("bytes")
TY_COMPLEX = create_builtin_type("complex")
TY_NONE = create_builtin_type("NoneType")
def unimpl(*args, **kwargs):
raise NotImplementedError()
def create_builtin_generic_type(name: str):
class CreatedType(BuiltinGenericType):
def __str__(self):
return name
res = CreatedType()
return res
TY_LIST = create_builtin_generic_type("list")
TY_SET = create_builtin_generic_type("set")
TY_DICT = create_builtin_generic_type("dict")
TY_TUPLE = create_builtin_generic_type("tuple")
@dataclass(unsafe_hash=False)
class TypeListType(ConcreteType):
"""
Special type used to represent a list of types.
Used in function types: Callable[[int, int], int]
"""
contents: list[ConcreteType]
def __str__(self):
return f"[{', '.join(map(str, self.contents))}]"
@dataclass(eq=False)
class CallableInstanceType(GenericInstanceType):
parameters: list[ConcreteType]
return_type: ConcreteType
class CallableType(GenericType):
def __str__(self):
return "Callable"
def _instantiate(self, args: list[ConcreteType]) -> GenericInstanceType:
match args:
case [TypeListType([*args]), ret]:
return CallableInstanceType(args, ret)
case _:
raise ValueError
TY_CALLABLE = CallableType()
@dataclass(eq=False)
class ClassTypeType(GenericInstanceType):
inner_type: BaseType
class ClassType(GenericType):
def __str__(self):
return "Type"
def _instantiate(self, args: list[ConcreteType]) -> GenericInstanceType:
return ClassTypeType(*args)
TY_TYPE = ClassType()
......@@ -66,7 +66,7 @@ def highlight(code, full=False):
return cf.yellow("<None>")
if type(code) == list:
return repr([highlight(x) for x in code])
from transpiler.phases.typing import BaseType
from transpiler.phases.typing.types import BaseType
if isinstance(code, ast.AST):
return cf.italic_grey60(f"[{type(code).__name__}] ") + highlight(ast.unparse(code))
elif isinstance(code, BaseType):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment