Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
T
typon-compiler
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
typon
typon-compiler
Commits
28ff809a
Commit
28ff809a
authored
Jun 11, 2023
by
Tom Niget
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add initial support for reference-counted user-defined classes
parent
e4d6f647
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
200 additions
and
19 deletions
+200
-19
trans/tests/a_a_usertype.py
trans/tests/a_a_usertype.py
+22
-0
trans/transpiler/phases/emit_cpp/__init__.py
trans/transpiler/phases/emit_cpp/__init__.py
+5
-1
trans/transpiler/phases/emit_cpp/block.py
trans/transpiler/phases/emit_cpp/block.py
+10
-2
trans/transpiler/phases/emit_cpp/class_.py
trans/transpiler/phases/emit_cpp/class_.py
+66
-0
trans/transpiler/phases/emit_cpp/expr.py
trans/transpiler/phases/emit_cpp/expr.py
+5
-1
trans/transpiler/phases/emit_cpp/module.py
trans/transpiler/phases/emit_cpp/module.py
+4
-0
trans/transpiler/phases/typing/block.py
trans/transpiler/phases/typing/block.py
+23
-10
trans/transpiler/phases/typing/class_.py
trans/transpiler/phases/typing/class_.py
+38
-0
trans/transpiler/phases/typing/common.py
trans/transpiler/phases/typing/common.py
+11
-2
trans/transpiler/phases/typing/expr.py
trans/transpiler/phases/typing/expr.py
+5
-1
trans/transpiler/phases/typing/scope.py
trans/transpiler/phases/typing/scope.py
+1
-0
trans/transpiler/phases/typing/types.py
trans/transpiler/phases/typing/types.py
+10
-2
No files found.
trans/tests/a_a_usertype.py
0 → 100644
View file @
28ff809a
# coding: utf-8
class
Person
:
name
:
str
age
:
int
def
__init__
(
self
,
name
:
str
,
age
:
int
):
self
.
name
=
name
self
.
age
=
age
def
afficher
(
self
):
print
(
self
.
name
,
self
.
age
)
def
creer
():
return
Person
(
"jean"
,
123
)
if
__name__
==
"__main__"
:
x
=
creer
()
print
(
x
.
name
)
print
(
x
.
age
)
x
.
afficher
()
trans/transpiler/phases/emit_cpp/__init__.py
View file @
28ff809a
...
@@ -6,7 +6,7 @@ from typing import Iterable
...
@@ -6,7 +6,7 @@ from typing import Iterable
from
transpiler.phases.emit_cpp.consts
import
MAPPINGS
from
transpiler.phases.emit_cpp.consts
import
MAPPINGS
from
transpiler.phases.typing
import
TypeVariable
from
transpiler.phases.typing
import
TypeVariable
from
transpiler.phases.typing.types
import
BaseType
,
TY_INT
,
TY_BOOL
,
TY_NONE
,
Promise
,
PromiseKind
from
transpiler.phases.typing.types
import
BaseType
,
TY_INT
,
TY_BOOL
,
TY_NONE
,
Promise
,
PromiseKind
,
TY_STR
,
UserType
from
transpiler.utils
import
UnsupportedNodeError
from
transpiler.utils
import
UnsupportedNodeError
class
UniversalVisitor
:
class
UniversalVisitor
:
...
@@ -55,6 +55,10 @@ class NodeVisitor(UniversalVisitor):
...
@@ -55,6 +55,10 @@ class NodeVisitor(UniversalVisitor):
yield
"bool"
yield
"bool"
elif
node
is
TY_NONE
:
elif
node
is
TY_NONE
:
yield
"void"
yield
"void"
elif
node
is
TY_STR
:
yield
"std::string"
elif
isinstance
(
node
,
UserType
):
yield
f"std::shared_ptr<decltype(
{
node
.
name
}
)::type>"
elif
isinstance
(
node
,
Promise
):
elif
isinstance
(
node
,
Promise
):
yield
"typon::"
yield
"typon::"
if
node
.
kind
==
PromiseKind
.
TASK
:
if
node
.
kind
==
PromiseKind
.
TASK
:
...
...
trans/transpiler/phases/emit_cpp/block.py
View file @
28ff809a
...
@@ -22,6 +22,9 @@ class BlockVisitor(NodeVisitor):
...
@@ -22,6 +22,9 @@ class BlockVisitor(NodeVisitor):
def
expr
(
self
)
->
ExpressionVisitor
:
def
expr
(
self
)
->
ExpressionVisitor
:
return
ExpressionVisitor
(
self
.
scope
,
self
.
generator
)
return
ExpressionVisitor
(
self
.
scope
,
self
.
generator
)
def
visit_Pass
(
self
,
node
:
ast
.
Pass
)
->
Iterable
[
str
]:
yield
";"
def
visit_FunctionDef
(
self
,
node
:
ast
.
FunctionDef
)
->
Iterable
[
str
]:
def
visit_FunctionDef
(
self
,
node
:
ast
.
FunctionDef
)
->
Iterable
[
str
]:
if
getattr
(
node
,
"is_main"
,
False
):
if
getattr
(
node
,
"is_main"
,
False
):
# Special case handling for Python's interesting way of defining an entry point.
# Special case handling for Python's interesting way of defining an entry point.
...
@@ -83,11 +86,14 @@ class BlockVisitor(NodeVisitor):
...
@@ -83,11 +86,14 @@ class BlockVisitor(NodeVisitor):
yield
"}"
yield
"}"
yield
f"}}
{
node
.
name
}
;"
yield
f"}}
{
node
.
name
}
;"
def
visit_func_new
(
self
,
node
:
ast
.
FunctionDef
)
->
Iterable
[
str
]:
def
visit_func_new
(
self
,
node
:
ast
.
FunctionDef
,
skip_first_arg
:
bool
=
False
)
->
Iterable
[
str
]:
yield
from
self
.
visit
(
node
.
type
.
return_type
)
yield
from
self
.
visit
(
node
.
type
.
return_type
)
yield
"operator()"
yield
"operator()"
yield
"("
yield
"("
for
i
,
(
arg
,
argty
)
in
enumerate
(
zip
(
node
.
args
.
args
,
node
.
type
.
parameters
)):
args_iter
=
zip
(
node
.
args
.
args
,
node
.
type
.
parameters
)
if
skip_first_arg
:
next
(
args_iter
)
for
i
,
(
arg
,
argty
)
in
enumerate
(
args_iter
):
if
i
!=
0
:
if
i
!=
0
:
yield
", "
yield
", "
yield
from
self
.
visit
(
argty
)
yield
from
self
.
visit
(
argty
)
...
@@ -241,6 +247,8 @@ class BlockVisitor(NodeVisitor):
...
@@ -241,6 +247,8 @@ class BlockVisitor(NodeVisitor):
yield
name
yield
name
elif
isinstance
(
lvalue
,
ast
.
Subscript
):
elif
isinstance
(
lvalue
,
ast
.
Subscript
):
yield
from
self
.
expr
().
visit
(
lvalue
)
yield
from
self
.
expr
().
visit
(
lvalue
)
elif
isinstance
(
lvalue
,
ast
.
Attribute
):
yield
from
self
.
expr
().
visit
(
lvalue
)
else
:
else
:
raise
NotImplementedError
(
lvalue
)
raise
NotImplementedError
(
lvalue
)
...
...
trans/transpiler/phases/emit_cpp/class_.py
0 → 100644
View file @
28ff809a
# coding: utf-8
import
ast
from
typing
import
Iterable
from
dataclasses
import
dataclass
from
transpiler.phases.typing.scope
import
Scope
from
transpiler.phases.emit_cpp
import
NodeVisitor
class
ClassVisitor
(
NodeVisitor
):
def
visit_ClassDef
(
self
,
node
:
ast
.
ClassDef
)
->
Iterable
[
str
]:
yield
"struct {"
yield
"struct type {"
inner
=
ClassInnerVisitor
(
node
.
inner_scope
)
for
stmt
in
node
.
body
:
yield
from
inner
.
visit
(
stmt
)
yield
"template<typename... T> type(T&&... args) {"
yield
"__init__(std::forward<T>(args)...);"
yield
"}"
yield
"type(const type&) = delete;"
yield
"type(type&&) = delete;"
yield
"};"
yield
"template<typename... T> auto operator()(T&&... args) {"
yield
"return std::make_shared<type>(std::forward<T>(args)...);"
yield
"}"
outer
=
ClassOuterVisitor
(
node
.
inner_scope
)
for
stmt
in
node
.
body
:
yield
from
outer
.
visit
(
stmt
)
yield
f"}}
{
node
.
name
}
;"
@
dataclass
class
ClassInnerVisitor
(
NodeVisitor
):
scope
:
Scope
def
visit_AnnAssign
(
self
,
node
:
ast
.
AnnAssign
)
->
Iterable
[
str
]:
member
=
self
.
scope
.
obj_type
.
members
[
node
.
target
.
id
]
yield
from
self
.
visit
(
member
)
yield
node
.
target
.
id
yield
";"
def
visit_FunctionDef
(
self
,
node
:
ast
.
FunctionDef
)
->
Iterable
[
str
]:
yield
"struct {"
yield
"type* self;"
from
transpiler.phases.emit_cpp.block
import
BlockVisitor
yield
from
BlockVisitor
(
self
.
scope
).
visit_func_new
(
node
,
True
)
yield
f"}}
{
node
.
name
}
{{ this }};"
@
dataclass
class
ClassOuterVisitor
(
NodeVisitor
):
scope
:
Scope
def
visit_AnnAssign
(
self
,
node
:
ast
.
AnnAssign
)
->
Iterable
[
str
]:
yield
""
def
visit_FunctionDef
(
self
,
node
:
ast
.
FunctionDef
)
->
Iterable
[
str
]:
yield
"struct {"
yield
"template<typename... T>"
yield
"auto operator()(type& self, T&&... args) {"
yield
f"return self.
{
node
.
name
}
(std::forward<T>(args)...);"
yield
"}"
yield
f"}}
{
node
.
name
}
;"
trans/transpiler/phases/emit_cpp/expr.py
View file @
28ff809a
...
@@ -3,6 +3,7 @@ import ast
...
@@ -3,6 +3,7 @@ import ast
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Iterable
from
typing
import
List
,
Iterable
from
transpiler.phases.typing.types
import
UserType
from
transpiler.utils
import
compare_ast
from
transpiler.utils
import
compare_ast
from
transpiler.consts
import
SYMBOLS
,
PRECEDENCE_LEVELS
from
transpiler.consts
import
SYMBOLS
,
PRECEDENCE_LEVELS
from
transpiler.phases.emit_cpp
import
CoroutineMode
,
join
,
NodeVisitor
from
transpiler.phases.emit_cpp
import
CoroutineMode
,
join
,
NodeVisitor
...
@@ -166,7 +167,10 @@ class ExpressionVisitor(NodeVisitor):
...
@@ -166,7 +167,10 @@ class ExpressionVisitor(NodeVisitor):
def
visit_Attribute
(
self
,
node
:
ast
.
Attribute
)
->
Iterable
[
str
]:
def
visit_Attribute
(
self
,
node
:
ast
.
Attribute
)
->
Iterable
[
str
]:
yield
from
self
.
prec
(
"."
).
visit
(
node
.
value
)
yield
from
self
.
prec
(
"."
).
visit
(
node
.
value
)
yield
"."
if
isinstance
(
node
.
value
.
type
,
UserType
):
yield
"->"
else
:
yield
"."
yield
node
.
attr
yield
node
.
attr
def
visit_List
(
self
,
node
:
ast
.
List
)
->
Iterable
[
str
]:
def
visit_List
(
self
,
node
:
ast
.
List
)
->
Iterable
[
str
]:
...
...
trans/transpiler/phases/emit_cpp/module.py
View file @
28ff809a
...
@@ -4,6 +4,7 @@ from typing import Iterable
...
@@ -4,6 +4,7 @@ from typing import Iterable
from
transpiler.phases.emit_cpp
import
CoroutineMode
from
transpiler.phases.emit_cpp
import
CoroutineMode
from
transpiler.phases.emit_cpp.block
import
BlockVisitor
from
transpiler.phases.emit_cpp.block
import
BlockVisitor
from
transpiler.phases.emit_cpp.class_
import
ClassVisitor
from
transpiler.phases.emit_cpp.function
import
FunctionVisitor
from
transpiler.phases.emit_cpp.function
import
FunctionVisitor
from
transpiler.utils
import
compare_ast
from
transpiler.utils
import
compare_ast
...
@@ -37,3 +38,6 @@ class ModuleVisitor(BlockVisitor):
...
@@ -37,3 +38,6 @@ class ModuleVisitor(BlockVisitor):
yield
f"//
{
node
.
value
.
s
}
"
yield
f"//
{
node
.
value
.
s
}
"
else
:
else
:
raise
NotImplementedError
(
node
)
raise
NotImplementedError
(
node
)
def
visit_ClassDef
(
self
,
node
:
ast
.
ClassDef
)
->
Iterable
[
str
]:
yield
from
ClassVisitor
().
visit
(
node
)
trans/transpiler/phases/typing/block.py
View file @
28ff809a
import
ast
import
ast
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
transpiler.phases.typing.annotations
import
TypeAnnotationVisitor
from
transpiler.phases.typing.common
import
ScoperVisitor
from
transpiler.phases.typing.common
import
ScoperVisitor
from
transpiler.phases.typing.expr
import
ScoperExprVisitor
from
transpiler.phases.typing.expr
import
ScoperExprVisitor
from
transpiler.phases.typing.class_
import
ScoperClassVisitor
from
transpiler.phases.typing.scope
import
VarDecl
,
VarKind
,
ScopeKind
from
transpiler.phases.typing.scope
import
VarDecl
,
VarKind
,
ScopeKind
from
transpiler.phases.typing.types
import
BaseType
,
TypeVariable
,
FunctionType
,
IncompatibleTypesError
,
TY_MODULE
,
\
from
transpiler.phases.typing.types
import
BaseType
,
TypeVariable
,
FunctionType
,
IncompatibleTypesError
,
TY_MODULE
,
\
Promise
,
TY_NONE
,
PromiseKind
,
TupleType
Promise
,
TY_NONE
,
PromiseKind
,
TupleType
,
UserType
@
dataclass
@
dataclass
...
@@ -17,6 +16,9 @@ class ScoperBlockVisitor(ScoperVisitor):
...
@@ -17,6 +16,9 @@ class ScoperBlockVisitor(ScoperVisitor):
def
expr
(
self
)
->
ScoperExprVisitor
:
def
expr
(
self
)
->
ScoperExprVisitor
:
return
ScoperExprVisitor
(
self
.
scope
,
self
.
root_decls
)
return
ScoperExprVisitor
(
self
.
scope
,
self
.
root_decls
)
def
visit_Pass
(
self
,
node
:
ast
.
Pass
):
pass
def
visit_Import
(
self
,
node
:
ast
.
Import
):
def
visit_Import
(
self
,
node
:
ast
.
Import
):
for
alias
in
node
.
names
:
for
alias
in
node
.
names
:
self
.
scope
.
vars
[
alias
.
asname
or
alias
.
name
]
=
VarDecl
(
VarKind
.
LOCAL
,
None
)
self
.
scope
.
vars
[
alias
.
asname
or
alias
.
name
]
=
VarDecl
(
VarKind
.
LOCAL
,
None
)
...
@@ -51,11 +53,12 @@ class ScoperBlockVisitor(ScoperVisitor):
...
@@ -51,11 +53,12 @@ class ScoperBlockVisitor(ScoperVisitor):
raise
NotImplementedError
(
node
)
raise
NotImplementedError
(
node
)
target
=
node
.
targets
[
0
]
target
=
node
.
targets
[
0
]
ty
=
self
.
get_type
(
node
.
value
)
ty
=
self
.
get_type
(
node
.
value
)
target
.
type
=
ty
node
.
is_declare
=
self
.
visit_assign_target
(
target
,
ty
)
node
.
is_declare
=
self
.
visit_assign_target
(
target
,
ty
)
target
.
type
.
unify
(
ty
)
def
visit_assign_target
(
self
,
target
,
decl_val
:
BaseType
)
->
bool
:
def
visit_assign_target
(
self
,
target
,
decl_val
:
BaseType
)
->
bool
:
if
isinstance
(
target
,
ast
.
Name
):
if
isinstance
(
target
,
ast
.
Name
):
target
.
type
=
decl_val
if
vdecl
:
=
self
.
scope
.
get
(
target
.
id
):
if
vdecl
:
=
self
.
scope
.
get
(
target
.
id
):
vdecl
.
type
.
unify
(
decl_val
)
vdecl
.
type
.
unify
(
decl_val
)
return
False
return
False
...
@@ -68,15 +71,13 @@ class ScoperBlockVisitor(ScoperVisitor):
...
@@ -68,15 +71,13 @@ class ScoperBlockVisitor(ScoperVisitor):
if
not
(
isinstance
(
decl_val
,
TupleType
)
and
len
(
target
.
elts
)
==
len
(
decl_val
.
args
)):
if
not
(
isinstance
(
decl_val
,
TupleType
)
and
len
(
target
.
elts
)
==
len
(
decl_val
.
args
)):
raise
IncompatibleTypesError
(
f"Cannot unpack
{
decl_val
}
into
{
target
}
"
)
raise
IncompatibleTypesError
(
f"Cannot unpack
{
decl_val
}
into
{
target
}
"
)
return
any
(
self
.
visit_assign_target
(
t
,
ty
)
for
t
,
ty
in
zip
(
target
.
elts
,
decl_val
.
args
))
return
any
(
self
.
visit_assign_target
(
t
,
ty
)
for
t
,
ty
in
zip
(
target
.
elts
,
decl_val
.
args
))
elif
isinstance
(
target
,
ast
.
Attribute
):
attr_type
=
self
.
expr
().
visit
(
target
)
attr_type
.
unify
(
decl_val
)
return
False
else
:
else
:
raise
NotImplementedError
(
target
)
raise
NotImplementedError
(
target
)
def
anno
(
self
)
->
"TypeAnnotationVisitor"
:
return
TypeAnnotationVisitor
(
self
.
scope
)
def
visit_annotation
(
self
,
expr
:
Optional
[
ast
.
expr
])
->
BaseType
:
return
self
.
anno
().
visit
(
expr
)
if
expr
else
TypeVariable
()
def
visit_FunctionDef
(
self
,
node
:
ast
.
FunctionDef
):
def
visit_FunctionDef
(
self
,
node
:
ast
.
FunctionDef
):
argtypes
=
[
self
.
visit_annotation
(
arg
.
annotation
)
for
arg
in
node
.
args
.
args
]
argtypes
=
[
self
.
visit_annotation
(
arg
.
annotation
)
for
arg
in
node
.
args
.
args
]
rtype
=
Promise
(
self
.
visit_annotation
(
node
.
returns
),
PromiseKind
.
TASK
)
rtype
=
Promise
(
self
.
visit_annotation
(
node
.
returns
),
PromiseKind
.
TASK
)
...
@@ -97,6 +98,18 @@ class ScoperBlockVisitor(ScoperVisitor):
...
@@ -97,6 +98,18 @@ class ScoperBlockVisitor(ScoperVisitor):
if
not
scope
.
has_return
:
if
not
scope
.
has_return
:
rtype
.
return_type
.
unify
(
TY_NONE
)
rtype
.
return_type
.
unify
(
TY_NONE
)
def
visit_ClassDef
(
self
,
node
:
ast
.
ClassDef
):
ctype
=
UserType
(
node
.
name
)
self
.
scope
.
vars
[
node
.
name
]
=
VarDecl
(
VarKind
.
LOCAL
,
ctype
)
scope
=
self
.
scope
.
child
(
ScopeKind
.
CLASS
)
scope
.
obj_type
=
ctype
scope
.
class_
=
scope
node
.
inner_scope
=
scope
node
.
type
=
ctype
visitor
=
ScoperClassVisitor
(
scope
)
for
b
in
node
.
body
:
visitor
.
visit
(
b
)
def
visit_If
(
self
,
node
:
ast
.
If
):
def
visit_If
(
self
,
node
:
ast
.
If
):
scope
=
self
.
scope
.
child
(
ScopeKind
.
FUNCTION_INNER
)
scope
=
self
.
scope
.
child
(
ScopeKind
.
FUNCTION_INNER
)
node
.
inner_scope
=
scope
node
.
inner_scope
=
scope
...
...
trans/transpiler/phases/typing/class_.py
0 → 100644
View file @
28ff809a
# coding: utf-8
import
ast
from
dataclasses
import
dataclass
from
transpiler.phases.typing
import
FunctionType
,
ScopeKind
,
VarDecl
,
VarKind
,
TY_NONE
from
transpiler.phases.typing.common
import
ScoperVisitor
@
dataclass
class
ScoperClassVisitor
(
ScoperVisitor
):
def
visit_AnnAssign
(
self
,
node
:
ast
.
AnnAssign
):
assert
node
.
value
is
None
,
"Class field should not have a value"
assert
node
.
simple
==
1
,
"Class field should be simple (identifier, not parenthesized)"
assert
isinstance
(
node
.
target
,
ast
.
Name
)
self
.
scope
.
obj_type
.
members
[
node
.
target
.
id
]
=
self
.
visit_annotation
(
node
.
annotation
)
def
visit_FunctionDef
(
self
,
node
:
ast
.
FunctionDef
):
from
transpiler.phases.typing.block
import
ScoperBlockVisitor
# TODO: maybe merge this code with ScoperBlockVisitor.visit_FunctionDef
argtypes
=
[
self
.
visit_annotation
(
arg
.
annotation
)
for
arg
in
node
.
args
.
args
]
argtypes
[
0
].
unify
(
self
.
scope
.
obj_type
)
# self parameter
rtype
=
self
.
visit_annotation
(
node
.
returns
)
ftype
=
FunctionType
(
argtypes
,
rtype
)
self
.
scope
.
obj_type
.
methods
[
node
.
name
]
=
ftype
scope
=
self
.
scope
.
child
(
ScopeKind
.
FUNCTION
)
scope
.
obj_type
=
ftype
scope
.
function
=
scope
node
.
inner_scope
=
scope
node
.
type
=
ftype
for
arg
,
ty
in
zip
(
node
.
args
.
args
,
argtypes
):
scope
.
vars
[
arg
.
arg
]
=
VarDecl
(
VarKind
.
LOCAL
,
ty
)
for
b
in
node
.
body
:
decls
=
{}
visitor
=
ScoperBlockVisitor
(
scope
,
decls
)
visitor
.
visit
(
b
)
b
.
decls
=
decls
if
not
scope
.
has_return
:
rtype
.
unify
(
TY_NONE
)
trans/transpiler/phases/typing/common.py
View file @
28ff809a
import
ast
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Dict
from
typing
import
Dict
,
Optional
from
transpiler.phases.typing.annotations
import
TypeAnnotationVisitor
from
transpiler.phases.typing.scope
import
Scope
,
ScopeKind
,
VarDecl
from
transpiler.phases.typing.scope
import
Scope
,
ScopeKind
,
VarDecl
from
transpiler.phases.typing.types
import
BaseType
,
TypeVariable
from
transpiler.phases.utils
import
NodeVisitorSeq
from
transpiler.phases.utils
import
NodeVisitorSeq
PRELUDE
=
Scope
.
make_global
()
PRELUDE
=
Scope
.
make_global
()
...
@@ -9,4 +12,10 @@ PRELUDE = Scope.make_global()
...
@@ -9,4 +12,10 @@ PRELUDE = Scope.make_global()
@
dataclass
@
dataclass
class
ScoperVisitor
(
NodeVisitorSeq
):
class
ScoperVisitor
(
NodeVisitorSeq
):
scope
:
Scope
=
field
(
default_factory
=
lambda
:
PRELUDE
.
child
(
ScopeKind
.
GLOBAL
))
scope
:
Scope
=
field
(
default_factory
=
lambda
:
PRELUDE
.
child
(
ScopeKind
.
GLOBAL
))
root_decls
:
Dict
[
str
,
VarDecl
]
=
field
(
default_factory
=
dict
)
root_decls
:
Dict
[
str
,
VarDecl
]
=
field
(
default_factory
=
dict
)
\ No newline at end of file
def
anno
(
self
)
->
"TypeAnnotationVisitor"
:
return
TypeAnnotationVisitor
(
self
.
scope
)
def
visit_annotation
(
self
,
expr
:
Optional
[
ast
.
expr
])
->
BaseType
:
return
self
.
anno
().
visit
(
expr
)
if
expr
else
TypeVariable
()
\ No newline at end of file
trans/transpiler/phases/typing/expr.py
View file @
28ff809a
...
@@ -6,7 +6,7 @@ from typing import List
...
@@ -6,7 +6,7 @@ from typing import List
from
transpiler.phases.typing
import
ScopeKind
,
VarDecl
,
VarKind
from
transpiler.phases.typing
import
ScopeKind
,
VarDecl
,
VarKind
from
transpiler.phases.typing.common
import
ScoperVisitor
from
transpiler.phases.typing.common
import
ScoperVisitor
from
transpiler.phases.typing.types
import
IncompatibleTypesError
,
BaseType
,
TupleType
,
TY_STR
,
TY_BOOL
,
TY_INT
,
\
from
transpiler.phases.typing.types
import
IncompatibleTypesError
,
BaseType
,
TupleType
,
TY_STR
,
TY_BOOL
,
TY_INT
,
\
TY_COMPLEX
,
TY_NONE
,
FunctionType
,
PyList
,
TypeVariable
,
PySet
,
TypeType
,
PyDict
,
Promise
,
PromiseKind
TY_COMPLEX
,
TY_NONE
,
FunctionType
,
PyList
,
TypeVariable
,
PySet
,
TypeType
,
PyDict
,
Promise
,
PromiseKind
,
UserType
DUNDER
=
{
DUNDER
=
{
ast
.
Eq
:
"eq"
,
ast
.
Eq
:
"eq"
,
...
@@ -105,6 +105,10 @@ class ScoperExprVisitor(ScoperVisitor):
...
@@ -105,6 +105,10 @@ class ScoperExprVisitor(ScoperVisitor):
return
actual
return
actual
def
visit_function_call
(
self
,
ftype
:
BaseType
,
arguments
:
List
[
BaseType
]):
def
visit_function_call
(
self
,
ftype
:
BaseType
,
arguments
:
List
[
BaseType
]):
if
isinstance
(
ftype
,
UserType
):
init
:
FunctionType
=
self
.
visit_getattr
(
ftype
,
"__init__"
)
ctor
=
FunctionType
(
init
.
args
[
1
:],
ftype
)
return
self
.
visit_function_call
(
ctor
,
arguments
)
if
not
isinstance
(
ftype
,
FunctionType
):
if
not
isinstance
(
ftype
,
FunctionType
):
raise
IncompatibleTypesError
(
f"Cannot call
{
ftype
}
"
)
raise
IncompatibleTypesError
(
f"Cannot call
{
ftype
}
"
)
#is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list())
#is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list())
...
...
trans/transpiler/phases/typing/scope.py
View file @
28ff809a
...
@@ -53,6 +53,7 @@ class Scope:
...
@@ -53,6 +53,7 @@ class Scope:
children
:
List
[
"Scope"
]
=
field
(
default_factory
=
list
)
children
:
List
[
"Scope"
]
=
field
(
default_factory
=
list
)
obj_type
:
Optional
[
BaseType
]
=
None
obj_type
:
Optional
[
BaseType
]
=
None
has_return
:
bool
=
False
has_return
:
bool
=
False
class_
:
Optional
[
"Scope"
]
=
None
@
staticmethod
@
staticmethod
def
make_global
():
def
make_global
():
...
...
trans/transpiler/phases/typing/types.py
View file @
28ff809a
...
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
...
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
enum
import
Enum
from
itertools
import
zip_longest
from
itertools
import
zip_longest
from
typing
import
Dict
,
Optional
,
List
,
ClassVar
,
Callable
,
Any
from
typing
import
Dict
,
Optional
,
List
,
ClassVar
,
Callable
class
IncompatibleTypesError
(
Exception
):
class
IncompatibleTypesError
(
Exception
):
...
@@ -215,7 +215,6 @@ class TypeOperator(BaseType, ABC):
...
@@ -215,7 +215,6 @@ class TypeOperator(BaseType, ABC):
return
[
self
,
*
self
.
args
]
return
[
self
,
*
self
.
args
]
class
FunctionType
(
TypeOperator
):
class
FunctionType
(
TypeOperator
):
def
__init__
(
self
,
args
:
List
[
BaseType
],
ret
:
BaseType
):
def
__init__
(
self
,
args
:
List
[
BaseType
],
ret
:
BaseType
):
super
().
__init__
([
ret
,
*
args
])
super
().
__init__
([
ret
,
*
args
])
...
@@ -374,3 +373,12 @@ class Future(Promise):
...
@@ -374,3 +373,12 @@ class Future(Promise):
def
__init__
(
self
,
ret
:
BaseType
):
def
__init__
(
self
,
ret
:
BaseType
):
super
().
__init__
(
ret
,
PromiseKind
.
FUTURE
)
super
().
__init__
(
ret
,
PromiseKind
.
FUTURE
)
class
UserType
(
TypeOperator
):
def
__init__
(
self
,
name
:
str
):
super
().
__init__
([],
name
=
name
)
def
unify_internal
(
self
,
other
:
"BaseType"
):
if
type
(
self
)
!=
type
(
other
):
raise
IncompatibleTypesError
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment