Commit b738512a authored by Tom Niget's avatar Tom Niget

Builtin things work

parent 85f54768
.idea .idea
cmake-build-* cmake-build-*
typon.egg-info typon.egg-info
build
\ No newline at end of file
...@@ -20,6 +20,8 @@ public: ...@@ -20,6 +20,8 @@ public:
return sync_wrapper(std::forward<Args>(args)...); return sync_wrapper(std::forward<Args>(args)...);
} }
}; };
/* /*
struct method {}; struct method {};
......
...@@ -59,6 +59,8 @@ template <PySmartPtr T> struct RealType<T> { ...@@ -59,6 +59,8 @@ template <PySmartPtr T> struct RealType<T> {
using type = typename T::element_type; using type = typename T::element_type;
}; };
namespace typon {
//template <typename T> using TyObj = std::shared_ptr<typename RealType<T>::type>; //template <typename T> using TyObj = std::shared_ptr<typename RealType<T>::type>;
template<typename T> template<typename T>
...@@ -135,6 +137,11 @@ template <typename T, typename... Args> auto pyobj_agg(Args &&...args) -> TyObj< ...@@ -135,6 +137,11 @@ template <typename T, typename... Args> auto pyobj_agg(Args &&...args) -> TyObj<
return std::make_shared<typename RealType<T>::type>((typename RealType<T>::type) { std::forward<Args>(args)... }); return std::make_shared<typename RealType<T>::type>((typename RealType<T>::type) { std::forward<Args>(args)... });
} }
class TyNone {
};
}
// typon_len // typon_len
template <typename T> template <typename T>
...@@ -199,6 +206,7 @@ static constexpr auto PyNone = std::nullopt; ...@@ -199,6 +206,7 @@ static constexpr auto PyNone = std::nullopt;
#include "builtins/bool.hpp" #include "builtins/bool.hpp"
#include "builtins/complex.hpp" #include "builtins/complex.hpp"
#include "builtins/dict.hpp" #include "builtins/dict.hpp"
#include "builtins/int.hpp"
#include "builtins/list.hpp" #include "builtins/list.hpp"
#include "builtins/print.hpp" #include "builtins/print.hpp"
#include "builtins/range.hpp" #include "builtins/range.hpp"
...@@ -307,7 +315,7 @@ typon::Task<typon::PyFile> open(const TyStr &path, std::string_view mode) { ...@@ -307,7 +315,7 @@ typon::Task<typon::PyFile> open(const TyStr &path, std::string_view mode) {
std::cerr << path << "," << flags << std::endl; std::cerr << path << "," << flags << std::endl;
system_error(-fd, "openat()"); system_error(-fd, "openat()");
} }
co_return tyObj<typon::PyFile>(fd, len); co_return typon::tyObj<typon::PyFile>(fd, len);
} }
#include <typon/generator.hpp> #include <typon/generator.hpp>
...@@ -349,9 +357,9 @@ template<PySmartPtr T> ...@@ -349,9 +357,9 @@ template<PySmartPtr T>
auto& iter_fix_ref(T& obj) { return *obj; } auto& iter_fix_ref(T& obj) { return *obj; }
namespace std { namespace std {
template <class T> auto begin(std::shared_ptr<T> &obj) { return dotp(obj, begin)(); } template <class T> auto begin(std::shared_ptr<T> &obj) { return dot(obj, begin)(); }
template <class T> auto end(std::shared_ptr<T> &obj) { return dotp(obj, end)(); } template <class T> auto end(std::shared_ptr<T> &obj) { return dot(obj, end)(); }
} }
template <typename T> template <typename T>
...@@ -365,18 +373,33 @@ struct ValueTypeEx { ...@@ -365,18 +373,33 @@ struct ValueTypeEx {
using type = decltype(*std::begin(std::declval<Seq&>())); using type = decltype(*std::begin(std::declval<Seq&>()));
}; };
// (2) // (2)
template <typename Map, typename Seq, typename Filt = AlwaysTrue<typename ValueTypeEx<Seq>::type>> /*template <typename Map, typename Seq, typename Filt = AlwaysTrue<typename ValueTypeEx<Seq>::type>>
auto mapFilter(Map map, Seq seq, Filt filt = Filt()) { typon::Task< mapFilter(Map map, Seq seq, Filt filt = Filt()) {
//typedef typename Seq::value_type value_type; //typedef typename Seq::value_type value_type;
using value_type = typename ValueTypeEx<Seq>::type; using value_type = typename ValueTypeEx<Seq>::type;
using return_type = decltype(map(std::declval<value_type>())); using return_type = decltype(map(std::declval<value_type>()));
std::vector<return_type> result{}; std::vector<return_type> result{};
for (auto i : seq | std::views::filter(filt) for (auto i : seq) {
| std::views::transform(map)) result.push_back(i); if (co_await filt(i)) {
result.push_back(co_await map(i));
}
}
return typon::TyList(std::move(result)); return typon::TyList(std::move(result));
} }*/
#define MAP_FILTER(item, seq, map, filter) ({\
using value_type = typename ValueTypeEx<decltype(seq)>::type;\
value_type item;\
std::vector<decltype(map)> result{};\
for (auto item : seq) {\
if (filter) {\
result.push_back(map);\
}\
}\
typon::TyList(std::move(result));\
})
namespace PYBIND11_NAMESPACE { namespace PYBIND11_NAMESPACE {
namespace detail { namespace detail {
......
//
// Created by Tom on 08/03/2023.
//
#ifndef TYPON_INT_HPP
#define TYPON_INT_HPP
#include <sstream>
#include <string>
#include <algorithm>
using namespace std::literals;
#include "bytes.hpp"
#include "print.hpp"
#include "slice.hpp"
// #include <format>
#include <fmt/format.h>
#include <pybind11/cast.h>
namespace typon {
/*template <typename _Base0 = object>
class TyInt__oo : classtype<_Base0, Integer__oo<>> {
public:
struct : method {
auto operator()(auto self, int value) const {
self->value = value;
}
} static constexpr oo__init__oo {};
struct : method {
auto operator()(auto self, auto other) const {
return Integer(dot(self, value) + dot(other, value));
}
} static constexpr oo__add__oo {};
auto operator () (int value) const {
struct Obj : instance<Integer__oo<>, Obj> {
int value;
};
auto obj = rc(Obj{});
dot(obj, oo__init__oo)(value);
return obj;
}
constexpr TyInt(int value) : value(value) {}
constexpr TyInt() : value(0) {}
operator int() const { return value; }
// operators
template <typename T> TyInt operator+(T x) const { return value + x; }
template <typename T> TyInt operator-(T x) const { return value - x; }
template <typename T> TyInt operator*(T x) const { return value * x; }
template <typename T> TyInt operator/(T x) const { return value / x; }
template <typename T> TyInt operator%(T x) const { return value % x; }
template <typename T> TyInt operator&(T x) const { return value & x; }
template <typename T> TyInt operator|(T x) const { return value | x; }
template <typename T> TyInt operator^(T x) const { return value ^ x; }
template <typename T> TyInt operator<<(T x) const { return value << x; }
template <typename T> TyInt operator>>(T x) const { return value >> x; }
template <typename T> TyInt operator&&(T x) const { return value && x; }
template <typename T> TyInt operator||(T x) const { return value || x; }
template <typename T> TyInt operator==(T x) const { return value == x; }
template <typename T> TyInt operator<(T x) const { return value < x; }
TyInt operator-() const { return -value; }
private:
int value;
};*/
using namespace referencemodel;
template <typename _Base0 = object>
struct Integer__oo : classtype<_Base0, Integer__oo<>> {
static constexpr std::string_view name = "Integer";
struct : method {
auto operator()(auto self, int value) const {
self->value = value;
}
} static constexpr oo__init__oo {};
struct : method {
auto operator()(auto self, auto other) const {
return TyInt(dot(self, value) + dot(other, value));
}
} static constexpr oo__add__oo {};
struct Obj : value<Integer__oo<>, Obj> {
int value;
Obj(int value=0) : value(value) {}
operator int() const { return value; }
};
auto operator () (int value) const {
auto obj = rc(Obj{});
dot(obj, oo__init__oo)(value);
return obj;
}
};
static constexpr Integer__oo<> TyInt {};
}
inline auto operator ""_pi(unsigned long long int v) noexcept {
return typon::TyInt(v);
}
template <> struct std::hash<decltype(0_pi)> {
std::size_t operator()(const decltype(0_pi) &s) const noexcept {
return std::hash<int>()(s);
}
};
namespace PYBIND11_NAMESPACE {
namespace detail {
template<>
struct type_caster<decltype(0_pi)>
: type_caster<int> {};
}}
template <> void repr_to(const decltype(0_pi) &x, std::ostream &s) {
s << x;
}
template <> void print_to<decltype(0_pi)>(const decltype(0_pi) &x, std::ostream &s) { s << x; }
#endif // TYPON_INT_HPP
...@@ -86,13 +86,14 @@ typon::Task<void> print(T const &head, Args const &...args) { ...@@ -86,13 +86,14 @@ typon::Task<void> print(T const &head, Args const &...args) {
}*/ }*/
struct { struct {
void operator()() { std::cout << '\n'; } typon::TyNone operator()() { std::cout << '\n'; return {}; }
template <Printable T, Printable... Args> template <Printable T, Printable... Args>
void operator()(T const &head, Args const &...args) { typon::TyNone operator()(T const &head, Args const &...args) {
print_to(head, std::cout); print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...); (((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n'; std::cout << '\n';
return {};
} }
} print; } print;
// typon::Task<void> print() { std::cout << '\n'; co_return; } // typon::Task<void> print() { std::cout << '\n'; co_return; }
......
...@@ -18,11 +18,7 @@ auto stride = [](int n) { ...@@ -18,11 +18,7 @@ auto stride = [](int n) {
// todo: proper range support // todo: proper range support
struct range_s : TyBuiltin<range_s> struct range_s : TyBuiltin<range_s>
{ {
template <typename T> auto sync(int start, int stop, int step = 1) {
auto sync(T stop) { return sync(0, stop); }
template <typename T>
auto sync(T start, T stop, T step = 1) {
// https://www.modernescpp.com/index.php/c-20-pythons-map-function/ // https://www.modernescpp.com/index.php/c-20-pythons-map-function/
if(step == 0) { if(step == 0) {
throw std::invalid_argument("Step cannot be 0"); throw std::invalid_argument("Step cannot be 0");
...@@ -36,6 +32,10 @@ struct range_s : TyBuiltin<range_s> ...@@ -36,6 +32,10 @@ struct range_s : TyBuiltin<range_s>
return start < stop ? i : stop - (i - start); return start < stop ? i : stop - (i - start);
}); });
} }
auto sync(int stop) { return sync(0, stop); }
} range; } range;
#endif // TYPON_RANGE_HPP #endif // TYPON_RANGE_HPP
...@@ -23,6 +23,9 @@ struct object; ...@@ -23,6 +23,9 @@ struct object;
template <typename T, typename O> template <typename T, typename O>
struct instance; struct instance;
template <typename T, typename O>
struct value;
template <typename B, typename T> template <typename B, typename T>
struct classtype; struct classtype;
...@@ -168,6 +171,10 @@ using unwrap_pack = typename unwrap_pack_s<std::remove_cvref_t<T>>::type; ...@@ -168,6 +171,10 @@ using unwrap_pack = typename unwrap_pack_s<std::remove_cvref_t<T>>::type;
/* Meta-programming utilities for object model */ /* Meta-programming utilities for object model */
template <typename T>
concept object = std::derived_from<unwrap_all<T>, referencemodel::object>;
template <typename T> template <typename T>
concept instance = std::derived_from< concept instance = std::derived_from<
unwrap_all<T>, unwrap_all<T>,
...@@ -215,7 +222,26 @@ concept boundmethod = boundmethod_s<std::remove_cvref_t<T>>::value; ...@@ -215,7 +222,26 @@ concept boundmethod = boundmethod_s<std::remove_cvref_t<T>>::value;
template <typename T> template <typename T>
concept value = !instance<T> && !boundmethod<T>; struct value_s {
static constexpr bool value = true;
};
template <instance T>
struct value_s<T> {
static constexpr bool value = std::derived_from<
unwrap_all<T>,
referencemodel::value<typename unwrap_all<T>::type, unwrap_all<T>>
>;
};
template <typename S, typename F>
struct value_s<referencemodel::boundmethod<S, F>> {
static constexpr bool value = value_s<S>::value;
};
template <typename T>
concept value = value_s<std::remove_cvref_t<T>>::value;
/* Meta-programming utilities: wrapped and unwrapped */ /* Meta-programming utilities: wrapped and unwrapped */
...@@ -963,6 +989,10 @@ struct instance : T { ...@@ -963,6 +989,10 @@ struct instance : T {
}; };
template <typename T, typename O>
struct value : instance<T, O> {};
template <typename B, typename T> template <typename B, typename T>
struct classtype : B { struct classtype : B {
using base = B; using base = B;
...@@ -1193,14 +1223,58 @@ decltype(auto) bind(S &&, const A & attr, T...) { ...@@ -1193,14 +1223,58 @@ decltype(auto) bind(S &&, const A & attr, T...) {
return attr; return attr;
} }
} // namespace referencemodel
#define dot(OBJ, NAME)\ #define dot(OBJ, NAME)\
[](auto && obj) -> decltype(auto) {\ [](auto && obj) -> decltype(auto) {\
return referencemodel::bind(std::forward<decltype(obj)>(obj), obj->NAME);\ return referencemodel::bind(std::forward<decltype(obj)>(obj), obj->NAME);\
}(OBJ) }(OBJ)
/* Operators */
namespace meta {
/* + */
template <typename Left, typename Right>
concept LeftAddable = requires (Left left, Right right) {
// note: using dot here would cause hard failure instead of invalid constraint
left->oo__add__oo(left, right);
};
template <typename Left, typename Right>
concept RightAddable = requires (Left left, Right right) {
// note: using dot here would cause hard failure instead of invalid constraint
right->oo__radd__oo(right, left);
};
template <typename Left, typename Right>
concept Addable = LeftAddable<Left, Right> || RightAddable<Left, Right>;
}
/* + */
template <meta::object Left, meta::object Right>
requires meta::Addable<Left, Right>
auto operator + (Left && left, Right && right) {
if constexpr (meta::LeftAddable<Left, Right>) {
return dot(std::forward<Left>(left), oo__add__oo)(
std::forward<Right>(right));
}
else {
if constexpr (meta::RightAddable<Left, Right>) {
return dot(std::forward<Right>(right), oo__radd__oo)(
std::forward<Left>(left));
}
}
}
} // namespace referencemodel
#endif // REFERENCEMODEL_H #endif // REFERENCEMODEL_H
Subproject commit 79677d125f915f7c61492d8d1d8cde9fc6a11875 Subproject commit 26c320d77d368d0f482684ebd653d1702c75a7f2
...@@ -3,8 +3,8 @@ from typing import Self, Protocol, Optional ...@@ -3,8 +3,8 @@ from typing import Self, Protocol, Optional
assert 5 assert 5
class object: class object:
def __eq__(self, other: Self) -> bool: ... def __eq__[T](self, other: T) -> bool: ...
def __ne__(self, other: Self) -> bool: ... def __ne__[T](self, other: T) -> bool: ...
class int: class int:
def __add__(self, other: Self) -> Self: ... def __add__(self, other: Self) -> Self: ...
...@@ -96,6 +96,10 @@ assert [].__getitem__ ...@@ -96,6 +96,10 @@ assert [].__getitem__
assert [4].__getitem__ assert [4].__getitem__
assert [1, 2, 3][1] assert [1, 2, 3][1]
class set[U]:
def __len__(self) -> int: ...
def __contains__(self, item: U) -> bool: ...
def iter[U](x: Iterable[U]) -> Iterator[U]: def iter[U](x: Iterable[U]) -> Iterator[U]:
... ...
......
...@@ -3,3 +3,4 @@ ...@@ -3,3 +3,4 @@
Protocol = BuiltinFeature["Protocol"] Protocol = BuiltinFeature["Protocol"]
Self = BuiltinFeature["Self"] Self = BuiltinFeature["Self"]
Optional = BuiltinFeature["Optional"] Optional = BuiltinFeature["Optional"]
Callable = BuiltinFeature["Callable"]
\ No newline at end of file
...@@ -88,15 +88,15 @@ def run_test(path, quiet=True): ...@@ -88,15 +88,15 @@ def run_test(path, quiet=True):
if args.compile: if args.compile:
return TestStatus.SUCCESS return TestStatus.SUCCESS
execute_str = "true" if (execute and not args.generate) else "false" execute_str = "true" if (execute and not args.generate) else "false"
name_bin = path.with_suffix("").as_posix() + ("$(python3-config --extension-suffix)" if extension else ".exe") name_bin = path.with_suffix("").as_posix() + ("$(python3.12-config --extension-suffix)" if extension else ".exe")
if exec_cmd(f'bash -c "export PYTHONPATH=stdlib; if {execute_str}; then python3 ./{path.as_posix()}; fi"') != 0: if exec_cmd(f'bash -c "export PYTHONPATH=stdlib; if {execute_str}; then echo python3.12 ./{path.as_posix()}; fi"') != 0:
return TestStatus.PYTHON_ERROR return TestStatus.PYTHON_ERROR
if compile and (alt := environ.get("ALT_RUNNER")): if compile and (alt := environ.get("ALT_RUNNER")):
if (code := exec_cmd(alt.format( if (code := exec_cmd(alt.format(
name_bin=name_bin, name_bin=name_bin,
name_cpp_posix=name_cpp.as_posix(), name_cpp_posix=name_cpp.as_posix(),
run_file=execute_str, run_file=execute_str,
test_exec=f"python3 {path.with_suffix('.post.py').as_posix()}" if extension else name_bin, test_exec=f"python3.12 {path.with_suffix('.post.py').as_posix()}" if extension else name_bin,
bonus_flags="-e" if extension else "" bonus_flags="-e" if extension else ""
))) != 0: ))) != 0:
return TestStatus(code) return TestStatus(code)
......
...@@ -4,9 +4,9 @@ from typon import is_cpp ...@@ -4,9 +4,9 @@ from typon import is_cpp
import sys as sis import sys as sis
from sys import stdout as truc from sys import stdout as truc
foo = 123 # foo = 123
test = (2 + 3) * 4 # test = (2 + 3) * 4
glob = 5 # glob = 5
# def g(): # def g():
# a = 8 # a = 8
...@@ -20,32 +20,37 @@ glob = 5 ...@@ -20,32 +20,37 @@ glob = 5
# e = d + 1 # e = d + 1
# print(e) # print(e)
def f(x): # def f(x):
return x + 1 # return x + 1
#
#
def fct(param: int): # def fct(param: int):
loc = f(456) # loc = f(456)
global glob # global glob
loc = 789 # loc = 789
glob = 123 # glob = 123
#
def fct2(): # def fct2():
global glob # global glob
glob += 5 # glob += 5
if __name__ == "__main__": if __name__ == "__main__":
print(is_cpp) print("is c++:", is_cpp())
# TODO: doesn't compile under G++ 12.2, fixed in trunk on March 15 # TODO: doesn't compile under G++ 12.2, fixed in trunk on March 15
# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=98056 # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=98056
sum = 0 sum = 0
for i in range(15): for i in range(15):
sum += i sum = sum + i
a = [n for n in range(10)] a = [n for n in range(10)]
b = [x for x in a if x % 2 == 0] #b = [x for x in a if x % 2 == 0]
c = [y * y for y in b] #c = [y * y for y in b]
print("C++ " if is_cpp() else "Python", print("C++ " if is_cpp() else "Python",
"res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5, "res=", 5, ".", True, [4, 5, 6], {7, 8, 9},
3j, sum, a, b, c) #[1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3},
#0x55 & 7 == 5,
#3j,
sum,
a)
print("Typon")
print() print()
...@@ -2,10 +2,15 @@ ...@@ -2,10 +2,15 @@
import ast import ast
import builtins import builtins
import importlib import importlib
import inspect
import sys import sys
import traceback import traceback
import colorful as cf import colorful as cf
from transpiler.exceptions import CompileError
from transpiler.utils import highlight
def exception_hook(exc_type, exc_value, tb): def exception_hook(exc_type, exc_value, tb):
print = lambda *args, **kwargs: builtins.print(*args, **kwargs, file=sys.stderr) print = lambda *args, **kwargs: builtins.print(*args, **kwargs, file=sys.stderr)
last_node = None last_node = None
...@@ -49,8 +54,12 @@ def exception_hook(exc_type, exc_value, tb): ...@@ -49,8 +54,12 @@ def exception_hook(exc_type, exc_value, tb):
return return
print(f"In file {cf.white(last_file)}:{last_node.lineno}") print(f"In file {cf.white(last_file)}:{last_node.lineno}")
#print(f"From {last_node.lineno}:{last_node.col_offset} to {last_node.end_lineno}:{last_node.end_col_offset}") #print(f"From {last_node.lineno}:{last_node.col_offset} to {last_node.end_lineno}:{last_node.end_col_offset}")
try:
with open(last_file, "r", encoding="utf-8") as f: with open(last_file, "r", encoding="utf-8") as f:
code = f.read() code = f.read()
except Exception:
pass
else:
hg = (str(highlight(code, True)) hg = (str(highlight(code, True))
.replace("\x1b[04m", "") .replace("\x1b[04m", "")
.replace("\x1b[24m", "") .replace("\x1b[24m", "")
......
...@@ -24,7 +24,7 @@ DUNDER = { ...@@ -24,7 +24,7 @@ DUNDER = {
class DesugarOp(ast.NodeTransformer): class DesugarOp(ast.NodeTransformer):
def visit_BinOp(self, node: ast.BinOp): def visit_BinOp(self, node: ast.BinOp):
lnd = linenodata(node) lnd = linenodata(node)
return ast.Call( res = ast.Call(
func=ast.Attribute( func=ast.Attribute(
value=self.visit(node.left), value=self.visit(node.left),
attr=f"__{DUNDER[type(node.op)]}__", attr=f"__{DUNDER[type(node.op)]}__",
...@@ -35,16 +35,19 @@ class DesugarOp(ast.NodeTransformer): ...@@ -35,16 +35,19 @@ class DesugarOp(ast.NodeTransformer):
keywords={}, keywords={},
**lnd **lnd
) )
res.orig_node = node
return res
def visit_UnaryOp(self, node: ast.UnaryOp): def visit_UnaryOp(self, node: ast.UnaryOp):
lnd = linenodata(node) lnd = linenodata(node)
if type(node.op) == ast.Not: if type(node.op) == ast.Not:
return ast.UnaryOp( res = ast.UnaryOp(
operand=self.visit(node.operand), operand=self.visit(node.operand),
op=node.op, op=node.op,
**lnd **lnd
) )
return ast.Call( else:
res = ast.Call(
func=ast.Attribute( func=ast.Attribute(
value=self.visit(node.operand), value=self.visit(node.operand),
attr=f"__{DUNDER[type(node.op)]}__", attr=f"__{DUNDER[type(node.op)]}__",
...@@ -55,6 +58,8 @@ class DesugarOp(ast.NodeTransformer): ...@@ -55,6 +58,8 @@ class DesugarOp(ast.NodeTransformer):
keywords={}, keywords={},
**lnd **lnd
) )
res.orig_node = node
return res
# def visit_AugAssign(self, node: ast.AugAssign): # def visit_AugAssign(self, node: ast.AugAssign):
# return # return
import ast
from typing import Iterable from typing import Iterable
def emit_class(clazz) -> Iterable[str]: def emit_class(node: ast.ClassDef) -> Iterable[str]:
yield f"template <typename _Base0 = referencemodel::object>" yield f"template <typename _Base0 = referencemodel::object>"
yield f"struct {node.name}__oo : referencemodel::classtype<_Base0, {node.name}__oo<>> {{" yield f"struct {node.name}__oo : referencemodel::classtype<_Base0, {node.name}__oo<>> {{"
yield f"static constexpr std::string_view name = \"{node.name}\";" yield f"static constexpr std::string_view name = \"{node.name}\";"
......
...@@ -5,8 +5,35 @@ from typing import Iterable ...@@ -5,8 +5,35 @@ from typing import Iterable
from transpiler.phases.emit_cpp.visitors import NodeVisitor, CoroutineMode, join from transpiler.phases.emit_cpp.visitors import NodeVisitor, CoroutineMode, join
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.utils import make_lnd
from transpiler.utils import linenodata
SYMBOLS = {
ast.Eq: "==",
ast.NotEq: '!=',
ast.Pass: '/* pass */',
ast.Mult: '*',
ast.Add: '+',
ast.Sub: '-',
ast.Div: '/',
ast.FloorDiv: '/', # TODO
ast.Mod: '%',
ast.Lt: '<',
ast.Gt: '>',
ast.GtE: '>=',
ast.LtE: '<=',
ast.LShift: '<<',
ast.RShift: '>>',
ast.BitXor: '^',
ast.BitOr: '|',
ast.BitAnd: '&',
ast.Not: '!',
ast.IsNot: '!=',
ast.USub: '-',
ast.And: '&&',
ast.Or: '||'
}
"""Mapping of Python AST nodes to C++ symbols."""
# noinspection PyPep8Naming # noinspection PyPep8Naming
@dataclass @dataclass
...@@ -33,7 +60,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -33,7 +60,7 @@ class ExpressionVisitor(NodeVisitor):
yield str(node.value).lower() yield str(node.value).lower()
elif isinstance(node.value, int): elif isinstance(node.value, int):
# TODO: bigints # TODO: bigints
yield str(node.value) yield str(node.value) + "_pi"
elif isinstance(node.value, float): elif isinstance(node.value, float):
yield repr(node.value) yield repr(node.value)
elif isinstance(node.value, complex): elif isinstance(node.value, complex):
...@@ -44,7 +71,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -44,7 +71,7 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, type(node)) raise NotImplementedError(node, type(node))
def visit_Slice(self, node: ast.Slice) -> Iterable[str]: def visit_Slice(self, node: ast.Slice) -> Iterable[str]:
yield "TySlice(" yield "typon::TySlice("
yield from join(", ", (self.visit(x or ast.Constant(value=None)) for x in (node.lower, node.upper, node.step))) yield from join(", ", (self.visit(x or ast.Constant(value=None)) for x in (node.lower, node.upper, node.step)))
yield ")" yield ")"
...@@ -79,22 +106,22 @@ class ExpressionVisitor(NodeVisitor): ...@@ -79,22 +106,22 @@ class ExpressionVisitor(NodeVisitor):
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right)) # yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]: def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
raise NotImplementedError() if len(node.values) == 1:
yield from self.visit(node.values[0])
# if len(node.values) == 1: return
# yield from self.visit(node.values[0]) cpp_op = {
# return ast.And: "&&",
# cpp_op = { ast.Or: "||"
# ast.And: "&&", }[type(node.op)]
# ast.Or: "||" yield "("
# }[type(node.op)] yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
# with self.prec_ctx(cpp_op): for left, right in zip(node.values[1:], node.values[2:]):
# yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1])) yield f" {cpp_op} "
# for left, right in zip(node.values[1:], node.values[2:]): yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
# yield f" {cpp_op} " yield ")"
# yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]: def visit_Call(self, node: ast.Call) -> Iterable[str]:
yield "co_await"
yield "(" yield "("
yield from self.visit(node.func) yield from self.visit(node.func)
yield ")(" yield ")("
...@@ -158,6 +185,8 @@ class ExpressionVisitor(NodeVisitor): ...@@ -158,6 +185,8 @@ class ExpressionVisitor(NodeVisitor):
templ, args, _ = self.process_args(node.args) templ, args, _ = self.process_args(node.args)
yield templ yield templ
yield args yield args
yield "->"
yield from self.visit(node.type.deref().return_type)
yield "{" yield "{"
yield "return" yield "return"
yield from self.reset().visit(node.body) yield from self.reset().visit(node.body)
...@@ -165,14 +194,20 @@ class ExpressionVisitor(NodeVisitor): ...@@ -165,14 +194,20 @@ class ExpressionVisitor(NodeVisitor):
yield "}" yield "}"
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]: def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node)) yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node))
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node)) yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node))
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]: def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]:
yield "(co_await ("
yield from self.visit(left)
yield " "
yield SYMBOLS[type(op)]
yield " "
yield from self.visit(right)
yield "))"
return
raise NotImplementedError() raise NotImplementedError()
# if type(op) == ast.In: # if type(op) == ast.In:
# call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd) # call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
...@@ -205,7 +240,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -205,7 +240,7 @@ class ExpressionVisitor(NodeVisitor):
def visit_List(self, node: ast.List) -> Iterable[str]: def visit_List(self, node: ast.List) -> Iterable[str]:
if node.elts: if node.elts:
yield "typon::TyList{" yield "typon::TyList{"
yield from join(", ", map(self.reset().visit, node.elts)) yield from join(", ", map(self.visit, node.elts))
yield "}" yield "}"
else: else:
yield from self.visit(node.type) yield from self.visit(node.type)
...@@ -214,7 +249,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -214,7 +249,7 @@ class ExpressionVisitor(NodeVisitor):
def visit_Set(self, node: ast.Set) -> Iterable[str]: def visit_Set(self, node: ast.Set) -> Iterable[str]:
if node.elts: if node.elts:
yield "typon::TySet{" yield "typon::TySet{"
yield from join(", ", map(self.reset().visit, node.elts)) yield from join(", ", map(self.visit, node.elts))
yield "}" yield "}"
else: else:
yield from self.visit(node.type) yield from self.visit(node.type)
...@@ -223,9 +258,9 @@ class ExpressionVisitor(NodeVisitor): ...@@ -223,9 +258,9 @@ class ExpressionVisitor(NodeVisitor):
def visit_Dict(self, node: ast.Dict) -> Iterable[str]: def visit_Dict(self, node: ast.Dict) -> Iterable[str]:
def visit_item(key, value): def visit_item(key, value):
yield "std::pair {" yield "std::pair {"
yield from self.reset().visit(key) yield from self.visit(key)
yield ", " yield ", "
yield from self.reset().visit(value) yield from self.visit(value)
yield "}" yield "}"
if node.keys: if node.keys:
...@@ -256,12 +291,13 @@ class ExpressionVisitor(NodeVisitor): ...@@ -256,12 +291,13 @@ class ExpressionVisitor(NodeVisitor):
yield from self.prec("unary").visit(operand) yield from self.prec("unary").visit(operand)
def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]: def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]:
with self.prec_ctx("?:"): yield "("
yield from self.visit(node.test) yield from self.visit(node.test)
yield " ? " yield " ? "
yield from self.visit(node.body) yield from self.visit(node.body)
yield " : " yield " : "
yield from self.visit(node.orelse) yield from self.visit(node.orelse)
yield ")"
def visit_Yield(self, node: ast.Yield) -> Iterable[str]: def visit_Yield(self, node: ast.Yield) -> Iterable[str]:
#if CoroutineMode.GENERATOR in self.generator: #if CoroutineMode.GENERATOR in self.generator:
...@@ -279,8 +315,23 @@ class ExpressionVisitor(NodeVisitor): ...@@ -279,8 +315,23 @@ class ExpressionVisitor(NodeVisitor):
if len(node.generators) != 1: if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet") raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0] gen: ast.comprehension = node.generators[0]
yield "MAP_FILTER("
yield from self.visit(gen.target)
yield ","
yield from self.visit(gen.iter)
yield ", "
yield from self.visit(node.elt)
yield ", "
if gen.ifs:
yield from self.visit(gen.ifs_node)
else:
yield "true"
yield ")"
return
yield "mapFilter([](" yield "mapFilter([]("
yield from self.visit(node.input_item_type) #yield from self.visit(node.input_item_type)
yield "auto"
yield from self.visit(gen.target) yield from self.visit(gen.target)
yield ") { return " yield ") { return "
yield from self.visit(node.elt) yield from self.visit(node.elt)
...@@ -289,9 +340,12 @@ class ExpressionVisitor(NodeVisitor): ...@@ -289,9 +340,12 @@ class ExpressionVisitor(NodeVisitor):
if gen.ifs: if gen.ifs:
yield ", " yield ", "
yield "[](" yield "[]("
yield from self.visit(node.input_item_type) #yield from self.visit(node.input_item_type)
yield "auto"
yield from self.visit(gen.target) yield from self.visit(gen.target)
yield ") { return " yield ") -> typon::Task<"
yield from self.visit(gen.ifs_node.type)
yield "> { return "
yield from self.visit(gen.ifs_node) yield from self.visit(gen.ifs_node)
yield "; }" yield "; }"
yield ")" yield ")"
......
...@@ -3,6 +3,7 @@ from dataclasses import dataclass, field ...@@ -3,6 +3,7 @@ from dataclasses import dataclass, field
from typing import Iterable, Optional from typing import Iterable, Optional
from transpiler.phases.emit_cpp.expr import ExpressionVisitor from transpiler.phases.emit_cpp.expr import ExpressionVisitor
from transpiler.phases.typing.common import IsDeclare
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp.visitors import NodeVisitor, flatmap, CoroutineMode from transpiler.phases.emit_cpp.visitors import NodeVisitor, flatmap, CoroutineMode
from transpiler.phases.typing.types import CallableInstanceType, BaseType from transpiler.phases.typing.types import CallableInstanceType, BaseType
...@@ -10,7 +11,7 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType ...@@ -10,7 +11,7 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType
def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]: def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]:
yield f"struct : referencemodel::function {{" yield f"struct : referencemodel::function {{"
yield "typon::Task<void> operator()(" yield "typon::Task<typon::TyNone> operator()("
for arg, ty in zip(func.block_data.node.args.args, func.parameters): for arg, ty in zip(func.block_data.node.args.args, func.parameters):
yield "auto " yield "auto "
...@@ -151,32 +152,36 @@ class BlockVisitor(NodeVisitor): ...@@ -151,32 +152,36 @@ class BlockVisitor(NodeVisitor):
# #
# yield "}" # yield "}"
# #
# def visit_lvalue(self, lvalue: ast.expr, declare: bool | list[bool] = False) -> Iterable[str]: def visit_lvalue(self, lvalue: ast.expr, declare: IsDeclare) -> Iterable[str]:
# if isinstance(lvalue, ast.Tuple): if isinstance(lvalue, ast.Tuple):
raise NotImplementedError()
# for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args): # for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args):
# if decl: # if decl:
# yield from self.visit_lvalue(name, True) # yield from self.visit_lvalue(name, True)
# yield ";" # yield ";"
# yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})" # yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})"
# elif isinstance(lvalue, ast.Name): elif isinstance(lvalue, ast.Name):
# if lvalue.id == "_": if lvalue.id == "_":
# if not declare: if not declare:
# yield "std::ignore" yield "std::ignore"
# return return
# name = self.fix_name(lvalue.id) name = self.fix_name(lvalue.id)
# # if name not in self._scope.vars: # if name not in self._scope.vars:
# # if not self.scope.exists_local(name): # if not self.scope.exists_local(name):
# # yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None, # yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None,
# # getattr(val, "is_future", False)) # getattr(val, "is_future", False))
# if declare: if declare:
# yield from self.visit(lvalue.type) yield "decltype("
# yield name yield from self.expr().visit(declare.initial_value)
# elif isinstance(lvalue, ast.Subscript): yield ")"
# yield from self.expr().visit(lvalue) #yield from self.visit(lvalue.type)
# elif isinstance(lvalue, ast.Attribute): yield name
# yield from self.expr().visit(lvalue) elif isinstance(lvalue, ast.Subscript):
# else: yield from self.expr().visit(lvalue)
# raise NotImplementedError(lvalue) elif isinstance(lvalue, ast.Attribute):
yield from self.expr().visit(lvalue)
else:
raise NotImplementedError(lvalue)
def visit_Assign(self, node: ast.Assign) -> Iterable[str]: def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1: if len(node.targets) != 1:
...@@ -192,3 +197,27 @@ class BlockVisitor(NodeVisitor): ...@@ -192,3 +197,27 @@ class BlockVisitor(NodeVisitor):
yield " = " yield " = "
yield from self.expr().visit(node.value) yield from self.expr().visit(node.value)
yield ";" yield ";"
def visit_For(self, node: ast.For) -> Iterable[str]:
if not isinstance(node.target, ast.Name):
raise NotImplementedError(node)
if node.orelse:
yield "auto"
yield node.orelse_variable
yield "= true;"
yield f"for (auto {node.target.id} : "
yield from self.expr().visit(node.iter)
yield ")"
yield from self.emit_block(node.inner_scope, node.body) # TODO: why not reuse the scope used for analysis? same in while
if node.orelse:
yield "if ("
yield node.orelse_variable
yield ")"
yield from self.emit_block(node.inner_scope, node.orelse)
def emit_block(self, scope: Scope, items: Iterable[ast.stmt]) -> Iterable[str]:
yield "{"
for child in items:
yield from BlockVisitor(scope, generator=self.generator).visit(child)
yield "}"
...@@ -15,6 +15,9 @@ class UniversalVisitor: ...@@ -15,6 +15,9 @@ class UniversalVisitor:
__TB__ = f"emitting C++ code for {highlight(node)}" __TB__ = f"emitting C++ code for {highlight(node)}"
# __TB_SKIP__ = True # __TB_SKIP__ = True
if orig := getattr(node, "orig_node", None):
node = orig
if type(node) == list: if type(node) == list:
for n in node: for n in node:
yield from self.visit(n) yield from self.visit(n)
...@@ -56,17 +59,28 @@ class NodeVisitor(UniversalVisitor): ...@@ -56,17 +59,28 @@ class NodeVisitor(UniversalVisitor):
match node: match node:
case types.TY_INT: case types.TY_INT:
yield "int" yield "decltype(0_pi)"
case types.TY_FLOAT: case types.TY_FLOAT:
yield "double" yield "double"
case types.TY_BOOL: case types.TY_BOOL:
yield "bool" yield "bool"
case types.TY_NONE: case types.TY_NONE:
yield "void" yield "typon::TyNone"
case types.TY_STR: case types.TY_STR:
yield "TyStr" yield "typon::TyStr"
case types.TypeVariable(name): case types.TypeVariable(name):
raise UnresolvedTypeVariableError(node) raise UnresolvedTypeVariableError(node)
case types.GenericInstanceType():
yield from self.visit(node.generic_parent)
yield "<"
yield from join(",", map(self.visit, node.generic_args))
yield ">"
case types.TY_LIST:
yield "typon::TyList"
case types.TY_DICT:
yield "typon::TyDict"
case types.TY_SET:
yield "typon::TySet"
case _: case _:
raise NotImplementedError(node) raise NotImplementedError(node)
......
# coding: utf-8 # coding: utf-8
import ast # import ast
from dataclasses import dataclass, field # from dataclasses import dataclass, field
#
from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE # from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE
from transpiler.phases.typing.common import ScoperVisitor # from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import PromiseKind, Promise, BaseType, MemberDef # from transpiler.phases.typing.types import PromiseKind, Promise, BaseType, MemberDef
#
#
@dataclass # @dataclass
class ScoperClassVisitor(ScoperVisitor): # class ScoperClassVisitor(ScoperVisitor):
fdecls: list[(ast.FunctionDef, BaseType)] = field(default_factory=list) # fdecls: list[(ast.FunctionDef, BaseType)] = field(default_factory=list)
#
def visit_AnnAssign(self, node: ast.AnnAssign): # def visit_AnnAssign(self, node: ast.AnnAssign):
assert node.value is None, "Class field should not have a value" # assert node.value is None, "Class field should not have a value"
assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)" # assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)"
assert isinstance(node.target, ast.Name) # assert isinstance(node.target, ast.Name)
self.scope.obj_type.fields[node.target.id] = MemberDef(self.visit_annotation(node.annotation)) # self.scope.obj_type.fields[node.target.id] = MemberDef(self.visit_annotation(node.annotation))
#
def visit_Assign(self, node: ast.Assign): # def visit_Assign(self, node: ast.Assign):
assert len(node.targets) == 1, "Can't use destructuring in class static member" # assert len(node.targets) == 1, "Can't use destructuring in class static member"
assert isinstance(node.targets[0], ast.Name) # assert isinstance(node.targets[0], ast.Name)
node.is_declare = True # node.is_declare = True
valtype = self.expr().visit(node.value) # valtype = self.expr().visit(node.value)
node.targets[0].type = valtype # node.targets[0].type = valtype
self.scope.obj_type.fields[node.targets[0].id] = MemberDef(valtype, node.value) # self.scope.obj_type.fields[node.targets[0].id] = MemberDef(valtype, node.value)
#
def visit_FunctionDef(self, node: ast.FunctionDef): # def visit_FunctionDef(self, node: ast.FunctionDef):
ftype = self.parse_function(node) # ftype = self.parse_function(node)
ftype.parameters[0].unify(self.scope.obj_type) # ftype.parameters[0].unify(self.scope.obj_type)
inner = ftype.return_type # inner = ftype.return_type
if node.name != "__init__": # if node.name != "__init__":
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK) # ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
ftype.is_method = True # ftype.is_method = True
self.scope.obj_type.fields[node.name] = MemberDef(ftype, node) # self.scope.obj_type.fields[node.name] = MemberDef(ftype, node)
return (node, inner) # return (node, inner)
import ast import ast
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional from typing import Dict, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.utils import highlight from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl, VarKind from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl, VarKind
...@@ -10,6 +12,7 @@ from transpiler.phases.utils import NodeVisitorSeq, AnnotationName ...@@ -10,6 +12,7 @@ from transpiler.phases.utils import NodeVisitorSeq, AnnotationName
PRELUDE = Scope.make_global() PRELUDE = Scope.make_global()
@dataclass @dataclass
class ScoperVisitor(NodeVisitorSeq): class ScoperVisitor(NodeVisitorSeq):
scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL)) scope: Scope = field(default_factory=lambda: PRELUDE.child(ScopeKind.GLOBAL))
...@@ -93,10 +96,11 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -93,10 +96,11 @@ class ScoperVisitor(NodeVisitorSeq):
elif len(visitor.fdecls) == 1: elif len(visitor.fdecls) == 1:
fnode, frtype = visitor.fdecls[0] fnode, frtype = visitor.fdecls[0]
self.visit_function_definition(fnode, frtype) self.visit_function_definition(fnode, frtype)
#del node.inner_scope.vars[fnode.name] # del node.inner_scope.vars[fnode.name]
visitor.visit_assign_target(ast.Name(fnode.name), fnode.type) visitor.visit_assign_target(ast.Name(fnode.name), fnode.type)
b.decls = decls b.decls = decls
if not node.inner_scope.diverges and not (isinstance(node.type.return_type, Promise) and node.type.return_type.kind == PromiseKind.GENERATOR): if not node.inner_scope.diverges and not (
isinstance(node.type.return_type, Promise) and node.type.return_type.kind == PromiseKind.GENERATOR):
from transpiler.phases.typing.exceptions import TypeMismatchError from transpiler.phases.typing.exceptions import TypeMismatchError
try: try:
rtype.unify(TY_NONE) rtype.unify(TY_NONE)
...@@ -104,21 +108,28 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -104,21 +108,28 @@ class ScoperVisitor(NodeVisitorSeq):
from transpiler.phases.typing.exceptions import MissingReturnError from transpiler.phases.typing.exceptions import MissingReturnError
raise MissingReturnError(node) from e raise MissingReturnError(node) from e
def get_iter(seq_type): def get_iter(self, seq_type):
try: try:
iter_type = seq_type.fields["__iter__"].type.return_type return self.expr().visit_function_call(self.expr().visit_getattr(seq_type, "__iter__"), [])
except: except:
from transpiler.phases.typing.exceptions import NotIterableError from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type) raise NotIterableError(seq_type)
return iter_type
def get_next(iter_type): def get_next(self, iter_type):
try: try:
next_type = iter_type.fields["__next__"].type.return_type return self.expr().visit_function_call(self.expr().visit_getattr(iter_type, "__next__"), [])
except: except:
from transpiler.phases.typing.exceptions import NotIteratorError from transpiler.phases.typing.exceptions import NotIterableError
raise NotIteratorError(iter_type) raise NotIterableError(iter_type)
return next_type
def is_builtin(x, feature): def is_builtin(x, feature):
return isinstance(x, BuiltinFeatureType) and x.feature() == feature return isinstance(x, BuiltinFeatureType) and x.feature() == feature
@dataclass
class DeclareInfo:
initial_value: Optional[ast.expr] = None
IsDeclare = None | DeclareInfo
...@@ -4,7 +4,7 @@ import inspect ...@@ -4,7 +4,7 @@ import inspect
from itertools import zip_longest from itertools import zip_longest
from typing import List from typing import List
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next, is_builtin from transpiler.phases.typing.common import ScoperVisitor, is_builtin
from transpiler.phases.typing.exceptions import ArgumentCountMismatchError, TypeMismatchKind, TypeMismatchError from transpiler.phases.typing.exceptions import ArgumentCountMismatchError, TypeMismatchKind, TypeMismatchError
from transpiler.phases.typing.types import BaseType, TY_STR, TY_BOOL, TY_INT, TY_COMPLEX, TY_FLOAT, TY_NONE, \ from transpiler.phases.typing.types import BaseType, TY_STR, TY_BOOL, TY_INT, TY_COMPLEX, TY_FLOAT, TY_NONE, \
ClassTypeType, ResolvedConcreteType, GenericType, CallableInstanceType, TY_LIST, TY_SET, TY_DICT, RuntimeValue, \ ClassTypeType, ResolvedConcreteType, GenericType, CallableInstanceType, TY_LIST, TY_SET, TY_DICT, RuntimeValue, \
...@@ -141,7 +141,12 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -141,7 +141,12 @@ class ScoperExprVisitor(ScoperVisitor):
if not a.try_assign(b): if not a.try_assign(b):
raise TypeMismatchError(a, b, TypeMismatchKind.DIFFERENT_TYPE) raise TypeMismatchError(a, b, TypeMismatchKind.DIFFERENT_TYPE)
return ftype.return_type if not ftype.is_native:
from transpiler.phases.typing.block import ScoperBlockVisitor
vis = ScoperBlockVisitor(ftype.block_data.scope)
for stmt in ftype.block_data.node.body:
vis.visit(stmt)
return ftype.return_type.resolve()
# if isinstance(ftype, TypeType):# and isinstance(ftype.type_object, UserType): # if isinstance(ftype, TypeType):# and isinstance(ftype.type_object, UserType):
# init: FunctionType = self.visit_getattr(ftype, "__init__").remove_self() # init: FunctionType = self.visit_getattr(ftype, "__init__").remove_self()
# init.return_type = ftype.type_object # init.return_type = ftype.type_object
...@@ -176,6 +181,10 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -176,6 +181,10 @@ class ScoperExprVisitor(ScoperVisitor):
node.body.decls = decls node.body.decls = decls
return ftype return ftype
def visit_BinOp(self, node: ast.BinOp) -> BaseType:
left, right = map(self.visit, (node.left, node.right))
return TypeVariable() # TODO
# def visit_BinOp(self, node: ast.BinOp) -> BaseType: # def visit_BinOp(self, node: ast.BinOp) -> BaseType:
# left, right = map(self.visit, (node.left, node.right)) # left, right = map(self.visit, (node.left, node.right))
# return self.make_dunder([left, right], DUNDER[type(node.op)]) # return self.make_dunder([left, right], DUNDER[type(node.op)])
...@@ -300,14 +309,15 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -300,14 +309,15 @@ class ScoperExprVisitor(ScoperVisitor):
if len(node.generators) != 1: if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet") raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0] gen: ast.comprehension = node.generators[0]
iter_type = get_iter(self.visit(gen.iter)) iter_type = self.get_iter(self.visit(gen.iter))
node.input_item_type = get_next(iter_type) node.input_item_type = self.get_next(iter_type)
virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER) virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER)
from transpiler import ScoperBlockVisitor from transpiler.phases.typing.block import ScoperBlockVisitor
visitor = ScoperBlockVisitor(virt_scope) visitor = ScoperBlockVisitor(virt_scope)
visitor.visit_assign_target(gen.target, node.input_item_type) visitor.visit_assign_target(gen.target, node.input_item_type)
node.item_type = visitor.expr().visit(node.elt) node.item_type = visitor.expr().visit(node.elt)
for if_ in gen.ifs: # for if_ in gen.ifs:
visitor.expr().visit(if_) # visitor.expr().visit(if_)
gen.ifs_node = ast.BoolOp(ast.And(), gen.ifs, **linenodata(node)) gen.ifs_node = ast.BoolOp(ast.And(), gen.ifs, **linenodata(node))
return TyList(node.item_type) visitor.expr().visit(gen.ifs_node)
\ No newline at end of file return TY_LIST.instantiate([node.item_type])
\ No newline at end of file
...@@ -4,7 +4,7 @@ from logging import debug ...@@ -4,7 +4,7 @@ from logging import debug
from transpiler.phases.typing import PRELUDE from transpiler.phases.typing import PRELUDE
from transpiler.phases.typing.scope import Scope, VarKind, VarDecl, ScopeKind from transpiler.phases.typing.scope import Scope, VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.types import MemberDef, ResolvedConcreteType, UniqueTypeMixin from transpiler.phases.typing.types import MemberDef, ResolvedConcreteType, UniqueTypeMixin, BlockData
class ModuleType(UniqueTypeMixin, ResolvedConcreteType): class ModuleType(UniqueTypeMixin, ResolvedConcreteType):
...@@ -21,8 +21,9 @@ def make_module(name: str, scope: Scope) -> ModuleType: ...@@ -21,8 +21,9 @@ def make_module(name: str, scope: Scope) -> ModuleType:
visited_modules = {} visited_modules = {}
def parse_module(mod_name: str, python_path: Path, scope=None, preprocess=None): def parse_module(mod_name: str, python_path: list[Path], scope=None, preprocess=None) -> ModuleType:
path = python_path / mod_name for path in python_path:
path = path / mod_name
if not path.exists(): if not path.exists():
path = path.with_suffix(".py") path = path.with_suffix(".py")
...@@ -31,29 +32,33 @@ def parse_module(mod_name: str, python_path: Path, scope=None, preprocess=None): ...@@ -31,29 +32,33 @@ def parse_module(mod_name: str, python_path: Path, scope=None, preprocess=None):
path = path.with_stem(mod_name + "_") path = path.with_stem(mod_name + "_")
if not path.exists(): if not path.exists():
raise FileNotFoundError(f"Could not find {path}") continue
if path.is_dir(): break
real_path = path / "__init__.py"
else: else:
real_path = path raise FileNotFoundError(f"Could not find {mod_name}")
if mod := visited_modules.get(real_path.as_posix()): if path.is_dir():
return mod path = path / "__init__.py"
if mod := visited_modules.get(path.as_posix()):
return mod.type
mod_scope = scope or PRELUDE.child(ScopeKind.GLOBAL) mod_scope = scope or PRELUDE.child(ScopeKind.GLOBAL)
if real_path.suffix == ".py": if path.suffix != ".py":
raise NotImplementedError(f"Unsupported file type {path.suffix}")
from transpiler.phases.typing.stdlib import StdlibVisitor from transpiler.phases.typing.stdlib import StdlibVisitor
node = ast.parse(real_path.read_text()) node = ast.parse(path.read_text())
if preprocess: if preprocess:
node = preprocess(node) node = preprocess(node)
StdlibVisitor(python_path, mod_scope).visit(node) from transpiler.transpiler import TYPON_STD
else: StdlibVisitor(python_path, mod_scope, is_native=TYPON_STD in path.parents).visit(node)
raise NotImplementedError(f"Unsupported file type {path.suffix}")
mod = make_module(mod_name, mod_scope) mod = make_module(mod_name, mod_scope)
visited_modules[real_path.as_posix()] = VarDecl(VarKind.LOCAL, mod, {k: v.type for k, v in mod_scope.vars.items()}) mod.block_data = BlockData(node, mod_scope)
visited_modules[path.as_posix()] = VarDecl(VarKind.LOCAL, mod)
return mod return mod
# def process_module(mod_path: Path, scope): # def process_module(mod_path: Path, scope):
......
...@@ -75,9 +75,10 @@ def visit_generic_item( ...@@ -75,9 +75,10 @@ def visit_generic_item(
@dataclass @dataclass
class StdlibVisitor(NodeVisitorSeq): class StdlibVisitor(NodeVisitorSeq):
python_path: Path python_path: list[Path]
scope: Scope = field(default_factory=lambda: PRELUDE) scope: Scope = field(default_factory=lambda: PRELUDE)
cur_class: Optional[ResolvedConcreteType] = None cur_class: Optional[ResolvedConcreteType] = None
is_native: bool = False
def resolve_module_import(self, name: str): def resolve_module_import(self, name: str):
# tries = [ # tries = [
...@@ -88,7 +89,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -88,7 +89,7 @@ class StdlibVisitor(NodeVisitorSeq):
# if path.exists(): # if path.exists():
# return path # return path
# raise FileNotFoundError(f"Could not find module {name}") # raise FileNotFoundError(f"Could not find module {name}")
return parse_module(name, self.python_path) return parse_module(name, self.python_path, self.scope.child(ScopeKind.GLOBAL))
def expr(self) -> ScoperExprVisitor: def expr(self) -> ScoperExprVisitor:
return ScoperExprVisitor(self.scope) return ScoperExprVisitor(self.scope)
...@@ -122,7 +123,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -122,7 +123,7 @@ class StdlibVisitor(NodeVisitorSeq):
for alias in node.names: for alias in node.names:
mod = self.resolve_module_import(alias.name) mod = self.resolve_module_import(alias.name)
alias.module_obj = mod alias.module_obj = mod
self.scope.vars[alias.asname or alias.name] = mod self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, mod)
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
if existing := self.scope.get(node.name): if existing := self.scope.get(node.name):
...@@ -137,7 +138,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -137,7 +138,7 @@ class StdlibVisitor(NodeVisitorSeq):
cl_scope = scope.child(ScopeKind.CLASS) cl_scope = scope.child(ScopeKind.CLASS)
cl_scope.declare_local("Self", output.type_type()) cl_scope.declare_local("Self", output.type_type())
output.block_data = BlockData(node, scope) output.block_data = BlockData(node, scope)
visitor = StdlibVisitor(self.python_path, cl_scope, output) visitor = StdlibVisitor(self.python_path, cl_scope, output, self.is_native)
bases = [self.anno().visit(base) for base in node.bases] bases = [self.anno().visit(base) for base in node.bases]
match bases: match bases:
case []: case []:
...@@ -163,6 +164,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -163,6 +164,7 @@ class StdlibVisitor(NodeVisitorSeq):
output.return_type = arg_visitor.visit(node.returns) output.return_type = arg_visitor.visit(node.returns)
output.optional_at = len(node.args.args) - len(node.args.defaults) output.optional_at = len(node.args.args) - len(node.args.defaults)
output.is_variadic = args.vararg is not None output.is_variadic = args.vararg is not None
output.is_native = self.is_native
@dataclass(eq=False, init=False) @dataclass(eq=False, init=False)
class InstanceType(CallableInstanceType): class InstanceType(CallableInstanceType):
......
...@@ -78,7 +78,10 @@ class BaseType(ABC): ...@@ -78,7 +78,10 @@ class BaseType(ABC):
return (needle is haystack) or haystack.contains_internal(needle) return (needle is haystack) or haystack.contains_internal(needle)
def try_assign(self, other: "BaseType") -> bool: def try_assign(self, other: "BaseType") -> bool:
return self.resolve().try_assign_internal(other.resolve()) target, value = self.resolve(), other.resolve()
if type(value) == TypeVariable:
return BaseType.try_assign_internal(target, other)
return target.try_assign_internal(other)
def try_assign_internal(self, other: "BaseType") -> bool: def try_assign_internal(self, other: "BaseType") -> bool:
...@@ -489,6 +492,7 @@ class CallableInstanceType(GenericInstanceType, MethodType): ...@@ -489,6 +492,7 @@ class CallableInstanceType(GenericInstanceType, MethodType):
return_type: ConcreteType return_type: ConcreteType
optional_at: int = None optional_at: int = None
is_variadic: bool = False is_variadic: bool = False
is_native: bool = False
def __post_init__(self): def __post_init__(self):
if self.optional_at is None and self.parameters is not None: if self.optional_at is None and self.parameters is not None:
...@@ -560,6 +564,8 @@ def make_builtin_feature(name: str): ...@@ -560,6 +564,8 @@ def make_builtin_feature(name: str):
return TY_OPTIONAL return TY_OPTIONAL
case "Union": case "Union":
return TY_UNION return TY_UNION
case "Callable":
return TY_CALLABLE
case _: case _:
class CreatedType(BuiltinFeatureType): class CreatedType(BuiltinFeatureType):
def name(self): def name(self):
......
# coding: utf-8 # coding: utf-8
import ast
from pathlib import Path from pathlib import Path
import colorama import colorama
...@@ -14,14 +15,16 @@ from transpiler.phases.emit_cpp.module import emit_module ...@@ -14,14 +15,16 @@ from transpiler.phases.emit_cpp.module import emit_module
from transpiler.phases.if_main import IfMainVisitor from transpiler.phases.if_main import IfMainVisitor
from transpiler.phases.typing import PRELUDE from transpiler.phases.typing import PRELUDE
from transpiler.phases.typing.modules import parse_module from transpiler.phases.typing.modules import parse_module
from transpiler.phases.typing.stdlib import StdlibVisitor
TYPON_STD = Path(__file__).parent.parent / "stdlib"
def init(): def init():
error_display.init() error_display.init()
colorama.init() colorama.init()
typon_std = Path(__file__).parent.parent / "stdlib"
#discover_module(typon_std, PRELUDE.child(ScopeKind.GLOBAL)) #discover_module(typon_std, PRELUDE.child(ScopeKind.GLOBAL))
parse_module("builtins", typon_std, PRELUDE) parse_module("builtins", [TYPON_STD], PRELUDE)
...@@ -35,7 +38,7 @@ def transpile(source, name: str, path: Path): ...@@ -35,7 +38,7 @@ def transpile(source, name: str, path: Path):
node = DesugarOp().visit(node) node = DesugarOp().visit(node)
return node return node
module = parse_module(path.stem, path.parent, preprocess=preprocess) module = parse_module(path.stem, [path.parent, TYPON_STD], preprocess=preprocess)
def disp_scope(scope, indent=0): def disp_scope(scope, indent=0):
debug(" " * indent, scope.kind) debug(" " * indent, scope.kind)
...@@ -44,6 +47,8 @@ def transpile(source, name: str, path: Path): ...@@ -44,6 +47,8 @@ def transpile(source, name: str, path: Path):
for var in scope.vars.items(): for var in scope.vars.items():
debug(" " * (indent + 1), var) debug(" " * (indent + 1), var)
StdlibVisitor([], module.block_data.scope).expr().visit(ast.parse("main()", mode="eval").body)
def main_module(): def main_module():
yield from emit_module(module) yield from emit_module(module)
yield "#ifdef TYPON_EXTENSION" yield "#ifdef TYPON_EXTENSION"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment