Commit 213b5302 authored by Tom Niget's avatar Tom Niget

Rename type to py_type, add support for static members

parent 20578e6d
......@@ -50,7 +50,7 @@ auto dot_bind(Obj, Attr attr) {
#define dot(OBJ, NAME) [](auto && obj) -> auto { return dot_bind(obj, obj.NAME); }(OBJ)
#define dotp(OBJ, NAME) [](auto && obj) -> auto { return dot_bind(obj, obj->NAME); }(OBJ)
#define dots(OBJ, NAME) [](auto && obj) -> auto { return std::remove_reference<decltype(obj)>::type::type::NAME; }(OBJ)
#define dots(OBJ, NAME) [](auto && obj) -> auto { return std::remove_reference<decltype(obj)>::type::py_type::NAME; }(OBJ)
#endif // TYPON_BASEDEF_HPP
......@@ -45,14 +45,14 @@ template <typename T>
concept PySmartPtr = requires { typename T::element_type; };
template <typename T>
concept PyUserType = requires { typename T::type; };
concept PyUserType = requires { typename T::py_type; };
template <typename T> struct RealType {
using type = T;
};
template <PyUserType T> struct RealType<T> {
using type = typename T::type;
using type = typename T::py_type;
};
template <PySmartPtr T> struct RealType<T> {
......@@ -149,7 +149,7 @@ public:
#include "builtins/str.hpp"
struct file_s {
struct type {
struct py_type {
METHOD(
typon::Task<PyStr>, read, (Self self, size_t size = -1), {
if (size == -1) {
......@@ -178,9 +178,9 @@ struct file_s {
typon::Task<void>, flush, (Self self),
{ co_await typon::io::fsync(self->fd); })
type(int fd = -1, size_t len = 0) : fd(fd), len(len) {}
py_type(int fd = -1, size_t len = 0) : fd(fd), len(len) {}
type(const type &other)
py_type(const py_type &other)
: fd(other.fd), len(other.len) {}
METHOD(
......@@ -198,7 +198,7 @@ struct file_s {
} file;
namespace typon {
using PyFile = PyObj<decltype(file)::type>;
using PyFile = PyObj<decltype(file)::py_type>;
}
typon::Task<typon::PyFile> open(const PyStr &path, std::string_view mode) {
......@@ -280,6 +280,8 @@ struct lvalue_or_rvalue {
namespace typon {
template< class... Types >
using PyTuple = std::tuple<Types...>;
}
template<typename T>
......
//
// Created by Tom on 18/08/2023.
//
#ifndef TYPON_EXCEPTION_HPP
#define TYPON_EXCEPTION_HPP
#include "str.hpp"
struct PyException_s {
struct py_type {
PyStr message;
};
auto operator()(const PyStr &message) const {
return py_type{message};
}
};
namespace typon {
using PyException = PyObj<PyException_s>;
}
#endif // TYPON_EXCEPTION_HPP
......@@ -7,7 +7,7 @@
#include <iostream>
#include <ostream>
#include <functional>
#include "str.hpp"
#include <typon/typon.hpp>
......@@ -56,6 +56,12 @@ void repr_to(const T &x, std::ostream &s) {
s << "<function at 0x" << std::hex << (size_t)x << std::dec << ">";
}
template <typename T>
void repr_to(const std::function<T> &x, std::ostream &s) {
s << "<function at 0x" << std::hex << (size_t)x.template target<T*>() << std::dec
<< ">";
}
template <> void repr_to(const PyStr &x, std::ostream &s) {
s << '"' << x << '"';
}
......@@ -96,4 +102,14 @@ struct {
}
} print;
// typon::Task<void> print() { std::cout << '\n'; co_return; }
struct {
PyStr operator()(const PyStr& s = ""_ps) {
std::cout << s;
PyStr input;
std::getline(std::cin, input);
return input;
}
} input;
#endif // TYPON_PRINT_HPP
......@@ -7,6 +7,7 @@
#include <sstream>
#include <string>
#include <algorithm>
using namespace std::literals;
......@@ -47,6 +48,14 @@ public:
return pos == std::string::npos ? -1 : pos;
})
METHOD(bool, isspace, (Self self), {
return std::all_of(self.begin(), self.end(), isspace);
})
METHOD(auto, py_contains, (Self self, const std::string &x), {
return self.std::string::find(x) != std::string::npos;
})
PyStr operator[](PySlice slice) const {
auto [len, new_slice] = slice.adjust_indices(this->size());
......
......@@ -20,16 +20,16 @@ struct hashlib_t {
typedef int (*openssl_final)(unsigned char *md, void *context);
struct _Hash_s {
struct type {
struct py_type {
type(PyObj<void> context, openssl_update update, openssl_final final,
py_type(PyObj<void> context, openssl_update update, openssl_final final,
int diglen)
: _context(context), _update(update), _final(final), _diglen(diglen) {
}
type() {}
py_type() {}
type(const type &other)
py_type(const py_type &other)
: _context(other._context), _update(other._update),
_final(other._final), _diglen(other._diglen) {}
......
......@@ -25,7 +25,7 @@ struct os_t {
auto, fsdecode, (std::string s) { return s; })
struct Stat_Result_s {
struct type {
struct py_type {
int st_mode;
unsigned long long st_ino;
dev_t st_dev;
......@@ -45,14 +45,14 @@ struct os_t {
} Stat_Result;
struct DirEntry_s {
struct type {
struct py_type {
PyStr name;
PyStr path;
};
} DirEntry;
struct _Scandiriterator_s {
struct type {
struct py_type {
using value_type = PyObj<DirEntry_s>;
using reference = PyObj<DirEntry_s>;
......@@ -65,14 +65,14 @@ struct os_t {
METHOD(auto, begin, (Self self), { return *self; })
METHOD(auto, end, (Self self),
{ return type(self->basepath, self->namelist, self->n, self->n); })
{ return py_type(self->basepath, self->namelist, self->n, self->n); })
auto operator*() {
auto name = PyStr(this->namelist[this->current]->d_name);
return pyobj_agg<DirEntry_s>(name, this->basepath + name);
}
type(const PyStr &basepath, struct dirent **namelist, int n,
py_type(const PyStr &basepath, struct dirent **namelist, int n,
int current = 0)
: basepath(basepath), namelist(namelist), n(n), current(current) {
if (this->basepath[this->basepath.size() - 1] != '/') {
......@@ -80,9 +80,9 @@ struct os_t {
}
}
type() {}
py_type() {}
bool operator!=(const type &other) {
bool operator!=(const py_type &other) {
return this->current != other.current;
}
......@@ -103,7 +103,7 @@ struct os_t {
STATX_SIZE, &statxbuf)) {
system_error(-err, "statx()");
}
co_return PyObj<Stat_Result_s>(new Stat_Result_s::type{
co_return PyObj<Stat_Result_s>(new Stat_Result_s::py_type{
statxbuf.stx_mode,
statxbuf.stx_ino,
makedev(statxbuf.stx_dev_major, statxbuf.stx_dev_minor),
......
......@@ -23,13 +23,13 @@ struct socket_t {
static constexpr int AF_UNIX = 1;
struct socket_s {
struct type {
METHOD(typon::Task<std::tuple<PyObj<type> COMMA() std::string>>, accept, (Self self), {
struct py_type {
METHOD(typon::Task<std::tuple<PyObj<py_type> COMMA() std::string>>, accept, (Self self), {
int connfd = co_await typon::io::accept(self->fd, NULL, NULL);
if (connfd < 0) {
system_error(-connfd, "accept()");
}
co_return std::make_tuple(pyobj<type>(connfd), std::string("")); // TODO
co_return std::make_tuple(pyobj<py_type>(connfd), std::string("")); // TODO
})
METHOD(typon::Task<void>, close, (Self self),
......@@ -74,9 +74,9 @@ struct socket_t {
}
})
type(int fd = -1) : fd(fd) {}
py_type(int fd = -1) : fd(fd) {}
type(const type &other)
py_type(const py_type &other)
: fd(other.fd) {}
int fd;
......@@ -84,7 +84,7 @@ struct socket_t {
auto operator()(int family, int type_) {
if (int fd = ::socket(family, type_, 0); fd >= 0) {
return pyobj<type>(fd);
return pyobj<py_type>(fd);
} else {
system_error(errno, "socket()");
}
......
......@@ -14,11 +14,14 @@ class int:
def __and__(self, other: Self) -> Self: ...
def __neg__(self) -> Self: ...
def __init__(self, x: str) -> None: ...
def __init__(self, x: object) -> None: ...
def __lt__(self, other: Self) -> bool: ...
def __gt__(self, other: Self) -> bool: ...
def __mod__(self, other: Self) -> Self: ...
def __ge__(self, other: Self) -> bool: ...
class float:
def __init__(self, x: object) -> None: ...
assert int.__add__
U = TypeVar("U")
......@@ -51,6 +54,8 @@ class str:
def __mul__(self, other: int) -> Self: ...
def startswith(self, prefix: Self) -> bool: ...
def __getitem__(self, item: int | slice) -> Self: ...
def isspace(self) -> bool: ...
def __contains__(self, item: Self) -> bool: ...
assert len("a")
......@@ -120,6 +125,8 @@ def print(*args) -> None: ...
assert print
def input(prompt: str = "") -> str:
...
def range(*args) -> Iterator[int]: ...
......@@ -158,4 +165,9 @@ class __test_type:
assert __test_type().test_opt(5)
assert __test_type().test_opt(5, 6)
assert not __test_type().test_opt(5, 6, 7)
assert not __test_type().test_opt()
\ No newline at end of file
assert not __test_type().test_opt()
def exit(code: int | None = None) -> None: ...
class Exception:
def __init__(self, message: str) -> None: ...
\ No newline at end of file
# coding: utf-8
Enum: BuiltinFeature["Enum"]
\ No newline at end of file
import sys
import math
from typing import Callable
# def f(x: Callable[[int], int]):
# return x(5)
x = 4
def f():
if True:
next = 5
x = 7
def g():
nonlocal x
x = 6
return 123
# def h(x):
# y = len(x)
# z = x[4]
if __name__ == "__main__":
a = [n for n in range(10)]
b = [x for x in a if x % 2 == 0]
c = [y * y for y in b]
print(a, b, c)
\ No newline at end of file
#print(f(lambda n: n + 1)) # todo
print(f())
for i in range(10):
if i > 4:
break
else:
print("else")
\ No newline at end of file
from dataclasses import dataclass
from typing import Any, Callable
from typing import Any, Callable, Optional
from enum import Enum
from itertools import groupby
import operator
import string
@dataclass
class BinOperator:
symbol: str
priority: int
perform: Callable[[float, float], float]
symbol: str
priority: int
perform: Callable[[float, float], float]
OPERATORS = [
BinOperator("+", 0, operator.add),
BinOperator("-", 0, operator.sub),
BinOperator("*", 1, operator.mul),
BinOperator("/", 1, operator.truediv)
BinOperator("+", 0, operator.add),
BinOperator("-", 0, operator.sub),
BinOperator("*", 1, operator.mul),
BinOperator("/", 1, operator.truediv)
]
ops_by_priority = [list(it) for _, it in groupby(OPERATORS, lambda op: op.priority)]
# ops_by_priority = [list(it) for _, it in groupby(OPERATORS, lambda op: op.priority)]
ops_by_priority = [
[OPERATORS[0], OPERATORS[1]],
[OPERATORS[2], OPERATORS[3]]
]
MAX_PRIORITY = len(ops_by_priority)
ops_syms = [op.symbol for op in OPERATORS]
class TokenType(Enum):
NUMBER = 1
PARENTHESIS = 2
OPERATION = 3
NUMBER = 1
PARENTHESIS = 2
OPERATION = 3
@dataclass
class Token:
type: TokenType
val: Any
type: TokenType
val: str
num: float
def tokenize(inp: str):
tokens = []
tokens = []
index = 0
def skip_spaces():
nonlocal index
while inp[index].isspace():
index += 1
def has():
return index < len(inp)
def peek():
return inp[index]
def read():
nonlocal index
index += 1
return inp[index - 1]
def read_number():
res = ""
while True:
res += read()
if not has() or peek() not in "0123456789.":
break
index = 0
def skip_spaces():
nonlocal index
while inp[index].isspace():
index += 1
def has():
return index < len(inp)
def peek():
return inp[index]
return Token(TokenType.NUMBER, res, float(res)) # if "." in res else int(res))
def read():
nonlocal index
index += 1
return inp[index - 1]
while has():
skip_spaces()
def read_number():
res = ""
next = peek()
while True:
res += read()
if not has() or peek() not in "0123456789.":
break
tok: Token
return Token(TokenType.NUMBER, float(res) if "." in res else int(res))
if next in ops_syms:
tok = Token(TokenType.OPERATION, read(), 0)
elif next in "()":
tok = Token(TokenType.PARENTHESIS, read(), 0)
elif next in "0123456789.":
tok = read_number()
else:
raise Exception("invalid character '{}' at {}".format(next, index))
while has():
skip_spaces()
tokens.append(tok)
next = peek()
return tokens
if next in ops_syms:
tok = Token(TokenType.OPERATION, read())
elif next in "()":
tok = Token(TokenType.PARENTHESIS, read())
elif next in "0123456789.":
tok = read_number()
else:
raise Exception(f"invalid character '{next}'", index)
tokens.append(tok)
def parse(tokens: list[Token]):
index = 0
return tokens
def has():
return index < len(tokens)
def current():
if not has():
raise Exception("expected token, got EOL")
return tokens[index]
def parse(tokens):
index = 0
def match(type: TokenType, val: Optional[str] = None):
return has() and tokens[index].type == type and (val is None or tokens[index].val == val)
def has():
return index < len(tokens)
def accept(type: TokenType, val: Optional[str] = None):
nonlocal index
if match(type, val):
index += 1
return True
return False
def current():
if not has():
raise Exception("expected token, got EOL")
return tokens[index]
def expect(type: TokenType, val: Optional[str] = None):
nonlocal index
if match(type, val):
index += 1
return tokens[index - 1]
if not has():
raise Exception("expected {}, got EOL".format(type))
else:
raise Exception("expected {}, got {}".format(type, current().type))
def match(type: TokenType, val: Any = None):
return has() and tokens[index].type == type and (val is None or tokens[index].val == val)
parse_term: Callable[[], None]
def accept(type: TokenType, val: Any = None):
nonlocal index
if match(type, val):
index += 1
return True
return False
def parse_bin(priority=0):
if priority >= MAX_PRIORITY:
return parse_term()
def expect(type: TokenType, val: Any = None):
nonlocal index
if match(type, val):
index += 1
return tokens[index - 1]
if not has():
raise Exception(f"expected {type}, got EOL")
else:
raise Exception(f"expected {type}, got {current().type}")
left = parse_bin(priority + 1)
ops = ops_by_priority[priority]
def parse_bin(priority=0):
if priority >= MAX_PRIORITY:
return parse_term()
left = parse_bin(priority + 1)
ops = ops_by_priority[priority]
while has() and current().type == TokenType.OPERATION:
for op in ops:
if accept(TokenType.OPERATION, op.symbol):
right = parse_bin(priority + 1)
left = op.perform(left, right)
break
else:
break
while has() and current().type == TokenType.OPERATION:
for op in ops:
if accept(TokenType.OPERATION, op.symbol):
right = parse_bin(priority + 1)
left = op.perform(left, right)
break
else:
break
return left
return left
def parse_expr():
return parse_bin()
def parse_term():
token = current()
def parse_term():
token = current()
if token.type == TokenType.NUMBER:
return expect(TokenType.NUMBER).val
elif accept(TokenType.PARENTHESIS, "("):
val = parse_expr()
expect(TokenType.PARENTHESIS, ")")
return val
else:
raise Exception(f"expected term, got {token.type}")
if token.type == TokenType.NUMBER:
return expect(TokenType.NUMBER).num
elif accept(TokenType.PARENTHESIS, "("):
val = parse_expr()
expect(TokenType.PARENTHESIS, ")")
return val
else:
raise Exception("expected term, got {}".format(token.type))
def parse_expr():
return parse_bin()
return parse_expr()
return parse_expr()
if __name__ == "__main__":
while True:
inp = input("> ")
try:
tok = tokenize(inp)
res = parse(tok)
print(res)
except Exception as e:
print(e)
print()
while True:
# inp = input("> ")
inp = "2 + 3 * 4"
try:
tok = tokenize(inp)
res = parse(tok)
print(res)
except Exception as e:
# print(e)
pass
print()
break
......@@ -42,7 +42,10 @@ if __name__ == "__main__":
sum = 0
for i in range(15):
sum += i
a = [n for n in range(10)]
b = [x for x in a if x % 2 == 0]
c = [y * y for y in b]
print("C++ " if is_cpp() else "Python",
"res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5,
3j, sum)
3j, sum, a, b, c)
print()
......@@ -9,7 +9,7 @@ from transpiler.phases.emit_cpp.consts import MAPPINGS
from transpiler.phases.typing import TypeVariable
from transpiler.phases.typing.exceptions import UnresolvedTypeVariableError
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind, TY_STR, UserType, \
TypeType, TypeOperator, TY_FLOAT
TypeType, TypeOperator, TY_FLOAT, FunctionType
from transpiler.utils import UnsupportedNodeError, highlight
......@@ -70,6 +70,12 @@ class NodeVisitor(UniversalVisitor):
yield f"PyObj<decltype({node.name})>"
elif isinstance(node, TypeType):
yield "auto" # TODO
elif isinstance(node, FunctionType):
yield "std::function<"
yield from self.visit(node.return_type)
yield "("
yield from join(", ", map(self.visit, node.parameters))
yield ")>"
elif isinstance(node, Promise):
yield "typon::"
if node.kind == PromiseKind.TASK:
......@@ -88,7 +94,7 @@ class NodeVisitor(UniversalVisitor):
yield from self.visit(node.return_type)
yield ">"
elif isinstance(node, TypeVariable):
#yield f"TYPEVAR_{node.name}";return
# yield f"TYPEVAR_{node.name}";return
raise UnresolvedTypeVariableError(node)
elif isinstance(node, TypeOperator):
yield "typon::Py" + node.name.title()
......
......@@ -13,17 +13,17 @@ class ClassVisitor(NodeVisitor):
yield f"extern {node.name}_s {node.name};"
yield f"struct {node.name}_s {{"
yield "struct type {"
yield "struct py_type {"
inner = ClassInnerVisitor(node.inner_scope)
for stmt in node.body:
yield from inner.visit(stmt)
yield "template<typename... T> type(T&&... args) {"
yield "template<typename... T> py_type(T&&... args) {"
yield "__init__(this, std::forward<T>(args)...);"
yield "}"
yield "type() {}"
yield "type(const type&) = delete;"
yield "type(type&&) = delete;"
yield "py_type() {}"
yield "py_type(const py_type&) = delete;"
yield "py_type(py_type&&) = delete;"
yield "void py_repr(std::ostream &s) const {"
yield "s << '{';"
......@@ -42,7 +42,7 @@ class ClassVisitor(NodeVisitor):
yield "};"
yield "template<typename... T> auto operator()(T&&... args) {"
yield "return pyobj<type>(std::forward<T>(args)...);"
yield "return pyobj<py_type>(std::forward<T>(args)...);"
yield "}"
# outer = ClassOuterVisitor(node.inner_scope)
......@@ -61,6 +61,11 @@ class ClassInnerVisitor(NodeVisitor):
yield node.target.id
yield ";"
def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
yield "static constexpr"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_Assign(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
# yield "struct {"
# yield "type* self;"
......
......@@ -74,6 +74,7 @@ PRECEDENCE_LEVELS = {op: i for i, ops in enumerate(PRECEDENCE) for op in ops}
MAPPINGS = {
"True": "true",
"False": "false",
"None": "nullptr"
"None": "nullptr",
"operator": "operator_",
}
"""Mapping of Python builtin constants to C++ equivalents."""
......@@ -3,7 +3,7 @@ import ast
from dataclasses import dataclass, field
from typing import List, Iterable
from transpiler.phases.typing.types import UserType, FunctionType, Promise
from transpiler.phases.typing.types import UserType, FunctionType, Promise, TypeType
from transpiler.phases.utils import make_lnd
from transpiler.utils import compare_ast, linenodata
from transpiler.phases.emit_cpp.consts import SYMBOLS, PRECEDENCE_LEVELS, DUNDER_SYMBOLS
......@@ -215,11 +215,16 @@ class ExpressionVisitor(NodeVisitor):
yield ")"
def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]:
if isinstance(node.type, FunctionType) and not isinstance(node.value.type, Promise):
use_dot = None
if type(node.value.type) == TypeType:
use_dot = "dots"
elif isinstance(node.type, FunctionType) and not isinstance(node.value.type, Promise):
if node.value.type.resolve().is_reference:
yield "dotp"
use_dot = "dotp"
else:
yield "dot"
use_dot = "dot"
if use_dot:
yield use_dot
yield "(("
yield from self.visit(node.value)
yield "), "
......
......@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Iterable
from transpiler.phases.emit_cpp.consts import SYMBOLS
from transpiler.phases.emit_cpp import CoroutineMode
from transpiler.phases.emit_cpp import CoroutineMode, FunctionEmissionKind
from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.typing.scope import Scope
from transpiler.phases.utils import PlainBlock
......
......@@ -12,6 +12,7 @@ from transpiler.phases.emit_cpp.class_ import ClassVisitor
from transpiler.phases.emit_cpp.function import FunctionVisitor
from transpiler.utils import compare_ast, highlight
IGNORED_IMPORTS = {"typon", "typing", "__future__", "dataclasses", "enum"}
# noinspection PyPep8Naming
@dataclass
......@@ -20,7 +21,7 @@ class ModuleVisitor(BlockVisitor):
def visit_Import(self, node: ast.Import) -> Iterable[str]:
TB = f"emitting C++ code for {highlight(node)}"
for alias in node.names:
concrete = alias.asname or alias.name
concrete = self.fix_name(alias.asname or alias.name)
if alias.module_obj.is_python:
yield f"namespace py_{concrete} {{"
yield f"struct {concrete}_t {{"
......@@ -33,7 +34,7 @@ class ModuleVisitor(BlockVisitor):
yield f"auto& get_all() {{ return all; }}"
yield "}"
yield f'auto& {concrete} = py_{concrete}::get_all();'
elif alias.name in {"typon", "typing", "__future__"}:
elif alias.name in IGNORED_IMPORTS:
yield ""
else:
yield from self.import_module(alias.name)
......@@ -70,7 +71,7 @@ class ModuleVisitor(BlockVisitor):
yield f"}} {alias};"
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module in {"typon", "typing", "__future__"}:
if node.module in IGNORED_IMPORTS:
yield ""
elif node.module_obj.is_python:
for alias in node.names:
......
......@@ -6,7 +6,7 @@ from dataclasses import dataclass
from transpiler.exceptions import CompileError
from transpiler.utils import highlight, linenodata
from transpiler.phases.typing import make_mod_decl
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next, is_builtin
from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
......@@ -19,9 +19,6 @@ from transpiler.phases.utils import PlainBlock, AnnotationName
class ScoperBlockVisitor(ScoperVisitor):
stdlib: bool = False
def expr(self) -> ScoperExprVisitor:
return ScoperExprVisitor(self.scope, self.root_decls)
def visit_Pass(self, node: ast.Pass):
pass
......@@ -107,7 +104,7 @@ class ScoperBlockVisitor(ScoperVisitor):
if target.id == "_":
return False
target.type = decl_val
if vdecl := self.scope.get(target.id):
if vdecl := self.scope.get(target.id, {VarKind.LOCAL, VarKind.GLOBAL, VarKind.NONLOCAL}, restrict_function=True):
TB = f"unifying existing variable {highlight(target.id)} of type {highlight(vdecl.type)} with assigned value {highlight(decl_val)}"
vdecl.type.unify(decl_val)
return False
......@@ -115,6 +112,7 @@ class ScoperBlockVisitor(ScoperVisitor):
self.scope.vars[target.id] = VarDecl(VarKind.LOCAL, decl_val)
if self.scope.kind == ScopeKind.FUNCTION_INNER:
self.root_decls[target.id] = VarDecl(VarKind.OUTER_DECL, decl_val)
return False
return True
elif isinstance(target, ast.Tuple):
if not isinstance(decl_val, TupleType):
......@@ -180,7 +178,7 @@ class ScoperBlockVisitor(ScoperVisitor):
visitor.visit_block(node.body)
for deco in node.decorator_list:
deco = self.expr().visit(deco)
if isinstance(deco, BuiltinFeature) and deco.val == "dataclass":
if is_builtin(deco, "dataclass"):
# init_type = FunctionType([cttype, *cttype.members.values()], TypeVariable())
# cttype.methods["__init__"] = init_type
lnd = linenodata(node)
......@@ -210,6 +208,13 @@ class ScoperBlockVisitor(ScoperVisitor):
visitor.visit_function_definition(init_method, rtype)
else:
raise NotImplementedError(deco)
for base in node.bases:
base = self.expr().visit(base)
if is_builtin(base, "Enum"):
for k in ctype.members:
ctype.members[k] = ctype
else:
raise NotImplementedError(base)
def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
......@@ -270,8 +275,8 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_Return(self, node: ast.Return):
fct = self.scope.function
if fct is None:
from transpiler.phases.typing.exceptions import ReturnOutsideFunctionError
raise ReturnOutsideFunctionError()
from transpiler.phases.typing.exceptions import OutsideFunctionError
raise OutsideFunctionError()
ftype = fct.obj_type
assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else TY_NONE
......@@ -284,6 +289,16 @@ class ScoperBlockVisitor(ScoperVisitor):
if name not in self.scope.global_scope.vars:
self.scope.global_scope.vars[name] = VarDecl(VarKind.LOCAL, None)
def visit_Nonlocal(self, node: ast.Global):
fct = self.scope.function
if fct is None:
from transpiler.phases.typing.exceptions import OutsideFunctionError
raise OutsideFunctionError()
for name in node.names:
fct.vars[name] = VarDecl(VarKind.NONLOCAL, None)
if name not in fct.parent.vars:
fct.parent.vars[name] = VarDecl(VarKind.LOCAL, None)
def visit_AugAssign(self, node: ast.AugAssign):
target, value = map(self.get_type, (node.target, node.value))
try:
......
......@@ -17,6 +17,14 @@ class ScoperClassVisitor(ScoperVisitor):
assert isinstance(node.target, ast.Name)
self.scope.obj_type.members[node.target.id] = self.visit_annotation(node.annotation)
def visit_Assign(self, node: ast.Assign):
assert len(node.targets) == 1, "Class field should be assigned to only once"
assert isinstance(node.targets[0], ast.Name)
node.is_declare = True
valtype = self.expr().visit(node.value)
node.targets[0].type = valtype
self.scope.obj_type.members[node.targets[0].id] = valtype
def visit_FunctionDef(self, node: ast.FunctionDef):
from transpiler.phases.typing.block import ScoperBlockVisitor
# TODO: maybe merge this code with ScoperBlockVisitor.visit_FunctionDef
......
......@@ -271,9 +271,9 @@ class NotIteratorError(CompileError):
"""
@dataclass
class ReturnOutsideFunctionError(CompileError):
class OutsideFunctionError(CompileError):
def __str__(self) -> str:
return f"{highlight('return')} cannot be used outside of a function"
return f"{highlight('return')} and {highlight('nonlocal')} cannot be used outside of a function"
def detail(self, last_node: ast.AST = None) -> str:
return ""
......
......@@ -105,6 +105,11 @@ class ScoperExprVisitor(ScoperVisitor):
ftype = self.visit(node.func)
if ftype.typevars:
ftype = ftype.gen_sub(None, {v.name: TypeVariable(v.name) for v in ftype.typevars})
from transpiler.exceptions import CompileError
try:
argtypes = [self.visit(arg) for arg in node.args]
except CompileError as e:
pass
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
actual = rtype
node.is_await = False
......@@ -193,14 +198,20 @@ class ScoperExprVisitor(ScoperVisitor):
return self.visit_getattr(p, name)
except MissingAttributeError as e:
pass
# class MemberProtocol(TypeOperator):
# pass
raise MissingAttributeError(ltype, name)
def visit_List(self, node: ast.List) -> BaseType:
if not node.elts:
return PyList(TypeVariable())
elems = [self.visit(e) for e in node.elts]
if len(set(elems)) != 1:
raise NotImplementedError("List with different types not handled yet")
first, *rest = elems
for e in rest:
try:
first.unify(e)
except:
raise NotImplementedError(f"List with different types not handled yet: {', '.join(map(str, elems))}")
return PyList(elems[0])
def visit_Set(self, node: ast.Set) -> BaseType:
......
......@@ -73,6 +73,8 @@ class Scope:
def child(self, kind: ScopeKind):
res = Scope(self, kind, self.function, self.global_scope)
if kind == ScopeKind.GLOBAL:
res.global_scope = res
self.children.append(res)
return res
......@@ -80,12 +82,18 @@ class Scope:
"""Declares a local variable"""
self.vars[name] = VarDecl(VarKind.LOCAL, type)
def get(self, name: str, kind: VarKind = VarKind.LOCAL) -> Optional[VarDecl]:
def get(self, name: str, kind: VarKind | set[VarKind] = VarKind.LOCAL, restrict_function: bool = False) -> Optional[VarDecl]:
"""
Gets the variable declaration of a variable in the current scope or any parent scope.
"""
if (res := self.vars.get(name)) and res.kind == kind:
if type(kind) is VarKind:
kind = {kind}
if (res := self.vars.get(name)) and res.kind in kind:
if res.kind == VarKind.GLOBAL:
return self.global_scope.get(name, kind)
elif res.kind == VarKind.NONLOCAL:
return self.function.parent.get(name, VarKind.LOCAL, True)
return res
if self.parent is not None:
return self.parent.get(name, kind)
if self.parent is not None and not (self.kind == ScopeKind.FUNCTION and restrict_function):
return self.parent.get(name, kind, restrict_function)
return None
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment