Commit 4aa55569 authored by Tom Niget's avatar Tom Niget

Add a few helper functions, rework concepts

parent 73735bc8
......@@ -13,20 +13,41 @@
using namespace std::literals;
template<typename T>
concept Streamable = requires(const T &x, std::ostream &s) {
{ s << x } -> std::same_as<std::ostream &>;
};
template<Streamable T>
void print_to(const T &x, std::ostream &s) {
s << x;
}
template<typename T>
concept Printable = requires(const T &x, std::ostream &s) {
concept FunctionPointer = std::is_function_v<T>
or std::is_member_function_pointer_v<T>
or std::is_function_v<std::remove_pointer_t<T>>;
template<Streamable T>
requires (FunctionPointer<T>)
void print_to(const T &x, std::ostream &s) {
s << "<function at 0x" << std::hex << (size_t) x << ">";
}
template<typename T>
concept PyPrint = requires(const T &x, std::ostream &s) {
{ x.py_print(s) } -> std::same_as<void>;
};
template<Printable T>
template<PyPrint T>
void print_to(const T &x, std::ostream &s) {
x.py_print(s);
}
template<typename T>
concept Printable = requires(const T &x, std::ostream &s) {
{ print_to(x, s) } -> std::same_as<void>;
};
template<typename T>
concept PyIterator = requires(T t) {
{ t.py_next() } -> std::same_as<std::optional<T>>;
......@@ -56,13 +77,6 @@ void print() {
std::cout << '\n';
}
template<typename T, typename ... Args>
void print(T const &head, Args const &... args) {
print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n';
}
bool is_cpp() {
return true;
}
......@@ -74,4 +88,11 @@ bool is_cpp() {
#include "builtins/set.hpp"
#include "builtins/str.hpp"
template<Printable T, Printable ... Args>
void print(T const &head, Args const &... args) {
print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n';
}
#endif //TYPON_BUILTINS_HPP
......@@ -8,7 +8,7 @@
#include <ostream>
template<>
void print_to(const bool &x, std::ostream &s) {
void print_to<bool>(const bool &x, std::ostream &s) {
s << (x ? "True" : "False");
}
......
......@@ -57,4 +57,9 @@ public:
}
};
template<typename T>
PyList<T> list(std::initializer_list<T> &&v) {
return PyList<T>(std::move(v));
}
#endif //TYPON_LIST_HPP
......@@ -4,6 +4,7 @@
#ifndef TYPON_SET_HPP
#define TYPON_SET_HPP
#include <unordered_set>
template<typename T>
......@@ -83,4 +84,9 @@ public:
}
};
template<typename T>
PySet<T> set(std::initializer_list<T> &&s) {
return PySet<T>(std::move(s));
}
#endif //TYPON_SET_HPP
//
// Created by Tom on 09/03/2023.
//
#ifndef TYPON_SYS_HPP
#define TYPON_SYS_HPP
#include <iostream>
struct sys_t {
static constexpr auto& stdout = std::cout;
} sys;
#endif //TYPON_SYS_HPP
# coding: utf-8
from typon import is_cpp
import sys
test = (2 + 3) * 4
glob = 5
def g():
if True:
if True:
if True:
x = 5
print(x)
def f(x):
return x + 1
......@@ -14,13 +22,15 @@ def fct(param):
global glob
loc = 789
glob = 123
a = 5
b = 6
z = f(a + b) * 2
def fct2():
global glob
glob += 5
if __name__ == "__main__":
# todo: 0x55 & 7 == 5
print(is_cpp)
print("C++ " if is_cpp() else "Python",
"res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5,
2 + 3j)
......
......@@ -73,6 +73,8 @@ SYMBOLS = {
"""Mapping of Python AST nodes to C++ symbols."""
PRECEDENCE = [
("()", "[]", ".",),
("unary",),
("*", "/", "%",),
("+", "-"),
("<<", ">>"),
......@@ -83,6 +85,8 @@ PRECEDENCE = [
("|",),
("&&",),
("||",),
("?:",),
(",",)
]
"""Precedence of C++ operators."""
......@@ -136,12 +140,46 @@ class NodeVisitor:
return MAPPINGS.get(name, name)
class PrecedenceContext:
def __init__(self, visitor: "ExpressionVisitor", op: str):
self.visitor = visitor
self.op = op
def __enter__(self):
if self.visitor.precedence[-1:] != [self.op]:
self.visitor.precedence.append(self.op)
def __exit__(self, exc_type, exc_val, exc_tb):
self.visitor.precedence.pop()
# noinspection PyPep8Naming
class ExpressionVisitor(NodeVisitor):
def __init__(self, precedence: Optional[int] = None):
self._precedence = precedence
def __init__(self, precedence=None):
self.precedence = precedence or []
def prec_ctx(self, op: str) -> PrecedenceContext:
"""
Creates a context manager that sets the precedence of the next expression.
"""
return PrecedenceContext(self, op)
def prec(self, op: str) -> "ExpressionVisitor":
"""
Sets the precedence of the next expression.
"""
return ExpressionVisitor([op])
def reset(self) -> "ExpressionVisitor":
"""
Resets the precedence stack.
"""
return ExpressionVisitor()
def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]:
yield f"std::make_tuple({', '.join(flatmap(self.visit, node.elts))})"
yield "std::make_tuple("
yield from join(", ", map(self.visit, node.elts))
yield ")"
def visit_Constant(self, node: ast.Constant) -> Iterable[str]:
if isinstance(node.value, str):
......@@ -161,13 +199,13 @@ class ExpressionVisitor(NodeVisitor):
yield self.fix_name(node.id)
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
# TODO: operator precedence
operands = [node.left, *node.comparators]
yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1])
for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# TODO: cleaner code
yield " && "
yield from self.visit_binary_operation(op, left, right)
with self.prec_ctx("&&"):
yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1])
for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# TODO: cleaner code
yield " && "
yield from self.visit_binary_operation(op, left, right)
def visit_Call(self, node: ast.Call) -> Iterable[str]:
if getattr(node, "keywords", None):
......@@ -176,9 +214,9 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, "varargs")
if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs")
yield from self.visit(node.func)
yield from self.prec("()").visit(node.func)
yield "("
yield from join(", ", map(self.visit, node.args))
yield from join(", ", map(self.reset().visit, node.args))
yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
......@@ -188,7 +226,7 @@ class ExpressionVisitor(NodeVisitor):
yield args
yield "{"
yield "return"
yield from self.visit(node.body)
yield from self.reset().visit(node.body)
yield ";"
yield "}"
......@@ -196,39 +234,38 @@ class ExpressionVisitor(NodeVisitor):
yield from self.visit_binary_operation(node.op, node.left, node.right)
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST) -> Iterable[str]:
# TODO: precedence
op = SYMBOLS[type(op)]
inner = ExpressionVisitor(PRECEDENCE_LEVELS[op])
prio = self._precedence is not None and inner._precedence > self._precedence
prio = self.precedence and PRECEDENCE_LEVELS[self.precedence[-1]] < PRECEDENCE_LEVELS[op]
if prio:
yield "("
yield from inner.visit(left)
yield op
yield from inner.visit(right)
with self.prec_ctx(op):
yield from self.visit(left)
yield op
yield from self.visit(right)
if prio:
yield ")"
def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]:
yield from self.visit(node.value)
yield from self.prec(".").visit(node.value)
yield "."
yield node.attr
def visit_List(self, node: ast.List) -> Iterable[str]:
yield "PyList{"
yield from join(", ", map(self.visit, node.elts))
yield from join(", ", map(self.reset().visit, node.elts))
yield "}"
def visit_Set(self, node: ast.Set) -> Iterable[str]:
yield "PySet{"
yield from join(", ", map(self.visit, node.elts))
yield from join(", ", map(self.reset().visit, node.elts))
yield "}"
def visit_Dict(self, node: ast.Dict) -> Iterable[str]:
def visit_item(key, value):
yield "std::pair {"
yield from self.visit(key)
yield from self.reset().visit(key)
yield ", "
yield from self.visit(value)
yield from self.reset().visit(value)
yield "}"
yield "PyDict{"
......@@ -236,21 +273,22 @@ class ExpressionVisitor(NodeVisitor):
yield "}"
def visit_Subscript(self, node: ast.Subscript) -> Iterable[str]:
yield from self.visit(node.value)
yield from self.prec("[]").visit(node.value)
yield "["
yield from self.visit(node.slice)
yield from self.reset().visit(node.slice)
yield "]"
def visit_UnaryOp(self, node: ast.UnaryOp) -> Iterable[str]:
yield from self.visit(node.op)
yield from self.visit(node.operand)
yield from self.prec("unary").visit(node.operand)
def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]:
yield from self.visit(node.test)
yield " ? "
yield from self.visit(node.body)
yield " : "
yield from self.visit(node.orelse)
with self.prec_ctx("?:"):
yield from self.visit(node.test)
yield " ? "
yield from self.visit(node.body)
yield " : "
yield from self.visit(node.orelse)
@dataclass
......@@ -265,6 +303,7 @@ class Scope:
return name in self.vars or (self.parent is not None and self.parent.exists(name))
# noinspection PyPep8Naming
class BlockVisitor(NodeVisitor):
def __init__(self, scope: Scope):
self._scope = scope
......@@ -280,11 +319,12 @@ class BlockVisitor(NodeVisitor):
yield ";"
def visit_Import(self, node: ast.Import) -> Iterable[str]:
for name in node.names:
if name == "typon":
for alias in node.names:
if alias.name == "typon":
yield ""
else:
raise NotImplementedError(node)
yield f'#include "python/{alias.name}.hpp"'
#raise NotImplementedError(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module == "typon":
......@@ -377,6 +417,20 @@ class BlockVisitor(NodeVisitor):
yield from ExpressionVisitor().visit(node.value)
yield ";"
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
if node.value is None:
raise NotImplementedError(node, "empty value")
yield from self.visit_lvalue(node.target)
yield " = "
yield from ExpressionVisitor().visit(node.value)
yield ";"
def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target)
yield SYMBOLS[type(node.op)] + "="
yield from ExpressionVisitor().visit(node.value)
yield ";"
def visit_For(self, node: ast.For) -> Iterable[str]:
if not isinstance(node.target, ast.Name):
raise NotImplementedError(node)
......
......@@ -7,4 +7,8 @@ clang = clang_format._get_executable("clang-format") # noqa
def format_code(code: str) -> str:
return subprocess.check_output([clang, "-style=LLVM"], input=code.encode("utf-8")).decode("utf-8")
return subprocess.check_output([
clang,
"--style=LLVM",
"--assume-filename=main.cpp"
], input=code.encode("utf-8")).decode("utf-8")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment