Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
T
typon-compiler
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
typon
typon-compiler
Commits
759c796e
Commit
759c796e
authored
Aug 29, 2023
by
Tom Niget
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Merge .members and .methods; fix unification for hierarchy lookup
parent
993809ef
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
135 additions
and
75 deletions
+135
-75
trans/tests/a_a_enumtest.py
trans/tests/a_a_enumtest.py
+11
-0
trans/tests/a_calcbasic.py2
trans/tests/a_calcbasic.py2
+0
-0
trans/transpiler/phases/emit_cpp/block.py
trans/transpiler/phases/emit_cpp/block.py
+1
-5
trans/transpiler/phases/emit_cpp/class_.py
trans/transpiler/phases/emit_cpp/class_.py
+4
-4
trans/transpiler/phases/emit_cpp/module.py
trans/transpiler/phases/emit_cpp/module.py
+3
-3
trans/transpiler/phases/emit_cpp/search.py
trans/transpiler/phases/emit_cpp/search.py
+2
-0
trans/transpiler/phases/typing/__init__.py
trans/transpiler/phases/typing/__init__.py
+2
-2
trans/transpiler/phases/typing/annotations.py
trans/transpiler/phases/typing/annotations.py
+1
-1
trans/transpiler/phases/typing/block.py
trans/transpiler/phases/typing/block.py
+9
-6
trans/transpiler/phases/typing/class_.py
trans/transpiler/phases/typing/class_.py
+5
-5
trans/transpiler/phases/typing/common.py
trans/transpiler/phases/typing/common.py
+2
-2
trans/transpiler/phases/typing/expr.py
trans/transpiler/phases/typing/expr.py
+27
-10
trans/transpiler/phases/typing/scope.py
trans/transpiler/phases/typing/scope.py
+1
-5
trans/transpiler/phases/typing/stdlib.py
trans/transpiler/phases/typing/stdlib.py
+4
-3
trans/transpiler/phases/typing/types.py
trans/transpiler/phases/typing/types.py
+63
-29
No files found.
trans/tests/a_a_enumtest.py
0 → 100644
View file @
759c796e
# coding: utf-8
from
enum
import
Enum
class
TokenType
(
Enum
):
NUMBER
=
1
PARENTHESIS
=
2
OPERATION
=
3
if
__name__
==
"__main__"
:
x
=
TokenType
.
NUMBER
\ No newline at end of file
trans/tests/
calcbasic.py
→
trans/tests/
a_calcbasic.py2
View file @
759c796e
File moved
trans/transpiler/phases/emit_cpp/block.py
View file @
759c796e
...
...
@@ -124,11 +124,7 @@ class BlockVisitor(NodeVisitor):
def
visit_ClassDef
(
self
,
node
:
ast
.
ClassDef
):
yield
from
()
def
check
(
self
,
f
):
for
b
in
node
.
body
:
yield
from
self
.
match
(
node
)
has_return
=
next
(
ReturnVisitor
().
check
(
node
),
False
)
has_return
=
ReturnVisitor
().
match
(
node
.
body
)
yield
from
self
.
visit_func_decls
(
node
.
body
,
inner_scope
)
...
...
trans/transpiler/phases/emit_cpp/class_.py
View file @
759c796e
...
...
@@ -29,12 +29,12 @@ class ClassVisitor(NodeVisitor):
yield
"int value;"
yield
"operator int() const { return value; }"
yield
"void py_repr(std::ostream &s) const {"
yield
f's << "
{
node
.
name
}
."
<< value
;'
yield
f's << "
{
node
.
name
}
.";'
yield
"}"
else
:
yield
"void py_repr(std::ostream &s) const {"
yield
"s << '{';"
for
i
,
(
name
,
memb
)
in
enumerate
(
node
.
type
.
member
s
.
items
()):
for
i
,
(
name
,
memb
)
in
enumerate
(
node
.
type
.
field
s
.
items
()):
if
i
!=
0
:
yield
's << ", ";'
yield
f's << "
\
\
"
{
name
}\
\
": ";'
...
...
@@ -63,8 +63,8 @@ class ClassInnerVisitor(NodeVisitor):
scope
:
Scope
def
visit_AnnAssign
(
self
,
node
:
ast
.
AnnAssign
)
->
Iterable
[
str
]:
member
=
self
.
scope
.
obj_type
.
member
s
[
node
.
target
.
id
]
yield
from
self
.
visit
(
member
)
member
=
self
.
scope
.
obj_type
.
field
s
[
node
.
target
.
id
]
yield
from
self
.
visit
(
member
.
type
)
yield
node
.
target
.
id
yield
";"
...
...
trans/transpiler/phases/emit_cpp/module.py
View file @
759c796e
...
...
@@ -26,9 +26,9 @@ class ModuleVisitor(BlockVisitor):
yield
f"namespace py_
{
concrete
}
{{"
yield
f"struct
{
concrete
}
_t {{"
for
name
,
obj
in
alias
.
module_obj
.
member
s
.
items
():
if
obj
.
python_func_used
:
yield
from
self
.
emit_python_func
(
alias
.
name
,
name
,
name
,
obj
)
for
name
,
obj
in
alias
.
module_obj
.
field
s
.
items
():
if
obj
.
type
.
python_func_used
:
yield
from
self
.
emit_python_func
(
alias
.
name
,
name
,
name
,
obj
.
type
)
yield
"} all;"
yield
f"auto& get_all() {{ return all; }}"
...
...
trans/transpiler/phases/emit_cpp/search.py
View file @
759c796e
...
...
@@ -15,4 +15,6 @@ class SearchVisitor(ast.NodeVisitor):
yield
from
self
.
visit
(
value
)
def
match
(
self
,
node
)
->
bool
:
if
type
(
node
)
==
list
:
return
any
(
self
.
match
(
n
)
for
n
in
node
)
return
next
(
self
.
visit
(
node
),
False
)
trans/transpiler/phases/typing/__init__.py
View file @
759c796e
...
...
@@ -5,7 +5,7 @@ from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope
from
transpiler.phases.typing.stdlib
import
PRELUDE
,
StdlibVisitor
from
transpiler.phases.typing.types
import
TY_TYPE
,
TY_INT
,
TY_STR
,
TY_BOOL
,
TY_COMPLEX
,
TY_NONE
,
FunctionType
,
\
TypeVariable
,
CppType
,
PyList
,
TypeType
,
Forked
,
Task
,
Future
,
PyIterator
,
TupleType
,
TypeOperator
,
BaseType
,
\
ModuleType
,
TY_BYTES
,
TY_FLOAT
,
PyDict
,
TY_SLICE
,
TY_OBJECT
,
BuiltinFeature
,
UnionType
ModuleType
,
TY_BYTES
,
TY_FLOAT
,
PyDict
,
TY_SLICE
,
TY_OBJECT
,
BuiltinFeature
,
UnionType
,
MemberDef
PRELUDE
.
vars
.
update
({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
...
...
@@ -46,7 +46,7 @@ typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
def
make_module
(
name
:
str
,
scope
:
Scope
)
->
BaseType
:
ty
=
ModuleType
([],
f"
{
name
}
"
)
for
n
,
v
in
scope
.
vars
.
items
():
ty
.
members
[
n
]
=
v
.
type
ty
.
fields
[
n
]
=
MemberDef
(
v
.
type
,
v
.
val
,
False
)
return
ty
...
...
trans/transpiler/phases/typing/annotations.py
View file @
759c796e
...
...
@@ -57,7 +57,7 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
def
visit_Attribute
(
self
,
node
:
ast
.
Attribute
)
->
BaseType
:
left
=
self
.
visit
(
node
.
value
)
res
=
left
.
members
[
node
.
attr
]
res
=
left
.
fields
[
node
.
attr
].
type
assert
isinstance
(
res
,
TypeType
)
return
res
.
type_object
...
...
trans/transpiler/phases/typing/block.py
View file @
759c796e
...
...
@@ -11,7 +11,8 @@ from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from
transpiler.phases.typing.class_
import
ScoperClassVisitor
from
transpiler.phases.typing.scope
import
VarDecl
,
VarKind
,
ScopeKind
,
Scope
from
transpiler.phases.typing.types
import
BaseType
,
TypeVariable
,
FunctionType
,
\
Promise
,
TY_NONE
,
PromiseKind
,
TupleType
,
UserType
,
TypeType
,
ModuleType
,
BuiltinFeature
,
TY_INT
Promise
,
TY_NONE
,
PromiseKind
,
TupleType
,
UserType
,
TypeType
,
ModuleType
,
BuiltinFeature
,
TY_INT
,
MemberDef
,
\
RuntimeValue
from
transpiler.phases.utils
import
PlainBlock
,
AnnotationName
...
...
@@ -167,7 +168,7 @@ class ScoperBlockVisitor(ScoperVisitor):
init_method
=
ast
.
FunctionDef
(
name
=
"__init__"
,
args
=
ast
.
arguments
(
args
=
[
ast
.
arg
(
arg
=
"self"
),
*
[
ast
.
arg
(
arg
=
n
)
for
n
in
ctype
.
members
]],
args
=
[
ast
.
arg
(
arg
=
"self"
),
*
[
ast
.
arg
(
arg
=
n
)
for
n
in
ctype
.
get_members
()
]],
defaults
=
[],
kw_defaults
=
[],
kwarg
=
None
,
...
...
@@ -179,7 +180,7 @@ class ScoperBlockVisitor(ScoperVisitor):
targets
=
[
ast
.
Attribute
(
value
=
ast
.
Name
(
id
=
"self"
),
attr
=
n
)],
value
=
ast
.
Name
(
id
=
n
),
**
lnd
)
for
n
in
ctype
.
members
)
for
n
in
ctype
.
get_members
()
],
decorator_list
=
[],
returns
=
None
,
...
...
@@ -195,9 +196,11 @@ class ScoperBlockVisitor(ScoperVisitor):
base
=
self
.
expr
().
visit
(
base
)
if
is_builtin
(
base
,
"Enum"
):
ctype
.
parents
.
append
(
TY_INT
)
for
k
in
ctype
.
members
:
ctype
.
members
[
k
]
=
ctype
ctype
.
members
[
"value"
]
=
TY_INT
for
k
,
m
in
ctype
.
fields
.
items
():
m
.
type
=
ctype
m
.
val
=
ast
.
literal_eval
(
m
.
val
)
assert
type
(
m
.
val
)
==
int
ctype
.
fields
[
"value"
]
=
MemberDef
(
TY_INT
)
lnd
=
linenodata
(
node
)
init_method
=
ast
.
FunctionDef
(
name
=
"__init__"
,
...
...
trans/transpiler/phases/typing/class_.py
View file @
759c796e
...
...
@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
from
transpiler.phases.typing
import
FunctionType
,
ScopeKind
,
VarDecl
,
VarKind
,
TY_NONE
from
transpiler.phases.typing.common
import
ScoperVisitor
from
transpiler.phases.typing.types
import
PromiseKind
,
Promise
,
BaseType
from
transpiler.phases.typing.types
import
PromiseKind
,
Promise
,
BaseType
,
MemberDef
@
dataclass
...
...
@@ -15,15 +15,15 @@ class ScoperClassVisitor(ScoperVisitor):
assert
node
.
value
is
None
,
"Class field should not have a value"
assert
node
.
simple
==
1
,
"Class field should be simple (identifier, not parenthesized)"
assert
isinstance
(
node
.
target
,
ast
.
Name
)
self
.
scope
.
obj_type
.
members
[
node
.
target
.
id
]
=
self
.
visit_annotation
(
node
.
annotation
)
self
.
scope
.
obj_type
.
fields
[
node
.
target
.
id
]
=
MemberDef
(
self
.
visit_annotation
(
node
.
annotation
)
)
def
visit_Assign
(
self
,
node
:
ast
.
Assign
):
assert
len
(
node
.
targets
)
==
1
,
"C
lass field should be assigned to only once
"
assert
len
(
node
.
targets
)
==
1
,
"C
an't use destructuring in class static member
"
assert
isinstance
(
node
.
targets
[
0
],
ast
.
Name
)
node
.
is_declare
=
True
valtype
=
self
.
expr
().
visit
(
node
.
value
)
node
.
targets
[
0
].
type
=
valtype
self
.
scope
.
obj_type
.
members
[
node
.
targets
[
0
].
id
]
=
valtype
self
.
scope
.
obj_type
.
fields
[
node
.
targets
[
0
].
id
]
=
MemberDef
(
valtype
,
node
.
value
)
def
visit_FunctionDef
(
self
,
node
:
ast
.
FunctionDef
):
ftype
=
self
.
parse_function
(
node
)
...
...
@@ -32,5 +32,5 @@ class ScoperClassVisitor(ScoperVisitor):
if
node
.
name
!=
"__init__"
:
ftype
.
return_type
=
Promise
(
ftype
.
return_type
,
PromiseKind
.
TASK
)
ftype
.
is_method
=
True
self
.
scope
.
obj_type
.
methods
[
node
.
name
]
=
ftype
self
.
scope
.
obj_type
.
fields
[
node
.
name
]
=
MemberDef
(
ftype
,
node
)
return
(
node
,
inner
)
trans/transpiler/phases/typing/common.py
View file @
759c796e
...
...
@@ -108,7 +108,7 @@ class ScoperVisitor(NodeVisitorSeq):
def
get_iter
(
seq_type
):
try
:
iter_type
=
seq_type
.
methods
[
"__iter__"
]
.
return_type
iter_type
=
seq_type
.
fields
[
"__iter__"
].
type
.
return_type
except
:
from
transpiler.phases.typing.exceptions
import
NotIterableError
raise
NotIterableError
(
seq_type
)
...
...
@@ -116,7 +116,7 @@ def get_iter(seq_type):
def
get_next
(
iter_type
):
try
:
next_type
=
iter_type
.
methods
[
"__next__"
]
.
return_type
next_type
=
iter_type
.
fields
[
"__next__"
].
type
.
return_type
except
:
from
transpiler.phases.typing.exceptions
import
NotIteratorError
raise
NotIteratorError
(
iter_type
)
...
...
trans/transpiler/phases/typing/expr.py
View file @
759c796e
...
...
@@ -174,6 +174,11 @@ class ScoperExprVisitor(ScoperVisitor):
def
visit_getattr
(
self
,
ltype
:
BaseType
,
name
:
str
)
->
BaseType
:
bound
=
True
if
isinstance
(
ltype
,
TypeType
):
# if mdecl := ltype.static_members.get(name):
# attr = mdecl.type
# if getattr(attr, "is_python_func", False):
# attr.python_func_used = True
# return attr
ltype
=
ltype
.
type_object
bound
=
False
if
isinstance
(
ltype
,
abc
.
ABCMeta
):
...
...
@@ -182,16 +187,28 @@ class ScoperExprVisitor(ScoperVisitor):
if
not
all
(
arg
.
annotation
==
BaseType
for
arg
in
args
):
raise
NotImplementedError
(
"I don't know how to handle this type"
)
ltype
=
ltype
(
*
(
TypeVariable
()
for
_
in
args
))
if
attr
:
=
ltype
.
members
.
get
(
name
):
if
getattr
(
attr
,
"is_python_func"
,
False
):
attr
.
python_func_used
=
True
return
attr
if
meth
:
=
ltype
.
methods
.
get
(
name
):
meth
=
meth
.
gen_sub
(
ltype
,
{})
if
bound
:
return
meth
.
remove_self
()
else
:
return
meth
# if mdecl := ltype.members.get(name):
# attr = mdecl.type
# if getattr(attr, "is_python_func", False):
# attr.python_func_used = True
# return attr
# if meth := ltype.methods.get(name):
# meth = meth.gen_sub(ltype, {})
# if bound:
# return meth.remove_self()
# else:
# return meth
if
field
:
=
ltype
.
fields
.
get
(
name
):
ty
=
field
.
type
if
getattr
(
ty
,
"is_python_func"
,
False
):
ty
.
python_func_used
=
True
if
isinstance
(
ty
,
FunctionType
):
ty
=
ty
.
gen_sub
(
ltype
,
{})
if
bound
and
field
.
in_class_def
:
return
ty
.
remove_self
()
return
ty
from
transpiler.phases.typing.exceptions
import
MissingAttributeError
parents
=
ltype
.
iter_hierarchy_recursive
()
next
(
parents
)
...
...
trans/transpiler/phases/typing/scope.py
View file @
759c796e
...
...
@@ -3,7 +3,7 @@ from dataclasses import field, dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Dict
,
List
,
Any
from
transpiler.phases.typing.types
import
BaseType
from
transpiler.phases.typing.types
import
BaseType
,
RuntimeValue
class
VarKind
(
Enum
):
...
...
@@ -23,10 +23,6 @@ class VarType:
pass
class
RuntimeValue
:
pass
@
dataclass
class
VarDecl
:
kind
:
VarKind
...
...
trans/transpiler/phases/typing/stdlib.py
View file @
759c796e
...
...
@@ -8,7 +8,8 @@ from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from
transpiler.phases.typing.common
import
PRELUDE
from
transpiler.phases.typing.expr
import
ScoperExprVisitor
from
transpiler.phases.typing.scope
import
Scope
,
VarDecl
,
VarKind
,
ScopeKind
from
transpiler.phases.typing.types
import
BaseType
,
TypeOperator
,
FunctionType
,
TY_VARARG
,
TypeType
,
TypeVariable
from
transpiler.phases.typing.types
import
BaseType
,
TypeOperator
,
FunctionType
,
TY_VARARG
,
TypeType
,
TypeVariable
,
\
MemberDef
from
transpiler.phases.utils
import
NodeVisitorSeq
...
...
@@ -36,7 +37,7 @@ class StdlibVisitor(NodeVisitorSeq):
if
isinstance
(
self
.
cur_class
.
type_object
,
ABCMeta
):
raise
NotImplementedError
else
:
self
.
cur_class
.
type_object
.
members
[
node
.
target
.
id
]
=
ty
.
gen_sub
(
self
.
cur_class
.
type_object
,
self
.
typevars
)
self
.
cur_class
.
type_object
.
fields
[
node
.
target
.
id
]
=
MemberDef
(
ty
.
gen_sub
(
self
.
cur_class
.
type_object
,
self
.
typevars
)
)
self
.
scope
.
vars
[
node
.
target
.
id
]
=
VarDecl
(
VarKind
.
LOCAL
,
ty
)
def
visit_ImportFrom
(
self
,
node
:
ast
.
ImportFrom
):
...
...
@@ -110,7 +111,7 @@ class StdlibVisitor(NodeVisitorSeq):
if
isinstance
(
self
.
cur_class
.
type_object
,
ABCMeta
):
self
.
cur_class
.
type_object
.
gen_methods
[
node
.
name
]
=
lambda
t
:
ty
.
gen_sub
(
t
,
self
.
typevars
)
else
:
self
.
cur_class
.
type_object
.
methods
[
node
.
name
]
=
ty
.
gen_sub
(
self
.
cur_class
.
type_object
,
self
.
typevars
)
self
.
cur_class
.
type_object
.
fields
[
node
.
name
]
=
MemberDef
(
ty
.
gen_sub
(
self
.
cur_class
.
type_object
,
self
.
typevars
)
)
self
.
scope
.
vars
[
node
.
name
]
=
VarDecl
(
VarKind
.
LOCAL
,
ty
)
def
visit_Assert
(
self
,
node
:
ast
.
Assert
):
...
...
trans/transpiler/phases/typing/types.py
View file @
759c796e
...
...
@@ -13,12 +13,36 @@ def get_default_parents():
return
[
obj
]
return
[]
class
RuntimeValue
:
pass
@
dataclass
class
MemberDef
:
type
:
"BaseType"
val
:
typing
.
Any
=
RuntimeValue
()
in_class_def
:
bool
=
True
@
dataclass
class
UnifyMode
:
search_hierarchy
:
bool
=
True
match_protocol
:
bool
=
True
UnifyMode
.
NORMAL
=
UnifyMode
()
UnifyMode
.
EXACT
=
UnifyMode
(
False
,
False
)
@
dataclass
(
eq
=
False
)
class
BaseType
(
ABC
):
members
:
Dict
[
str
,
"BaseType"
]
=
field
(
default_factory
=
dict
,
init
=
False
)
methods
:
Dict
[
str
,
"FunctionType"
]
=
field
(
default_factory
=
dict
,
init
=
False
)
#members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
#methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
fields
:
Dict
[
str
,
"MemberDef"
]
=
field
(
default_factory
=
dict
,
init
=
False
)
parents
:
List
[
"BaseType"
]
=
field
(
default_factory
=
get_default_parents
,
init
=
False
)
typevars
:
List
[
"TypeVariable"
]
=
field
(
default_factory
=
list
,
init
=
False
)
#static_members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
def
get_members
(
self
):
return
{
n
:
m
for
n
,
m
in
self
.
fields
.
items
()
if
type
(
m
.
val
)
is
RuntimeValue
}
def
get_parents
(
self
)
->
List
[
"BaseType"
]:
...
...
@@ -41,21 +65,29 @@ class BaseType(ABC):
queue
.
put
(
p
)
def
inherits_from
(
self
,
other
:
"BaseType"
)
->
bool
:
return
other
in
self
.
iter_hierarchy_recursive
()
from
transpiler.exceptions
import
CompileError
for
parent
in
self
.
iter_hierarchy_recursive
():
try
:
parent
.
unify
(
other
,
UnifyMode
.
EXACT
)
except
CompileError
:
pass
else
:
return
True
return
False
def
resolve
(
self
)
->
"BaseType"
:
return
self
@
abstractmethod
def
unify_internal
(
self
,
other
:
"BaseType"
):
def
unify_internal
(
self
,
other
:
"BaseType"
,
mode
:
UnifyMode
):
pass
def
unify
(
self
,
other
:
"BaseType"
):
def
unify
(
self
,
other
:
"BaseType"
,
mode
=
UnifyMode
.
NORMAL
):
a
,
b
=
self
.
resolve
(),
other
.
resolve
()
TB
=
f"unifying
{
highlight
(
a
)
}
and
{
highlight
(
b
)
}
"
if
isinstance
(
b
,
TypeVariable
):
a
,
b
=
b
,
a
a
.
unify_internal
(
b
)
a
.
unify_internal
(
b
,
mode
)
def
contains
(
self
,
other
:
"BaseType"
)
->
bool
:
needle
,
haystack
=
other
.
resolve
(),
self
.
resolve
()
...
...
@@ -86,7 +118,7 @@ class MagicType(BaseType, typing.Generic[T]):
super
().
__init__
()
self
.
val
=
val
def
unify_internal
(
self
,
other
:
"BaseType"
):
def
unify_internal
(
self
,
other
:
"BaseType"
,
mode
:
UnifyMode
):
if
type
(
self
)
!=
type
(
other
)
or
self
.
val
!=
other
.
val
:
from
transpiler.phases.typing.exceptions
import
TypeMismatchError
,
TypeMismatchKind
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
...
...
@@ -128,7 +160,7 @@ class TypeVariable(BaseType):
return
self
return
self
.
resolved
.
resolve
()
def
unify_internal
(
self
,
other
:
BaseType
):
def
unify_internal
(
self
,
other
:
BaseType
,
mode
:
UnifyMode
):
if
self
is
not
other
:
if
other
.
contains
(
self
):
from
transpiler.phases.typing.exceptions
import
RecursiveTypeUnificationError
...
...
@@ -178,19 +210,19 @@ class TypeOperator(BaseType, ABC):
if
self
.
name
is
None
:
self
.
name
=
self
.
__class__
.
__name__
for
name
,
factory
in
self
.
gen_methods
.
items
():
self
.
methods
[
name
]
=
factory
(
self
)
self
.
fields
[
name
]
=
MemberDef
(
factory
(
self
)
)
for
gp
in
self
.
gen_parents
:
if
not
isinstance
(
gp
,
BaseType
):
gp
=
gp
(
self
.
args
)
self
.
parents
.
append
(
gp
)
self
.
methods
=
{
**
gp
.
methods
,
**
self
.
metho
ds
}
self
.
fields
=
{
**
gp
.
fields
,
**
self
.
fiel
ds
}
self
.
is_protocol
=
self
.
is_protocol
or
self
.
is_protocol_gen
self
.
_add_default_eq
()
def
_add_default_eq
(
self
):
if
"__eq__"
not
in
self
.
metho
ds
:
if
"__eq__"
not
in
self
.
fiel
ds
:
if
"DEFAULT_EQ"
in
globals
():
self
.
methods
[
"__eq__"
]
=
DEFAULT_EQ
self
.
fields
[
"__eq__"
]
=
MemberDef
(
DEFAULT_EQ
)
def
matches_protocol
(
self
,
protocol
:
"TypeOperator"
):
if
hash
(
protocol
)
in
self
.
match_cache
:
...
...
@@ -199,33 +231,35 @@ class TypeOperator(BaseType, ABC):
try
:
dupl
=
protocol
.
gen_sub
(
self
,
{
v
.
name
:
(
TypeVariable
(
v
.
name
)
if
isinstance
(
v
.
resolve
(),
TypeVariable
)
else
v
)
for
v
in
protocol
.
args
})
self
.
match_cache
.
add
(
hash
(
protocol
))
for
name
,
ty
in
dupl
.
metho
ds
.
items
():
for
name
,
ty
in
dupl
.
fiel
ds
.
items
():
if
name
==
"__eq__"
:
continue
if
name
not
in
self
.
metho
ds
:
if
name
not
in
self
.
fiel
ds
:
raise
ProtocolMismatchError
(
self
,
protocol
,
f"missing method
{
name
}
"
)
corresp
=
self
.
methods
[
name
]
corresp
.
remove_self
().
unify
(
ty
.
remove_self
())
corresp
=
self
.
fields
[
name
].
type
corresp
.
remove_self
().
unify
(
ty
.
type
.
remove_self
())
except
TypeMismatchError
as
e
:
if
hash
(
protocol
)
in
self
.
match_cache
:
self
.
match_cache
.
remove
(
hash
(
protocol
))
raise
ProtocolMismatchError
(
self
,
protocol
,
e
)
def
unify_internal
(
self
,
other
:
BaseType
):
def
unify_internal
(
self
,
other
:
BaseType
,
mode
:
UnifyMode
):
from
transpiler.phases.typing.exceptions
import
TypeMismatchError
,
TypeMismatchKind
# TODO(zdimension): this is really broken... but it would be nice
# if from_node := next(filter(None, (getattr(x, "from_node", None) for x in (other, self))), None):
# TB_NODE = from_node
if
not
isinstance
(
other
,
TypeOperator
):
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
if
other
.
is_protocol
and
not
self
.
is_protocol
:
return
other
.
unify_internal
(
self
)
if
self
.
is_protocol
and
not
other
.
is_protocol
:
return
other
.
matches_protocol
(
self
)
# TODO: doesn't print the correct type in the error message
if
mode
.
match_protocol
:
if
other
.
is_protocol
and
not
self
.
is_protocol
:
return
other
.
unify_internal
(
self
,
mode
)
if
self
.
is_protocol
and
not
other
.
is_protocol
:
return
other
.
matches_protocol
(
self
)
# TODO: doesn't print the correct type in the error message
assert
self
.
is_protocol
==
other
.
is_protocol
if
type
(
self
)
!=
type
(
other
):
# and ((TY_NONE not in {self, other}) or isinstance(({self, other} - {TY_NONE}).pop(), UnionType)):
if
self
.
inherits_from
(
other
)
or
other
.
inherits_from
(
self
):
return
if
mode
.
search_hierarchy
:
if
self
.
inherits_from
(
other
)
or
other
.
inherits_from
(
self
):
return
# for parent in other.get_parents():
# try:
# self.unify(parent)
...
...
@@ -242,8 +276,8 @@ class TypeOperator(BaseType, ABC):
# return
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
if
len
(
self
.
args
)
<
len
(
other
.
args
):
return
other
.
unify_internal
(
self
)
if
len
(
self
.
args
)
==
0
:
return
other
.
unify_internal
(
self
,
mode
)
if
True
or
len
(
self
.
args
)
==
0
:
# todo: why check len?
if
self
.
name
!=
other
.
name
:
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
for
i
,
(
a
,
b
)
in
enumerate
(
zip_longest
(
self
.
args
,
other
.
args
)):
...
...
@@ -292,7 +326,7 @@ class TypeOperator(BaseType, ABC):
for
k
,
v
in
self
.
__dict__
.
items
():
setattr
(
res
,
k
,
v
)
res
.
args
=
[
arg
.
resolve
().
gen_sub
(
this
,
vardict
,
cache
)
for
arg
in
self
.
args
]
res
.
methods
=
{
k
:
v
.
gen_sub
(
this
,
vardict
,
cache
)
for
k
,
v
in
self
.
metho
ds
.
items
()}
res
.
fields
=
{
k
:
dataclasses
.
replace
(
v
,
type
=
v
.
type
.
gen_sub
(
this
,
vardict
,
cache
))
for
k
,
v
in
self
.
fiel
ds
.
items
()}
res
.
parents
=
[
p
.
gen_sub
(
this
,
vardict
,
cache
)
for
p
in
self
.
parents
]
#res.is_protocol = self.is_protocol
return
res
...
...
@@ -466,10 +500,10 @@ class Promise(TypeOperator, ABC):
if
value
==
PromiseKind
.
GENERATOR
:
f_iter
=
FunctionType
([],
self
)
f_iter
.
is_method
=
True
self
.
methods
[
"__iter__"
]
=
f_iter
self
.
fields
[
"__iter__"
]
=
MemberDef
(
f_iter
,
())
f_next
=
FunctionType
([],
self
.
return_type
)
f_next
.
is_method
=
True
self
.
methods
[
"__next__"
]
=
f_next
self
.
fields
[
"__next__"
]
=
MemberDef
(
f_next
,
())
self
.
args
[
1
].
val
=
value
def
__str__
(
self
):
...
...
@@ -506,7 +540,7 @@ class UserType(TypeOperator):
def
__init__
(
self
,
name
:
str
):
super
().
__init__
([],
name
=
name
,
is_reference
=
True
)
def
unify_internal
(
self
,
other
:
"BaseType"
):
def
unify_internal
(
self
,
other
:
"BaseType"
,
mode
:
UnifyMode
):
if
type
(
self
)
!=
type
(
other
):
from
transpiler.phases.typing.exceptions
import
TypeMismatchError
,
TypeMismatchKind
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment