Commit b738512a authored by Tom Niget's avatar Tom Niget

Builtin things work

parent 85f54768
.idea .idea
cmake-build-* cmake-build-*
typon.egg-info typon.egg-info
\ No newline at end of file build
\ No newline at end of file
...@@ -20,6 +20,8 @@ public: ...@@ -20,6 +20,8 @@ public:
return sync_wrapper(std::forward<Args>(args)...); return sync_wrapper(std::forward<Args>(args)...);
} }
}; };
/* /*
struct method {}; struct method {};
......
...@@ -59,6 +59,8 @@ template <PySmartPtr T> struct RealType<T> { ...@@ -59,6 +59,8 @@ template <PySmartPtr T> struct RealType<T> {
using type = typename T::element_type; using type = typename T::element_type;
}; };
namespace typon {
//template <typename T> using TyObj = std::shared_ptr<typename RealType<T>::type>; //template <typename T> using TyObj = std::shared_ptr<typename RealType<T>::type>;
template<typename T> template<typename T>
...@@ -135,6 +137,11 @@ template <typename T, typename... Args> auto pyobj_agg(Args &&...args) -> TyObj< ...@@ -135,6 +137,11 @@ template <typename T, typename... Args> auto pyobj_agg(Args &&...args) -> TyObj<
return std::make_shared<typename RealType<T>::type>((typename RealType<T>::type) { std::forward<Args>(args)... }); return std::make_shared<typename RealType<T>::type>((typename RealType<T>::type) { std::forward<Args>(args)... });
} }
class TyNone {
};
}
// typon_len // typon_len
template <typename T> template <typename T>
...@@ -199,6 +206,7 @@ static constexpr auto PyNone = std::nullopt; ...@@ -199,6 +206,7 @@ static constexpr auto PyNone = std::nullopt;
#include "builtins/bool.hpp" #include "builtins/bool.hpp"
#include "builtins/complex.hpp" #include "builtins/complex.hpp"
#include "builtins/dict.hpp" #include "builtins/dict.hpp"
#include "builtins/int.hpp"
#include "builtins/list.hpp" #include "builtins/list.hpp"
#include "builtins/print.hpp" #include "builtins/print.hpp"
#include "builtins/range.hpp" #include "builtins/range.hpp"
...@@ -307,7 +315,7 @@ typon::Task<typon::PyFile> open(const TyStr &path, std::string_view mode) { ...@@ -307,7 +315,7 @@ typon::Task<typon::PyFile> open(const TyStr &path, std::string_view mode) {
std::cerr << path << "," << flags << std::endl; std::cerr << path << "," << flags << std::endl;
system_error(-fd, "openat()"); system_error(-fd, "openat()");
} }
co_return tyObj<typon::PyFile>(fd, len); co_return typon::tyObj<typon::PyFile>(fd, len);
} }
#include <typon/generator.hpp> #include <typon/generator.hpp>
...@@ -349,9 +357,9 @@ template<PySmartPtr T> ...@@ -349,9 +357,9 @@ template<PySmartPtr T>
auto& iter_fix_ref(T& obj) { return *obj; } auto& iter_fix_ref(T& obj) { return *obj; }
namespace std { namespace std {
template <class T> auto begin(std::shared_ptr<T> &obj) { return dotp(obj, begin)(); } template <class T> auto begin(std::shared_ptr<T> &obj) { return dot(obj, begin)(); }
template <class T> auto end(std::shared_ptr<T> &obj) { return dotp(obj, end)(); } template <class T> auto end(std::shared_ptr<T> &obj) { return dot(obj, end)(); }
} }
template <typename T> template <typename T>
...@@ -365,18 +373,33 @@ struct ValueTypeEx { ...@@ -365,18 +373,33 @@ struct ValueTypeEx {
using type = decltype(*std::begin(std::declval<Seq&>())); using type = decltype(*std::begin(std::declval<Seq&>()));
}; };
// (2) // (2)
template <typename Map, typename Seq, typename Filt = AlwaysTrue<typename ValueTypeEx<Seq>::type>> /*template <typename Map, typename Seq, typename Filt = AlwaysTrue<typename ValueTypeEx<Seq>::type>>
auto mapFilter(Map map, Seq seq, Filt filt = Filt()) { typon::Task< mapFilter(Map map, Seq seq, Filt filt = Filt()) {
//typedef typename Seq::value_type value_type; //typedef typename Seq::value_type value_type;
using value_type = typename ValueTypeEx<Seq>::type; using value_type = typename ValueTypeEx<Seq>::type;
using return_type = decltype(map(std::declval<value_type>())); using return_type = decltype(map(std::declval<value_type>()));
std::vector<return_type> result{}; std::vector<return_type> result{};
for (auto i : seq | std::views::filter(filt) for (auto i : seq) {
| std::views::transform(map)) result.push_back(i); if (co_await filt(i)) {
result.push_back(co_await map(i));
}
}
return typon::TyList(std::move(result)); return typon::TyList(std::move(result));
} }*/
#define MAP_FILTER(item, seq, map, filter) ({\
using value_type = typename ValueTypeEx<decltype(seq)>::type;\
value_type item;\
std::vector<decltype(map)> result{};\
for (auto item : seq) {\
if (filter) {\
result.push_back(map);\
}\
}\
typon::TyList(std::move(result));\
})
namespace PYBIND11_NAMESPACE { namespace PYBIND11_NAMESPACE {
namespace detail { namespace detail {
......
//
// Created by Tom on 08/03/2023.
//
#ifndef TYPON_INT_HPP
#define TYPON_INT_HPP
#include <sstream>
#include <string>
#include <algorithm>
using namespace std::literals;
#include "bytes.hpp"
#include "print.hpp"
#include "slice.hpp"
// #include <format>
#include <fmt/format.h>
#include <pybind11/cast.h>
namespace typon {
/*template <typename _Base0 = object>
class TyInt__oo : classtype<_Base0, Integer__oo<>> {
public:
struct : method {
auto operator()(auto self, int value) const {
self->value = value;
}
} static constexpr oo__init__oo {};
struct : method {
auto operator()(auto self, auto other) const {
return Integer(dot(self, value) + dot(other, value));
}
} static constexpr oo__add__oo {};
auto operator () (int value) const {
struct Obj : instance<Integer__oo<>, Obj> {
int value;
};
auto obj = rc(Obj{});
dot(obj, oo__init__oo)(value);
return obj;
}
constexpr TyInt(int value) : value(value) {}
constexpr TyInt() : value(0) {}
operator int() const { return value; }
// operators
template <typename T> TyInt operator+(T x) const { return value + x; }
template <typename T> TyInt operator-(T x) const { return value - x; }
template <typename T> TyInt operator*(T x) const { return value * x; }
template <typename T> TyInt operator/(T x) const { return value / x; }
template <typename T> TyInt operator%(T x) const { return value % x; }
template <typename T> TyInt operator&(T x) const { return value & x; }
template <typename T> TyInt operator|(T x) const { return value | x; }
template <typename T> TyInt operator^(T x) const { return value ^ x; }
template <typename T> TyInt operator<<(T x) const { return value << x; }
template <typename T> TyInt operator>>(T x) const { return value >> x; }
template <typename T> TyInt operator&&(T x) const { return value && x; }
template <typename T> TyInt operator||(T x) const { return value || x; }
template <typename T> TyInt operator==(T x) const { return value == x; }
template <typename T> TyInt operator<(T x) const { return value < x; }
TyInt operator-() const { return -value; }
private:
int value;
};*/
using namespace referencemodel;
template <typename _Base0 = object>
struct Integer__oo : classtype<_Base0, Integer__oo<>> {
static constexpr std::string_view name = "Integer";
struct : method {
auto operator()(auto self, int value) const {
self->value = value;
}
} static constexpr oo__init__oo {};
struct : method {
auto operator()(auto self, auto other) const {
return TyInt(dot(self, value) + dot(other, value));
}
} static constexpr oo__add__oo {};
struct Obj : value<Integer__oo<>, Obj> {
int value;
Obj(int value=0) : value(value) {}
operator int() const { return value; }
};
auto operator () (int value) const {
auto obj = rc(Obj{});
dot(obj, oo__init__oo)(value);
return obj;
}
};
static constexpr Integer__oo<> TyInt {};
}
inline auto operator ""_pi(unsigned long long int v) noexcept {
return typon::TyInt(v);
}
template <> struct std::hash<decltype(0_pi)> {
std::size_t operator()(const decltype(0_pi) &s) const noexcept {
return std::hash<int>()(s);
}
};
namespace PYBIND11_NAMESPACE {
namespace detail {
template<>
struct type_caster<decltype(0_pi)>
: type_caster<int> {};
}}
template <> void repr_to(const decltype(0_pi) &x, std::ostream &s) {
s << x;
}
template <> void print_to<decltype(0_pi)>(const decltype(0_pi) &x, std::ostream &s) { s << x; }
#endif // TYPON_INT_HPP
...@@ -86,13 +86,14 @@ typon::Task<void> print(T const &head, Args const &...args) { ...@@ -86,13 +86,14 @@ typon::Task<void> print(T const &head, Args const &...args) {
}*/ }*/
struct { struct {
void operator()() { std::cout << '\n'; } typon::TyNone operator()() { std::cout << '\n'; return {}; }
template <Printable T, Printable... Args> template <Printable T, Printable... Args>
void operator()(T const &head, Args const &...args) { typon::TyNone operator()(T const &head, Args const &...args) {
print_to(head, std::cout); print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...); (((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n'; std::cout << '\n';
return {};
} }
} print; } print;
// typon::Task<void> print() { std::cout << '\n'; co_return; } // typon::Task<void> print() { std::cout << '\n'; co_return; }
......
...@@ -18,11 +18,7 @@ auto stride = [](int n) { ...@@ -18,11 +18,7 @@ auto stride = [](int n) {
// todo: proper range support // todo: proper range support
struct range_s : TyBuiltin<range_s> struct range_s : TyBuiltin<range_s>
{ {
template <typename T> auto sync(int start, int stop, int step = 1) {
auto sync(T stop) { return sync(0, stop); }
template <typename T>
auto sync(T start, T stop, T step = 1) {
// https://www.modernescpp.com/index.php/c-20-pythons-map-function/ // https://www.modernescpp.com/index.php/c-20-pythons-map-function/
if(step == 0) { if(step == 0) {
throw std::invalid_argument("Step cannot be 0"); throw std::invalid_argument("Step cannot be 0");
...@@ -36,6 +32,10 @@ struct range_s : TyBuiltin<range_s> ...@@ -36,6 +32,10 @@ struct range_s : TyBuiltin<range_s>
return start < stop ? i : stop - (i - start); return start < stop ? i : stop - (i - start);
}); });
} }
auto sync(int stop) { return sync(0, stop); }
} range; } range;
#endif // TYPON_RANGE_HPP #endif // TYPON_RANGE_HPP
...@@ -23,6 +23,9 @@ struct object; ...@@ -23,6 +23,9 @@ struct object;
template <typename T, typename O> template <typename T, typename O>
struct instance; struct instance;
template <typename T, typename O>
struct value;
template <typename B, typename T> template <typename B, typename T>
struct classtype; struct classtype;
...@@ -168,6 +171,10 @@ using unwrap_pack = typename unwrap_pack_s<std::remove_cvref_t<T>>::type; ...@@ -168,6 +171,10 @@ using unwrap_pack = typename unwrap_pack_s<std::remove_cvref_t<T>>::type;
/* Meta-programming utilities for object model */ /* Meta-programming utilities for object model */
template <typename T>
concept object = std::derived_from<unwrap_all<T>, referencemodel::object>;
template <typename T> template <typename T>
concept instance = std::derived_from< concept instance = std::derived_from<
unwrap_all<T>, unwrap_all<T>,
...@@ -215,7 +222,26 @@ concept boundmethod = boundmethod_s<std::remove_cvref_t<T>>::value; ...@@ -215,7 +222,26 @@ concept boundmethod = boundmethod_s<std::remove_cvref_t<T>>::value;
template <typename T> template <typename T>
concept value = !instance<T> && !boundmethod<T>; struct value_s {
static constexpr bool value = true;
};
template <instance T>
struct value_s<T> {
static constexpr bool value = std::derived_from<
unwrap_all<T>,
referencemodel::value<typename unwrap_all<T>::type, unwrap_all<T>>
>;
};
template <typename S, typename F>
struct value_s<referencemodel::boundmethod<S, F>> {
static constexpr bool value = value_s<S>::value;
};
template <typename T>
concept value = value_s<std::remove_cvref_t<T>>::value;
/* Meta-programming utilities: wrapped and unwrapped */ /* Meta-programming utilities: wrapped and unwrapped */
...@@ -963,6 +989,10 @@ struct instance : T { ...@@ -963,6 +989,10 @@ struct instance : T {
}; };
template <typename T, typename O>
struct value : instance<T, O> {};
template <typename B, typename T> template <typename B, typename T>
struct classtype : B { struct classtype : B {
using base = B; using base = B;
...@@ -1193,14 +1223,58 @@ decltype(auto) bind(S &&, const A & attr, T...) { ...@@ -1193,14 +1223,58 @@ decltype(auto) bind(S &&, const A & attr, T...) {
return attr; return attr;
} }
} // namespace referencemodel
#define dot(OBJ, NAME)\ #define dot(OBJ, NAME)\
[](auto && obj) -> decltype(auto) {\ [](auto && obj) -> decltype(auto) {\
return referencemodel::bind(std::forward<decltype(obj)>(obj), obj->NAME);\ return referencemodel::bind(std::forward<decltype(obj)>(obj), obj->NAME);\
}(OBJ) }(OBJ)
/* Operators */
namespace meta {
/* + */
template <typename Left, typename Right>
concept LeftAddable = requires (Left left, Right right) {
// note: using dot here would cause hard failure instead of invalid constraint
left->oo__add__oo(left, right);
};
template <typename Left, typename Right>
concept RightAddable = requires (Left left, Right right) {
// note: using dot here would cause hard failure instead of invalid constraint
right->oo__radd__oo(right, left);
};
template <typename Left, typename Right>
concept Addable = LeftAddable<Left, Right> || RightAddable<Left, Right>;
}
/* + */
template <meta::object Left, meta::object Right>
requires meta::Addable<Left, Right>
auto operator + (Left && left, Right && right) {
if constexpr (meta::LeftAddable<Left, Right>) {
return dot(std::forward<Left>(left), oo__add__oo)(
std::forward<Right>(right));
}
else {
if constexpr (meta::RightAddable<Left, Right>) {
return dot(std::forward<Right>(right), oo__radd__oo)(
std::forward<Left>(left));
}
}
}
} // namespace referencemodel
#endif // REFERENCEMODEL_H #endif // REFERENCEMODEL_H
Subproject commit 79677d125f915f7c61492d8d1d8cde9fc6a11875 Subproject commit 26c320d77d368d0f482684ebd653d1702c75a7f2
...@@ -3,8 +3,8 @@ from typing import Self, Protocol, Optional ...@@ -3,8 +3,8 @@ from typing import Self, Protocol, Optional
assert 5 assert 5
class object: class object:
def __eq__(self, other: Self) -> bool: ... def __eq__[T](self, other: T) -> bool: ...
def __ne__(self, other: Self) -> bool: ... def __ne__[T](self, other: T) -> bool: ...
class int: class int:
def __add__(self, other: Self) -> Self: ... def __add__(self, other: Self) -> Self: ...
...@@ -96,6 +96,10 @@ assert [].__getitem__ ...@@ -96,6 +96,10 @@ assert [].__getitem__
assert [4].__getitem__ assert [4].__getitem__
assert [1, 2, 3][1] assert [1, 2, 3][1]
class set[U]:
def __len__(self) -> int: ...
def __contains__(self, item: U) -> bool: ...
def iter[U](x: Iterable[U]) -> Iterator[U]: def iter[U](x: Iterable[U]) -> Iterator[U]:
... ...
......
...@@ -2,4 +2,5 @@ ...@@ -2,4 +2,5 @@
Protocol = BuiltinFeature["Protocol"] Protocol = BuiltinFeature["Protocol"]
Self = BuiltinFeature["Self"] Self = BuiltinFeature["Self"]
Optional = BuiltinFeature["Optional"] Optional = BuiltinFeature["Optional"]
\ No newline at end of file Callable = BuiltinFeature["Callable"]
\ No newline at end of file
...@@ -88,15 +88,15 @@ def run_test(path, quiet=True): ...@@ -88,15 +88,15 @@ def run_test(path, quiet=True):
if args.compile: if args.compile:
return TestStatus.SUCCESS return TestStatus.SUCCESS
execute_str = "true" if (execute and not args.generate) else "false" execute_str = "true" if (execute and not args.generate) else "false"
name_bin = path.with_suffix("").as_posix() + ("$(python3-config --extension-suffix)" if extension else ".exe") name_bin = path.with_suffix("").as_posix() + ("$(python3.12-config --extension-suffix)" if extension else ".exe")
if exec_cmd(f'bash -c "export PYTHONPATH=stdlib; if {execute_str}; then python3 ./{path.as_posix()}; fi"') != 0: if exec_cmd(f'bash -c "export PYTHONPATH=stdlib; if {execute_str}; then echo python3.12 ./{path.as_posix()}; fi"') != 0:
return TestStatus.PYTHON_ERROR return TestStatus.PYTHON_ERROR
if compile and (alt := environ.get("ALT_RUNNER")): if compile and (alt := environ.get("ALT_RUNNER")):
if (code := exec_cmd(alt.format( if (code := exec_cmd(alt.format(
name_bin=name_bin, name_bin=name_bin,
name_cpp_posix=name_cpp.as_posix(), name_cpp_posix=name_cpp.as_posix(),
run_file=execute_str, run_file=execute_str,
test_exec=f"python3 {path.with_suffix('.post.py').as_posix()}" if extension else name_bin, test_exec=f"python3.12 {path.with_suffix('.post.py').as_posix()}" if extension else name_bin,
bonus_flags="-e" if extension else "" bonus_flags="-e" if extension else ""
))) != 0: ))) != 0:
return TestStatus(code) return TestStatus(code)
......
...@@ -4,9 +4,9 @@ from typon import is_cpp ...@@ -4,9 +4,9 @@ from typon import is_cpp
import sys as sis import sys as sis
from sys import stdout as truc from sys import stdout as truc
foo = 123 # foo = 123
test = (2 + 3) * 4 # test = (2 + 3) * 4
glob = 5 # glob = 5
# def g(): # def g():
# a = 8 # a = 8
...@@ -20,32 +20,37 @@ glob = 5 ...@@ -20,32 +20,37 @@ glob = 5
# e = d + 1 # e = d + 1
# print(e) # print(e)
def f(x): # def f(x):
return x + 1 # return x + 1
#
#
def fct(param: int): # def fct(param: int):
loc = f(456) # loc = f(456)
global glob # global glob
loc = 789 # loc = 789
glob = 123 # glob = 123
#
def fct2(): # def fct2():
global glob # global glob
glob += 5 # glob += 5
if __name__ == "__main__": if __name__ == "__main__":
print(is_cpp) print("is c++:", is_cpp())
# TODO: doesn't compile under G++ 12.2, fixed in trunk on March 15 # TODO: doesn't compile under G++ 12.2, fixed in trunk on March 15
# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=98056 # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=98056
sum = 0 sum = 0
for i in range(15): for i in range(15):
sum += i sum = sum + i
a = [n for n in range(10)] a = [n for n in range(10)]
b = [x for x in a if x % 2 == 0] #b = [x for x in a if x % 2 == 0]
c = [y * y for y in b] #c = [y * y for y in b]
print("C++ " if is_cpp() else "Python", print("C++ " if is_cpp() else "Python",
"res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5, "res=", 5, ".", True, [4, 5, 6], {7, 8, 9},
3j, sum, a, b, c) #[1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3},
#0x55 & 7 == 5,
#3j,
sum,
a)
print("Typon")
print() print()
...@@ -2,10 +2,15 @@ ...@@ -2,10 +2,15 @@
import ast import ast
import builtins import builtins
import importlib import importlib
import inspect
import sys import sys
import traceback import traceback
import colorful as cf import colorful as cf
from transpiler.exceptions import CompileError
from transpiler.utils import highlight
def exception_hook(exc_type, exc_value, tb): def exception_hook(exc_type, exc_value, tb):
print = lambda *args, **kwargs: builtins.print(*args, **kwargs, file=sys.stderr) print = lambda *args, **kwargs: builtins.print(*args, **kwargs, file=sys.stderr)
last_node = None last_node = None
...@@ -49,40 +54,44 @@ def exception_hook(exc_type, exc_value, tb): ...@@ -49,40 +54,44 @@ def exception_hook(exc_type, exc_value, tb):
return return
print(f"In file {cf.white(last_file)}:{last_node.lineno}") print(f"In file {cf.white(last_file)}:{last_node.lineno}")
#print(f"From {last_node.lineno}:{last_node.col_offset} to {last_node.end_lineno}:{last_node.end_col_offset}") #print(f"From {last_node.lineno}:{last_node.col_offset} to {last_node.end_lineno}:{last_node.end_col_offset}")
with open(last_file, "r", encoding="utf-8") as f: try:
code = f.read() with open(last_file, "r", encoding="utf-8") as f:
hg = (str(highlight(code, True)) code = f.read()
.replace("\x1b[04m", "") except Exception:
.replace("\x1b[24m", "") pass
.replace("\x1b[39;24m", "\x1b[39m")
.splitlines())
if last_node.lineno == last_node.end_lineno:
old = hg[last_node.lineno - 1]
start, end = find_indices(old, [last_node.col_offset, last_node.end_col_offset])
hg[last_node.lineno - 1] = old[:start] + "\x1b[4m" + old[start:end] + "\x1b[24m" + old[end:]
else: else:
old = hg[last_node.lineno - 1] hg = (str(highlight(code, True))
[start] = find_indices(old, [last_node.col_offset]) .replace("\x1b[04m", "")
hg[last_node.lineno - 1] = old[:start] + "\x1b[4m" + old[start:] .replace("\x1b[24m", "")
for lineid in range(last_node.lineno, last_node.end_lineno - 1): .replace("\x1b[39;24m", "\x1b[39m")
old = hg[lineid] .splitlines())
if last_node.lineno == last_node.end_lineno:
old = hg[last_node.lineno - 1]
start, end = find_indices(old, [last_node.col_offset, last_node.end_col_offset])
hg[last_node.lineno - 1] = old[:start] + "\x1b[4m" + old[start:end] + "\x1b[24m" + old[end:]
else:
old = hg[last_node.lineno - 1]
[start] = find_indices(old, [last_node.col_offset])
hg[last_node.lineno - 1] = old[:start] + "\x1b[4m" + old[start:]
for lineid in range(last_node.lineno, last_node.end_lineno - 1):
old = hg[lineid]
first_nonspace = len(old) - len(old.lstrip())
hg[lineid] = old[:first_nonspace] + "\x1b[4m" + old[first_nonspace:] + "\x1b[24m"
old = hg[last_node.end_lineno - 1]
first_nonspace = len(old) - len(old.lstrip()) first_nonspace = len(old) - len(old.lstrip())
hg[lineid] = old[:first_nonspace] + "\x1b[4m" + old[first_nonspace:] + "\x1b[24m" [end] = find_indices(old, [last_node.end_col_offset])
old = hg[last_node.end_lineno - 1] hg[last_node.end_lineno - 1] = old[:first_nonspace] + "\x1b[4m" + old[first_nonspace:end] + "\x1b[24m" + old[end:]
first_nonspace = len(old) - len(old.lstrip()) CONTEXT_SIZE = 2
[end] = find_indices(old, [last_node.end_col_offset]) start = max(0, last_node.lineno - CONTEXT_SIZE - 1)
hg[last_node.end_lineno - 1] = old[:first_nonspace] + "\x1b[4m" + old[first_nonspace:end] + "\x1b[24m" + old[end:] offset = start + 1
CONTEXT_SIZE = 2 for i, line in enumerate(hg[start:last_node.end_lineno + CONTEXT_SIZE]):
start = max(0, last_node.lineno - CONTEXT_SIZE - 1) erroneous = last_node.lineno <= offset + i <= last_node.end_lineno
offset = start + 1 indicator = cf.white(" →") if erroneous else " "
for i, line in enumerate(hg[start:last_node.end_lineno + CONTEXT_SIZE]): bar = " ▎"
erroneous = last_node.lineno <= offset + i <= last_node.end_lineno # bar = "│" if erroneous else "┊"
indicator = cf.white(" →") if erroneous else " " disp = f"\x1b[24m{indicator}{cf.white}{(offset + i):>4}{cf.red if erroneous else cf.reset}{bar}{cf.reset} {line}\x1b[24m"
bar = " ▎" print(disp)
# bar = "│" if erroneous else "┊" # print(repr(disp))
disp = f"\x1b[24m{indicator}{cf.white}{(offset + i):>4}{cf.red if erroneous else cf.reset}{bar}{cf.reset} {line}\x1b[24m"
print(disp)
# print(repr(disp))
print() print()
if isinstance(exc_value, CompileError): if isinstance(exc_value, CompileError):
print(cf.red("Error:"), exc_value) print(cf.red("Error:"), exc_value)
......
...@@ -24,7 +24,7 @@ DUNDER = { ...@@ -24,7 +24,7 @@ DUNDER = {
class DesugarOp(ast.NodeTransformer): class DesugarOp(ast.NodeTransformer):
def visit_BinOp(self, node: ast.BinOp): def visit_BinOp(self, node: ast.BinOp):
lnd = linenodata(node) lnd = linenodata(node)
return ast.Call( res = ast.Call(
func=ast.Attribute( func=ast.Attribute(
value=self.visit(node.left), value=self.visit(node.left),
attr=f"__{DUNDER[type(node.op)]}__", attr=f"__{DUNDER[type(node.op)]}__",
...@@ -35,26 +35,31 @@ class DesugarOp(ast.NodeTransformer): ...@@ -35,26 +35,31 @@ class DesugarOp(ast.NodeTransformer):
keywords={}, keywords={},
**lnd **lnd
) )
res.orig_node = node
return res
def visit_UnaryOp(self, node: ast.UnaryOp): def visit_UnaryOp(self, node: ast.UnaryOp):
lnd = linenodata(node) lnd = linenodata(node)
if type(node.op) == ast.Not: if type(node.op) == ast.Not:
return ast.UnaryOp( res = ast.UnaryOp(
operand=self.visit(node.operand), operand=self.visit(node.operand),
op=node.op, op=node.op,
**lnd **lnd
) )
return ast.Call( else:
func=ast.Attribute( res = ast.Call(
value=self.visit(node.operand), func=ast.Attribute(
attr=f"__{DUNDER[type(node.op)]}__", value=self.visit(node.operand),
ctx=ast.Load(), attr=f"__{DUNDER[type(node.op)]}__",
ctx=ast.Load(),
**lnd
),
args=[],
keywords={},
**lnd **lnd
), )
args=[], res.orig_node = node
keywords={}, return res
**lnd
)
# def visit_AugAssign(self, node: ast.AugAssign): # def visit_AugAssign(self, node: ast.AugAssign):
# return # return
import ast
from typing import Iterable from typing import Iterable
def emit_class(clazz) -> Iterable[str]: def emit_class(node: ast.ClassDef) -> Iterable[str]:
yield f"template <typename _Base0 = referencemodel::object>" yield f"template <typename _Base0 = referencemodel::object>"
yield f"struct {node.name}__oo : referencemodel::classtype<_Base0, {node.name}__oo<>> {{" yield f"struct {node.name}__oo : referencemodel::classtype<_Base0, {node.name}__oo<>> {{"
yield f"static constexpr std::string_view name = \"{node.name}\";" yield f"static constexpr std::string_view name = \"{node.name}\";"
......
...@@ -5,8 +5,35 @@ from typing import Iterable ...@@ -5,8 +5,35 @@ from typing import Iterable
from transpiler.phases.emit_cpp.visitors import NodeVisitor, CoroutineMode, join from transpiler.phases.emit_cpp.visitors import NodeVisitor, CoroutineMode, join
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.utils import make_lnd
from transpiler.utils import linenodata
SYMBOLS = {
ast.Eq: "==",
ast.NotEq: '!=',
ast.Pass: '/* pass */',
ast.Mult: '*',
ast.Add: '+',
ast.Sub: '-',
ast.Div: '/',
ast.FloorDiv: '/', # TODO
ast.Mod: '%',
ast.Lt: '<',
ast.Gt: '>',
ast.GtE: '>=',
ast.LtE: '<=',
ast.LShift: '<<',
ast.RShift: '>>',
ast.BitXor: '^',
ast.BitOr: '|',
ast.BitAnd: '&',
ast.Not: '!',
ast.IsNot: '!=',
ast.USub: '-',
ast.And: '&&',
ast.Or: '||'
}
"""Mapping of Python AST nodes to C++ symbols."""
# noinspection PyPep8Naming # noinspection PyPep8Naming
@dataclass @dataclass
...@@ -33,7 +60,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -33,7 +60,7 @@ class ExpressionVisitor(NodeVisitor):
yield str(node.value).lower() yield str(node.value).lower()
elif isinstance(node.value, int): elif isinstance(node.value, int):
# TODO: bigints # TODO: bigints
yield str(node.value) yield str(node.value) + "_pi"
elif isinstance(node.value, float): elif isinstance(node.value, float):
yield repr(node.value) yield repr(node.value)
elif isinstance(node.value, complex): elif isinstance(node.value, complex):
...@@ -44,7 +71,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -44,7 +71,7 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, type(node)) raise NotImplementedError(node, type(node))
def visit_Slice(self, node: ast.Slice) -> Iterable[str]: def visit_Slice(self, node: ast.Slice) -> Iterable[str]:
yield "TySlice(" yield "typon::TySlice("
yield from join(", ", (self.visit(x or ast.Constant(value=None)) for x in (node.lower, node.upper, node.step))) yield from join(", ", (self.visit(x or ast.Constant(value=None)) for x in (node.lower, node.upper, node.step)))
yield ")" yield ")"
...@@ -79,22 +106,22 @@ class ExpressionVisitor(NodeVisitor): ...@@ -79,22 +106,22 @@ class ExpressionVisitor(NodeVisitor):
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right)) # yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]: def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
raise NotImplementedError() if len(node.values) == 1:
yield from self.visit(node.values[0])
# if len(node.values) == 1: return
# yield from self.visit(node.values[0]) cpp_op = {
# return ast.And: "&&",
# cpp_op = { ast.Or: "||"
# ast.And: "&&", }[type(node.op)]
# ast.Or: "||" yield "("
# }[type(node.op)] yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
# with self.prec_ctx(cpp_op): for left, right in zip(node.values[1:], node.values[2:]):
# yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1])) yield f" {cpp_op} "
# for left, right in zip(node.values[1:], node.values[2:]): yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
# yield f" {cpp_op} " yield ")"
# yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]: def visit_Call(self, node: ast.Call) -> Iterable[str]:
yield "co_await"
yield "(" yield "("
yield from self.visit(node.func) yield from self.visit(node.func)
yield ")(" yield ")("
...@@ -158,6 +185,8 @@ class ExpressionVisitor(NodeVisitor): ...@@ -158,6 +185,8 @@ class ExpressionVisitor(NodeVisitor):
templ, args, _ = self.process_args(node.args) templ, args, _ = self.process_args(node.args)
yield templ yield templ
yield args yield args
yield "->"
yield from self.visit(node.type.deref().return_type)
yield "{" yield "{"
yield "return" yield "return"
yield from self.reset().visit(node.body) yield from self.reset().visit(node.body)
...@@ -165,14 +194,20 @@ class ExpressionVisitor(NodeVisitor): ...@@ -165,14 +194,20 @@ class ExpressionVisitor(NodeVisitor):
yield "}" yield "}"
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]: def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node)) yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node))
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node)) yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node))
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]: def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]:
yield "(co_await ("
yield from self.visit(left)
yield " "
yield SYMBOLS[type(op)]
yield " "
yield from self.visit(right)
yield "))"
return
raise NotImplementedError() raise NotImplementedError()
# if type(op) == ast.In: # if type(op) == ast.In:
# call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd) # call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
...@@ -205,7 +240,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -205,7 +240,7 @@ class ExpressionVisitor(NodeVisitor):
def visit_List(self, node: ast.List) -> Iterable[str]: def visit_List(self, node: ast.List) -> Iterable[str]:
if node.elts: if node.elts:
yield "typon::TyList{" yield "typon::TyList{"
yield from join(", ", map(self.reset().visit, node.elts)) yield from join(", ", map(self.visit, node.elts))
yield "}" yield "}"
else: else:
yield from self.visit(node.type) yield from self.visit(node.type)
...@@ -214,7 +249,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -214,7 +249,7 @@ class ExpressionVisitor(NodeVisitor):
def visit_Set(self, node: ast.Set) -> Iterable[str]: def visit_Set(self, node: ast.Set) -> Iterable[str]:
if node.elts: if node.elts:
yield "typon::TySet{" yield "typon::TySet{"
yield from join(", ", map(self.reset().visit, node.elts)) yield from join(", ", map(self.visit, node.elts))
yield "}" yield "}"
else: else:
yield from self.visit(node.type) yield from self.visit(node.type)
...@@ -223,9 +258,9 @@ class ExpressionVisitor(NodeVisitor): ...@@ -223,9 +258,9 @@ class ExpressionVisitor(NodeVisitor):
def visit_Dict(self, node: ast.Dict) -> Iterable[str]: def visit_Dict(self, node: ast.Dict) -> Iterable[str]:
def visit_item(key, value): def visit_item(key, value):
yield "std::pair {" yield "std::pair {"
yield from self.reset().visit(key) yield from self.visit(key)
yield ", " yield ", "
yield from self.reset().visit(value) yield from self.visit(value)
yield "}" yield "}"
if node.keys: if node.keys:
...@@ -256,12 +291,13 @@ class ExpressionVisitor(NodeVisitor): ...@@ -256,12 +291,13 @@ class ExpressionVisitor(NodeVisitor):
yield from self.prec("unary").visit(operand) yield from self.prec("unary").visit(operand)
def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]: def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]:
with self.prec_ctx("?:"): yield "("
yield from self.visit(node.test) yield from self.visit(node.test)
yield " ? " yield " ? "
yield from self.visit(node.body) yield from self.visit(node.body)
yield " : " yield " : "
yield from self.visit(node.orelse) yield from self.visit(node.orelse)
yield ")"
def visit_Yield(self, node: ast.Yield) -> Iterable[str]: def visit_Yield(self, node: ast.Yield) -> Iterable[str]:
#if CoroutineMode.GENERATOR in self.generator: #if CoroutineMode.GENERATOR in self.generator:
...@@ -279,8 +315,23 @@ class ExpressionVisitor(NodeVisitor): ...@@ -279,8 +315,23 @@ class ExpressionVisitor(NodeVisitor):
if len(node.generators) != 1: if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet") raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0] gen: ast.comprehension = node.generators[0]
yield "MAP_FILTER("
yield from self.visit(gen.target)
yield ","
yield from self.visit(gen.iter)
yield ", "
yield from self.visit(node.elt)
yield ", "
if gen.ifs:
yield from self.visit(gen.ifs_node)
else:
yield "true"
yield ")"
return
yield "mapFilter([](" yield "mapFilter([]("
yield from self.visit(node.input_item_type) #yield from self.visit(node.input_item_type)
yield "auto"
yield from self.visit(gen.target) yield from self.visit(gen.target)
yield ") { return " yield ") { return "
yield from self.visit(node.elt) yield from self.visit(node.elt)
...@@ -289,9 +340,12 @@ class ExpressionVisitor(NodeVisitor): ...@@ -289,9 +340,12 @@ class ExpressionVisitor(NodeVisitor):
if gen.ifs: if gen.ifs:
yield ", " yield ", "
yield "[](" yield "[]("
yield from self.visit(node.input_item_type) #yield from self.visit(node.input_item_type)
yield "auto"
yield from self.visit(gen.target) yield from self.visit(gen.target)
yield ") { return " yield ") -> typon::Task<"
yield from self.visit(gen.ifs_node.type)
yield "> { return "
yield from self.visit(gen.ifs_node) yield from self.visit(gen.ifs_node)
yield "; }" yield "; }"
yield ")" yield ")"
......
...@@ -3,6 +3,7 @@ from dataclasses import dataclass, field ...@@ -3,6 +3,7 @@ from dataclasses import dataclass, field
from typing import Iterable, Optional from typing import Iterable, Optional
from transpiler.phases.emit_cpp.expr import ExpressionVisitor from transpiler.phases.emit_cpp.expr import ExpressionVisitor
from transpiler.phases.typing.common import IsDeclare
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp.visitors import NodeVisitor, flatmap, CoroutineMode from transpiler.phases.emit_cpp.visitors import NodeVisitor, flatmap, CoroutineMode
from transpiler.phases.typing.types import CallableInstanceType, BaseType from transpiler.phases.typing.types import CallableInstanceType, BaseType
...@@ -10,7 +11,7 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType ...@@ -10,7 +11,7 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType
def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]: def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]:
yield f"struct : referencemodel::function {{" yield f"struct : referencemodel::function {{"
yield "typon::Task<void> operator()(" yield "typon::Task<typon::TyNone> operator()("
for arg, ty in zip(func.block_data.node.args.args, func.parameters): for arg, ty in zip(func.block_data.node.args.args, func.parameters):
yield "auto " yield "auto "
...@@ -151,32 +152,36 @@ class BlockVisitor(NodeVisitor): ...@@ -151,32 +152,36 @@ class BlockVisitor(NodeVisitor):
# #
# yield "}" # yield "}"
# #
# def visit_lvalue(self, lvalue: ast.expr, declare: bool | list[bool] = False) -> Iterable[str]: def visit_lvalue(self, lvalue: ast.expr, declare: IsDeclare) -> Iterable[str]:
# if isinstance(lvalue, ast.Tuple): if isinstance(lvalue, ast.Tuple):
# for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args): raise NotImplementedError()
# if decl: # for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args):
# yield from self.visit_lvalue(name, True) # if decl:
# yield ";" # yield from self.visit_lvalue(name, True)
# yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})" # yield ";"
# elif isinstance(lvalue, ast.Name): # yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})"
# if lvalue.id == "_": elif isinstance(lvalue, ast.Name):
# if not declare: if lvalue.id == "_":
# yield "std::ignore" if not declare:
# return yield "std::ignore"
# name = self.fix_name(lvalue.id) return
# # if name not in self._scope.vars: name = self.fix_name(lvalue.id)
# # if not self.scope.exists_local(name): # if name not in self._scope.vars:
# # yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None, # if not self.scope.exists_local(name):
# # getattr(val, "is_future", False)) # yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None,
# if declare: # getattr(val, "is_future", False))
# yield from self.visit(lvalue.type) if declare:
# yield name yield "decltype("
# elif isinstance(lvalue, ast.Subscript): yield from self.expr().visit(declare.initial_value)
# yield from self.expr().visit(lvalue) yield ")"
# elif isinstance(lvalue, ast.Attribute): #yield from self.visit(lvalue.type)
# yield from self.expr().visit(lvalue) yield name
# else: elif isinstance(lvalue, ast.Subscript):
# raise NotImplementedError(lvalue) yield from self.expr().visit(lvalue)
elif isinstance(lvalue, ast.Attribute):
yield from self.expr().visit(lvalue)
else:
raise NotImplementedError(lvalue)
def visit_Assign(self, node: ast.Assign) -> Iterable[str]: def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1: if len(node.targets) != 1:
...@@ -192,3 +197,27 @@ class BlockVisitor(NodeVisitor): ...@@ -192,3 +197,27 @@ class BlockVisitor(NodeVisitor):
yield " = " yield " = "
yield from self.expr().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
def visit_For(self, node: ast.For) -> Iterable[str]:
if not isinstance(node.target, ast.Name):
raise NotImplementedError(node)
if node.orelse:
yield "auto"
yield node.orelse_variable
yield "= true;"
yield f"for (auto {node.target.id} : "
yield from self.expr().visit(node.iter)
yield ")"
yield from self.emit_block(node.inner_scope, node.body) # TODO: why not reuse the scope used for analysis? same in while
if node.orelse:
yield "if ("
yield node.orelse_variable
yield ")"
yield from self.emit_block(node.inner_scope, node.orelse)
def emit_block(self, scope: Scope, items: Iterable[ast.stmt]) -> Iterable[str]:
yield "{"
for child in items:
yield from BlockVisitor(scope, generator=self.generator).visit(child)
yield "}"
...@@ -15,6 +15,9 @@ class UniversalVisitor: ...@@ -15,6 +15,9 @@ class UniversalVisitor:
__TB__ = f"emitting C++ code for {highlight(node)}" __TB__ = f"emitting C++ code for {highlight(node)}"
# __TB_SKIP__ = True # __TB_SKIP__ = True
if orig := getattr(node, "orig_node", None):
node = orig
if type(node) == list: if type(node) == list:
for n in node: for n in node:
yield from self.visit(n) yield from self.visit(n)
...@@ -56,17 +59,28 @@ class NodeVisitor(UniversalVisitor): ...@@ -56,17 +59,28 @@ class NodeVisitor(UniversalVisitor):
match node: match node:
case types.TY_INT: case types.TY_INT:
yield "int" yield "decltype(0_pi)"
case types.TY_FLOAT: case types.TY_FLOAT:
yield "double" yield "double"
case types.TY_BOOL: case types.TY_BOOL:
yield "bool" yield "bool"
case types.TY_NONE: case types.TY_NONE:
yield "void" yield "typon::TyNone"
case types.TY_STR: case types.TY_STR:
yield "TyStr" yield "typon::TyStr"
case types.TypeVariable(name): case types.TypeVariable(name):
raise UnresolvedTypeVariableError(node) raise UnresolvedTypeVariableError(node)
case types.GenericInstanceType():
yield from self.visit(node.generic_parent)
yield "<"
yield from join(",", map(self.visit, node.generic_args))
yield ">"
case types.TY_LIST:
yield "typon::TyList"
case types.TY_DICT:
yield "typon::TyDict"
case types.TY_SET:
yield "typon::TySet"
case _: case _:
raise NotImplementedError(node) raise NotImplementedError(node)
......
...@@ -5,15 +5,11 @@ import importlib ...@@ -5,15 +5,11 @@ import importlib
from dataclasses import dataclass from dataclasses import dataclass
from transpiler.exceptions import CompileError from transpiler.exceptions import CompileError
from transpiler.phases.typing.types import BaseType, TypeVariable
from transpiler.utils import highlight, linenodata from transpiler.utils import highlight, linenodata
from transpiler.phases.typing import make_mod_decl from transpiler.phases.typing.common import ScoperVisitor, is_builtin, DeclareInfo
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next, is_builtin
from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT, MemberDef, \
RuntimeValue, GenericUserType, MonomorphizedUserType
from transpiler.phases.utils import PlainBlock, AnnotationName from transpiler.phases.utils import PlainBlock, AnnotationName
...@@ -24,51 +20,51 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -24,51 +20,51 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_Pass(self, node: ast.Pass): def visit_Pass(self, node: ast.Pass):
pass pass
def get_module(self, name: str) -> VarDecl: # def get_module(self, name: str) -> VarDecl:
mod = self.scope.get(name, VarKind.MODULE) # mod = self.scope.get(name, VarKind.MODULE)
if mod is None: # if mod is None:
# try lookup with importlib # # try lookup with importlib
py_mod = importlib.import_module(name) # py_mod = importlib.import_module(name)
mod_scope = Scope() # mod_scope = Scope()
# copy all functions to mod_scope # # copy all functions to mod_scope
for fname, obj in py_mod.__dict__.items(): # for fname, obj in py_mod.__dict__.items():
if callable(obj): # if callable(obj):
# fty = FunctionType([], TypeVariable()) # # fty = FunctionType([], TypeVariable())
# fty.is_python_func = True # # fty.is_python_func = True
fty = TypeVariable() # fty = TypeVariable()
fty.is_python_func = True # fty.is_python_func = True
mod_scope.vars[fname] = VarDecl(VarKind.LOCAL, fty) # mod_scope.vars[fname] = VarDecl(VarKind.LOCAL, fty)
mod = make_mod_decl(name, mod_scope) # mod = make_mod_decl(name, mod_scope)
mod.type.is_python = True # mod.type.is_python = True
self.scope.vars[name] = mod # self.scope.vars[name] = mod
if mod is None: # if mod is None:
from transpiler.phases.typing.exceptions import UnknownNameError # from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(name) # raise UnknownNameError(name)
assert isinstance(mod, VarDecl), mod # assert isinstance(mod, VarDecl), mod
assert isinstance(mod.type, ModuleType), mod.type # assert isinstance(mod.type, ModuleType), mod.type
return mod # return mod
#
def visit_Import(self, node: ast.Import): # def visit_Import(self, node: ast.Import):
for alias in node.names: # for alias in node.names:
mod = self.get_module(alias.name) # mod = self.get_module(alias.name)
alias.module_obj = mod.type # alias.module_obj = mod.type
self.scope.vars[alias.asname or alias.name] = dataclasses.replace(mod, kind=VarKind.LOCAL) # self.scope.vars[alias.asname or alias.name] = dataclasses.replace(mod, kind=VarKind.LOCAL)
#
def visit_ImportFrom(self, node: ast.ImportFrom): # def visit_ImportFrom(self, node: ast.ImportFrom):
if node.module in {"typing2", "__future__"}: # if node.module in {"typing2", "__future__"}:
return # return
module = self.get_module(node.module) # module = self.get_module(node.module)
node.module_obj = module.type # node.module_obj = module.type
for alias in node.names: # for alias in node.names:
thing = module.val.get(alias.name) # thing = module.val.get(alias.name)
if not thing: # if not thing:
from transpiler.phases.typing.exceptions import UnknownModuleMemberError # from transpiler.phases.typing.exceptions import UnknownModuleMemberError
raise UnknownModuleMemberError(node.module, alias.name) # raise UnknownModuleMemberError(node.module, alias.name)
alias.item_obj = thing # alias.item_obj = thing
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing) # self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing)
#
def visit_Module(self, node: ast.Module): # def visit_Module(self, node: ast.Module):
self.visit_block(node.body) # self.visit_block(node.body)
def get_type(self, node: ast.expr) -> BaseType: def get_type(self, node: ast.expr) -> BaseType:
if type := getattr(node, "type", None): if type := getattr(node, "type", None):
...@@ -85,8 +81,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -85,8 +81,7 @@ class ScoperBlockVisitor(ScoperVisitor):
target = node.targets[0] target = node.targets[0]
ty = self.get_type(node.value) ty = self.get_type(node.value)
decl = self.visit_assign_target(target, ty) decl = self.visit_assign_target(target, ty)
if not hasattr(node, "is_declare"): node.is_declare = DeclareInfo(node.value) if decl else None
node.is_declare = decl
def visit_AnnAssign(self, node: ast.AnnAssign): def visit_AnnAssign(self, node: ast.AnnAssign):
if node.simple != 1: if node.simple != 1:
...@@ -95,8 +90,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -95,8 +90,7 @@ class ScoperBlockVisitor(ScoperVisitor):
raise NotImplementedError(node) raise NotImplementedError(node)
ty = self.visit_annotation(node.annotation) ty = self.visit_annotation(node.annotation)
decl = self.visit_assign_target(node.target, ty) decl = self.visit_assign_target(node.target, ty)
if not hasattr(node, "is_declare"): node.is_declare = DeclareInfo(node.value) if decl else None
node.is_declare = decl
if node.value is not None: if node.value is not None:
ty_val = self.get_type(node.value) ty_val = self.get_type(node.value)
__TB__ = f"unifying annotation {highlight(node.annotation)} with value {highlight(node.value)} of type {highlight(ty_val)}" __TB__ = f"unifying annotation {highlight(node.annotation)} with value {highlight(node.value)} of type {highlight(ty_val)}"
...@@ -144,146 +138,6 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -144,146 +138,6 @@ class ScoperBlockVisitor(ScoperVisitor):
else: else:
raise NotImplementedError(ast.unparse(target)) raise NotImplementedError(ast.unparse(target))
def visit_FunctionDef(self, node: ast.FunctionDef):
ftype = self.parse_function(node)
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
def process_class_ast(self, ctype: BaseType, node: ast.ClassDef, bases_after: list[ast.expr]):
scope = self.scope.child(ScopeKind.CLASS)
scope.obj_type = ctype
scope.class_ = scope
node.inner_scope = scope
node.type = ctype
visitor = ScoperClassVisitor(scope, cur_class=TypeType(ctype))
visitor.visit_block(node.body)
for base in bases_after:
base = self.expr().visit(base)
if is_builtin(base, "Enum"):
ctype.parents.append(TY_INT)
for k, m in ctype.fields.items():
m.type = ctype
m.val = ast.literal_eval(m.val)
assert type(m.val) == int
ctype.fields["value"] = MemberDef(TY_INT)
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), ast.arg(arg="value")],
defaults=[],
kw_defaults=[],
kwarg=None,
kwonlyargs=[],
posonlyargs=[],
),
body=[
ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr="value")],
value=ast.Name(id="value"),
**lnd
)
],
decorator_list=[],
returns=None,
type_comment=None,
**lnd
)
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method)
ctype.is_enum = True
else:
raise NotImplementedError(base)
for deco in node.decorator_list:
deco = self.expr().visit(deco)
if is_builtin(deco, "dataclass"):
# init_type = FunctionType([cttype, *cttype.members.values()], TypeVariable())
# cttype.methods["__init__"] = init_type
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.get_members()]],
defaults=[],
kw_defaults=[],
kwarg=None,
kwonlyargs=[],
posonlyargs=[],
),
body=[
ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)],
value=ast.Name(id=n),
**lnd
) for n in ctype.get_members()
],
decorator_list=[],
returns=None,
type_comment=None,
**lnd
)
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method)
else:
raise NotImplementedError(deco)
return ctype
def visit_ClassDef(self, node: ast.ClassDef):
copied = copy.deepcopy(node)
class NewUserType(UserType):
def __init__(self):
super().__init__(node.name)
#ctype = UserType(node.name)
typevars = []
bases_after = []
for base in node.bases:
if isinstance(base, ast.Subscript):
if isinstance(base.slice, ast.Name):
sliceval = [base.slice.id]
elif isinstance(base.slice, ast.Tuple):
sliceval = [n.id for n in base.slice.elts]
if is_builtin(self.expr().visit(base.value), "Generic"):
typevars = sliceval
else:
bases_after.append(base)
if typevars:
# generic
#ctype = GenericUserType(node.name, typevars, node)
var_scope = self.scope.child(ScopeKind.GLOBAL)
var_visitor = ScoperBlockVisitor(var_scope, self.root_decls)
node.gen_instances = {}
class OurGenericType(GenericUserType):
# def __init__(self, *args):
# super().__init__(node.name)
# for tv, arg in zip(typevars, args):
# var_scope.declare_local(tv, arg)
# var_visitor.process_class_ast(self, node, bases_after)
def __new__(cls, *args, **kwargs):
res = MonomorphizedUserType(node.name + "$$" + "__".join(map(str, args)) + "$$")
for tv, arg in zip(typevars, args):
var_scope.declare_local(tv, arg)
new_node = copy.deepcopy(copied)
new_node.name = res.name
var_visitor.process_class_ast(res, new_node, bases_after)
node.gen_instances[tuple(args)] = new_node
return res
ctype = OurGenericType
else:
# not generic
ctype = self.process_class_ast(UserType(node.name), node, bases_after)
cttype = TypeType(ctype)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, cttype)
def visit_If(self, node: ast.If): def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
node.inner_scope = scope node.inner_scope = scope
...@@ -327,8 +181,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -327,8 +181,8 @@ class ScoperBlockVisitor(ScoperVisitor):
var_var = TypeVariable() var_var = TypeVariable()
scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, var_var) scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, var_var)
seq_type = self.expr().visit(node.iter) seq_type = self.expr().visit(node.iter)
iter_type = get_iter(seq_type) iter_type = self.get_iter(seq_type)
next_type = get_next(iter_type) next_type = self.get_next(iter_type)
var_var.unify(next_type) var_var.unify(next_type)
body_scope = scope.child(ScopeKind.FUNCTION_INNER) body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls) body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
......
# coding: utf-8 # coding: utf-8
import ast # import ast
from dataclasses import dataclass, field # from dataclasses import dataclass, field
#
from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE # from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE
from transpiler.phases.typing.common import ScoperVisitor # from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import PromiseKind, Promise, BaseType, MemberDef # from transpiler.phases.typing.types import PromiseKind, Promise, BaseType, MemberDef
#
#
@dataclass # @dataclass
class ScoperClassVisitor(ScoperVisitor): # class ScoperClassVisitor(ScoperVisitor):
fdecls: list[(ast.FunctionDef, BaseType)] = field(default_factory=list) # fdecls: list[(ast.FunctionDef, BaseType)] = field(default_factory=list)
#
def visit_AnnAssign(self, node: ast.AnnAssign): # def visit_AnnAssign(self, node: ast.AnnAssign):
assert node.value is None, "Class field should not have a value" # assert node.value is None, "Class field should not have a value"
assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)" # assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)"
assert isinstance(node.target, ast.Name) # assert isinstance(node.target, ast.Name)
self.scope.obj_type.fields[node.target.id] = MemberDef(self.visit_annotation(node.annotation)) # self.scope.obj_type.fields[node.target.id] = MemberDef(self.visit_annotation(node.annotation))
#
def visit_Assign(self, node: ast.Assign): # def visit_Assign(self, node: ast.Assign):
assert len(node.targets) == 1, "Can't use destructuring in class static member" # assert len(node.targets) == 1, "Can't use destructuring in class static member"
assert isinstance(node.targets[0], ast.Name) # assert isinstance(node.targets[0], ast.Name)
node.is_declare = True # node.is_declare = True
valtype = self.expr().visit(node.value) # valtype = self.expr().visit(node.value)
node.targets[0].type = valtype # node.targets[0].type = valtype
self.scope.obj_type.fields[node.targets[0].id] = MemberDef(valtype, node.value) # self.scope.obj_type.fields[node.targets[0].id] = MemberDef(valtype, node.value)
#
def visit_FunctionDef(self, node: ast.FunctionDef): # def visit_FunctionDef(self, node: ast.FunctionDef):
ftype = self.parse_function(node) # ftype = self.parse_function(node)
ftype.parameters[0].unify(self.scope.obj_type) # ftype.parameters[0].unify(self.scope.obj_type)
inner = ftype.return_type # inner = ftype.return_type
if node.name != "__init__": # if node.name != "__init__":
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK) # ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
ftype.is_method = True # ftype.is_method = True
self.scope.obj_type.fields[node.name] = MemberDef(ftype, node) # self.scope.obj_type.fields[node.name] = MemberDef(ftype, node)
return (node, inner) # return (node, inner)
import ast import ast
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional from typing import Dict, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.utils import highlight from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl, VarKind from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl, VarKind
...@@ -10,6 +12,7 @@ from transpiler.phases.utils import NodeVisitorSeq, AnnotationName ...@@ -10,6 +12,7 @@ from transpiler.phases.utils import NodeVisitorSeq, AnnotationName
PRELUDE = Scope.make_global() PRELUDE = Scope.make_global()
@dataclass @dataclass
class ScoperVisitor(NodeVisitorSeq): class ScoperVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL)) scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL))
...@@ -93,10 +96,11 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -93,10 +96,11 @@ class ScoperVisitor(NodeVisitorSeq):
elif len(visitor.fdecls) == 1: elif len(visitor.fdecls) == 1:
fnode, frtype = visitor.fdecls[0] fnode, frtype = visitor.fdecls[0]
self.visit_function_definition(fnode, frtype) self.visit_function_definition(fnode, frtype)
#del node.inner_scope.vars[fnode.name] # del node.inner_scope.vars[fnode.name]
visitor.visit_assign_target(ast.Name(fnode.name), fnode.type) visitor.visit_assign_target(ast.Name(fnode.name), fnode.type)
b.decls = decls b.decls = decls
if not node.inner_scope.diverges and not (isinstance(node.type.return_type, Promise) and node.type.return_type.kind == PromiseKind.GENERATOR): if not node.inner_scope.diverges and not (
isinstance(node.type.return_type, Promise) and node.type.return_type.kind == PromiseKind.GENERATOR):
from transpiler.phases.typing.exceptions import TypeMismatchError from transpiler.phases.typing.exceptions import TypeMismatchError
try: try:
rtype.unify(TY_NONE) rtype.unify(TY_NONE)
...@@ -104,21 +108,28 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -104,21 +108,28 @@ class ScoperVisitor(NodeVisitorSeq):
from transpiler.phases.typing.exceptions import MissingReturnError from transpiler.phases.typing.exceptions import MissingReturnError
raise MissingReturnError(node) from e raise MissingReturnError(node) from e
def get_iter(seq_type): def get_iter(self, seq_type):
try: try:
iter_type = seq_type.fields["__iter__"].type.return_type return self.expr().visit_function_call(self.expr().visit_getattr(seq_type, "__iter__"), [])
except: except:
from transpiler.phases.typing.exceptions import NotIterableError from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type) raise NotIterableError(seq_type)
return iter_type
def get_next(self, iter_type):
def get_next(iter_type): try:
try: return self.expr().visit_function_call(self.expr().visit_getattr(iter_type, "__next__"), [])
next_type = iter_type.fields["__next__"].type.return_type except:
except: from transpiler.phases.typing.exceptions import NotIterableError
from transpiler.phases.typing.exceptions import NotIteratorError raise NotIterableError(iter_type)
raise NotIteratorError(iter_type)
return next_type
def is_builtin(x, feature): def is_builtin(x, feature):
return isinstance(x, BuiltinFeatureType) and x.feature() == feature return isinstance(x, BuiltinFeatureType) and x.feature() == feature
\ No newline at end of file
@dataclass
class DeclareInfo:
initial_value: Optional[ast.expr] = None
IsDeclare = None | DeclareInfo
...@@ -4,7 +4,7 @@ import inspect ...@@ -4,7 +4,7 @@ import inspect
from itertools import zip_longest from itertools import zip_longest
from typing import List from typing import List
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next, is_builtin from transpiler.phases.typing.common import ScoperVisitor, is_builtin
from transpiler.phases.typing.exceptions import ArgumentCountMismatchError, TypeMismatchKind, TypeMismatchError from transpiler.phases.typing.exceptions import ArgumentCountMismatchError, TypeMismatchKind, TypeMismatchError
from transpiler.phases.typing.types import BaseType, TY_STR, TY_BOOL, TY_INT, TY_COMPLEX, TY_FLOAT, TY_NONE, \ from transpiler.phases.typing.types import BaseType, TY_STR, TY_BOOL, TY_INT, TY_COMPLEX, TY_FLOAT, TY_NONE, \
ClassTypeType, ResolvedConcreteType, GenericType, CallableInstanceType, TY_LIST, TY_SET, TY_DICT, RuntimeValue, \ ClassTypeType, ResolvedConcreteType, GenericType, CallableInstanceType, TY_LIST, TY_SET, TY_DICT, RuntimeValue, \
...@@ -141,7 +141,12 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -141,7 +141,12 @@ class ScoperExprVisitor(ScoperVisitor):
if not a.try_assign(b): if not a.try_assign(b):
raise TypeMismatchError(a, b, TypeMismatchKind.DIFFERENT_TYPE) raise TypeMismatchError(a, b, TypeMismatchKind.DIFFERENT_TYPE)
return ftype.return_type if not ftype.is_native:
from transpiler.phases.typing.block import ScoperBlockVisitor
vis = ScoperBlockVisitor(ftype.block_data.scope)
for stmt in ftype.block_data.node.body:
vis.visit(stmt)
return ftype.return_type.resolve()
# if isinstance(ftype, TypeType):# and isinstance(ftype.type_object, UserType): # if isinstance(ftype, TypeType):# and isinstance(ftype.type_object, UserType):
# init: FunctionType = self.visit_getattr(ftype, "__init__").remove_self() # init: FunctionType = self.visit_getattr(ftype, "__init__").remove_self()
# init.return_type = ftype.type_object # init.return_type = ftype.type_object
...@@ -176,6 +181,10 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -176,6 +181,10 @@ class ScoperExprVisitor(ScoperVisitor):
node.body.decls = decls node.body.decls = decls
return ftype return ftype
def visit_BinOp(self, node: ast.BinOp) -> BaseType:
left, right = map(self.visit, (node.left, node.right))
return TypeVariable() # TODO
# def visit_BinOp(self, node: ast.BinOp) -> BaseType: # def visit_BinOp(self, node: ast.BinOp) -> BaseType:
# left, right = map(self.visit, (node.left, node.right)) # left, right = map(self.visit, (node.left, node.right))
# return self.make_dunder([left, right], DUNDER[type(node.op)]) # return self.make_dunder([left, right], DUNDER[type(node.op)])
...@@ -300,14 +309,15 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -300,14 +309,15 @@ class ScoperExprVisitor(ScoperVisitor):
if len(node.generators) != 1: if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet") raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0] gen: ast.comprehension = node.generators[0]
iter_type = get_iter(self.visit(gen.iter)) iter_type = self.get_iter(self.visit(gen.iter))
node.input_item_type = get_next(iter_type) node.input_item_type = self.get_next(iter_type)
virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER) virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER)
from transpiler import ScoperBlockVisitor from transpiler.phases.typing.block import ScoperBlockVisitor
visitor = ScoperBlockVisitor(virt_scope) visitor = ScoperBlockVisitor(virt_scope)
visitor.visit_assign_target(gen.target, node.input_item_type) visitor.visit_assign_target(gen.target, node.input_item_type)
node.item_type = visitor.expr().visit(node.elt) node.item_type = visitor.expr().visit(node.elt)
for if_ in gen.ifs: # for if_ in gen.ifs:
visitor.expr().visit(if_) # visitor.expr().visit(if_)
gen.ifs_node = ast.BoolOp(ast.And(), gen.ifs, **linenodata(node)) gen.ifs_node = ast.BoolOp(ast.And(), gen.ifs, **linenodata(node))
return TyList(node.item_type) visitor.expr().visit(gen.ifs_node)
\ No newline at end of file return TY_LIST.instantiate([node.item_type])
\ No newline at end of file
...@@ -4,7 +4,7 @@ from logging import debug ...@@ -4,7 +4,7 @@ from logging import debug
from transpiler.phases.typing import PRELUDE from transpiler.phases.typing import PRELUDE
from transpiler.phases.typing.scope import Scope, VarKind, VarDecl, ScopeKind from transpiler.phases.typing.scope import Scope, VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.types import MemberDef, ResolvedConcreteType, UniqueTypeMixin from transpiler.phases.typing.types import MemberDef, ResolvedConcreteType, UniqueTypeMixin, BlockData
class ModuleType(UniqueTypeMixin, ResolvedConcreteType): class ModuleType(UniqueTypeMixin, ResolvedConcreteType):
...@@ -21,39 +21,44 @@ def make_module(name: str, scope: Scope) -> ModuleType: ...@@ -21,39 +21,44 @@ def make_module(name: str, scope: Scope) -> ModuleType:
visited_modules = {} visited_modules = {}
def parse_module(mod_name: str, python_path: Path, scope=None, preprocess=None): def parse_module(mod_name: str, python_path: list[Path], scope=None, preprocess=None) -> ModuleType:
path = python_path / mod_name for path in python_path:
path = path / mod_name
if not path.exists(): if not path.exists():
path = path.with_suffix(".py") path = path.with_suffix(".py")
if not path.exists(): if not path.exists():
path = path.with_stem(mod_name + "_") path = path.with_stem(mod_name + "_")
if not path.exists(): if not path.exists():
raise FileNotFoundError(f"Could not find {path}") continue
if path.is_dir(): break
real_path = path / "__init__.py"
else: else:
real_path = path raise FileNotFoundError(f"Could not find {mod_name}")
if path.is_dir():
path = path / "__init__.py"
if mod := visited_modules.get(real_path.as_posix()): if mod := visited_modules.get(path.as_posix()):
return mod return mod.type
mod_scope = scope or PRELUDE.child(ScopeKind.GLOBAL) mod_scope = scope or PRELUDE.child(ScopeKind.GLOBAL)
if real_path.suffix == ".py": if path.suffix != ".py":
from transpiler.phases.typing.stdlib import StdlibVisitor
node = ast.parse(real_path.read_text())
if preprocess:
node = preprocess(node)
StdlibVisitor(python_path, mod_scope).visit(node)
else:
raise NotImplementedError(f"Unsupported file type {path.suffix}") raise NotImplementedError(f"Unsupported file type {path.suffix}")
from transpiler.phases.typing.stdlib import StdlibVisitor
node = ast.parse(path.read_text())
if preprocess:
node = preprocess(node)
from transpiler.transpiler import TYPON_STD
StdlibVisitor(python_path, mod_scope, is_native=TYPON_STD in path.parents).visit(node)
mod = make_module(mod_name, mod_scope) mod = make_module(mod_name, mod_scope)
visited_modules[real_path.as_posix()] = VarDecl(VarKind.LOCAL, mod, {k: v.type for k, v in mod_scope.vars.items()}) mod.block_data = BlockData(node, mod_scope)
visited_modules[path.as_posix()] = VarDecl(VarKind.LOCAL, mod)
return mod return mod
# def process_module(mod_path: Path, scope): # def process_module(mod_path: Path, scope):
......
...@@ -75,9 +75,10 @@ def visit_generic_item( ...@@ -75,9 +75,10 @@ def visit_generic_item(
@dataclass @dataclass
class StdlibVisitor(NodeVisitorSeq): class StdlibVisitor(NodeVisitorSeq):
python_path: Path python_path: list[Path]
scope: Scope = field(default_factory=lambda: PRELUDE) scope: Scope = field(default_factory=lambda: PRELUDE)
cur_class: Optional[ResolvedConcreteType] = None cur_class: Optional[ResolvedConcreteType] = None
is_native: bool = False
def resolve_module_import(self, name: str): def resolve_module_import(self, name: str):
# tries = [ # tries = [
...@@ -88,7 +89,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -88,7 +89,7 @@ class StdlibVisitor(NodeVisitorSeq):
# if path.exists(): # if path.exists():
# return path # return path
# raise FileNotFoundError(f"Could not find module {name}") # raise FileNotFoundError(f"Could not find module {name}")
return parse_module(name, self.python_path) return parse_module(name, self.python_path, self.scope.child(ScopeKind.GLOBAL))
def expr(self) -> ScoperExprVisitor: def expr(self) -> ScoperExprVisitor:
return ScoperExprVisitor(self.scope) return ScoperExprVisitor(self.scope)
...@@ -122,7 +123,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -122,7 +123,7 @@ class StdlibVisitor(NodeVisitorSeq):
for alias in node.names: for alias in node.names:
mod = self.resolve_module_import(alias.name) mod = self.resolve_module_import(alias.name)
alias.module_obj = mod alias.module_obj = mod
self.scope.vars[alias.asname or alias.name] = mod self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, mod)
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
if existing := self.scope.get(node.name): if existing := self.scope.get(node.name):
...@@ -137,7 +138,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -137,7 +138,7 @@ class StdlibVisitor(NodeVisitorSeq):
cl_scope = scope.child(ScopeKind.CLASS) cl_scope = scope.child(ScopeKind.CLASS)
cl_scope.declare_local("Self", output.type_type()) cl_scope.declare_local("Self", output.type_type())
output.block_data = BlockData(node, scope) output.block_data = BlockData(node, scope)
visitor = StdlibVisitor(self.python_path, cl_scope, output) visitor = StdlibVisitor(self.python_path, cl_scope, output, self.is_native)
bases = [self.anno().visit(base) for base in node.bases] bases = [self.anno().visit(base) for base in node.bases]
match bases: match bases:
case []: case []:
...@@ -163,6 +164,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -163,6 +164,7 @@ class StdlibVisitor(NodeVisitorSeq):
output.return_type = arg_visitor.visit(node.returns) output.return_type = arg_visitor.visit(node.returns)
output.optional_at = len(node.args.args) - len(node.args.defaults) output.optional_at = len(node.args.args) - len(node.args.defaults)
output.is_variadic = args.vararg is not None output.is_variadic = args.vararg is not None
output.is_native = self.is_native
@dataclass(eq=False, init=False) @dataclass(eq=False, init=False)
class InstanceType(CallableInstanceType): class InstanceType(CallableInstanceType):
......
...@@ -78,7 +78,10 @@ class BaseType(ABC): ...@@ -78,7 +78,10 @@ class BaseType(ABC):
return (needle is haystack) or haystack.contains_internal(needle) return (needle is haystack) or haystack.contains_internal(needle)
def try_assign(self, other: "BaseType") -> bool: def try_assign(self, other: "BaseType") -> bool:
return self.resolve().try_assign_internal(other.resolve()) target, value = self.resolve(), other.resolve()
if type(value) == TypeVariable:
return BaseType.try_assign_internal(target, other)
return target.try_assign_internal(other)
def try_assign_internal(self, other: "BaseType") -> bool: def try_assign_internal(self, other: "BaseType") -> bool:
...@@ -489,6 +492,7 @@ class CallableInstanceType(GenericInstanceType, MethodType): ...@@ -489,6 +492,7 @@ class CallableInstanceType(GenericInstanceType, MethodType):
return_type: ConcreteType return_type: ConcreteType
optional_at: int = None optional_at: int = None
is_variadic: bool = False is_variadic: bool = False
is_native: bool = False
def __post_init__(self): def __post_init__(self):
if self.optional_at is None and self.parameters is not None: if self.optional_at is None and self.parameters is not None:
...@@ -560,6 +564,8 @@ def make_builtin_feature(name: str): ...@@ -560,6 +564,8 @@ def make_builtin_feature(name: str):
return TY_OPTIONAL return TY_OPTIONAL
case "Union": case "Union":
return TY_UNION return TY_UNION
case "Callable":
return TY_CALLABLE
case _: case _:
class CreatedType(BuiltinFeatureType): class CreatedType(BuiltinFeatureType):
def name(self): def name(self):
......
# coding: utf-8 # coding: utf-8
import ast
from pathlib import Path from pathlib import Path
import colorama import colorama
...@@ -14,14 +15,16 @@ from transpiler.phases.emit_cpp.module import emit_module ...@@ -14,14 +15,16 @@ from transpiler.phases.emit_cpp.module import emit_module
from transpiler.phases.if_main import IfMainVisitor from transpiler.phases.if_main import IfMainVisitor
from transpiler.phases.typing import PRELUDE from transpiler.phases.typing import PRELUDE
from transpiler.phases.typing.modules import parse_module from transpiler.phases.typing.modules import parse_module
from transpiler.phases.typing.stdlib import StdlibVisitor
TYPON_STD = Path(__file__).parent.parent / "stdlib"
def init(): def init():
error_display.init() error_display.init()
colorama.init() colorama.init()
typon_std = Path(__file__).parent.parent / "stdlib"
#discover_module(typon_std, PRELUDE.child(ScopeKind.GLOBAL)) #discover_module(typon_std, PRELUDE.child(ScopeKind.GLOBAL))
parse_module("builtins", typon_std, PRELUDE) parse_module("builtins", [TYPON_STD], PRELUDE)
...@@ -35,7 +38,7 @@ def transpile(source, name: str, path: Path): ...@@ -35,7 +38,7 @@ def transpile(source, name: str, path: Path):
node = DesugarOp().visit(node) node = DesugarOp().visit(node)
return node return node
module = parse_module(path.stem, path.parent, preprocess=preprocess) module = parse_module(path.stem, [path.parent, TYPON_STD], preprocess=preprocess)
def disp_scope(scope, indent=0): def disp_scope(scope, indent=0):
debug(" " * indent, scope.kind) debug(" " * indent, scope.kind)
...@@ -44,6 +47,8 @@ def transpile(source, name: str, path: Path): ...@@ -44,6 +47,8 @@ def transpile(source, name: str, path: Path):
for var in scope.vars.items(): for var in scope.vars.items():
debug(" " * (indent + 1), var) debug(" " * (indent + 1), var)
StdlibVisitor([], module.block_data.scope).expr().visit(ast.parse("main()", mode="eval").body)
def main_module(): def main_module():
yield from emit_module(module) yield from emit_module(module)
yield "#ifdef TYPON_EXTENSION" yield "#ifdef TYPON_EXTENSION"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment