Commit 551ba3cc authored by Tom Niget's avatar Tom Niget

Fix various things, calcbasic works

parent 759c796e
......@@ -66,12 +66,16 @@ class PyObj : public std::shared_ptr<typename RealType<T>::type> {
public:
using inner = typename RealType<T>::type;
template<typename... Args>
PyObj(Args&&... args) : std::shared_ptr<inner>(std::make_shared<inner>(std::forward<Args>(args)...)) {}
PyObj() : std::shared_ptr<inner>() {}
PyObj(std::nullptr_t) : std::shared_ptr<inner>(nullptr) {}
PyObj(inner *ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(const std::shared_ptr<inner> &ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(std::shared_ptr<inner> &&ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(const PyObj &ptr) : std::shared_ptr<inner>(ptr) {}
PyObj( PyObj &ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(PyObj &&ptr) : std::shared_ptr<inner>(ptr) {}
PyObj &operator=(const PyObj &ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; }
PyObj &operator=(PyObj &&ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; }
......@@ -82,16 +86,14 @@ public:
template<typename U>
PyObj(const PyObj<U> &ptr) : std::shared_ptr<inner>(ptr) {}
template<typename U>
PyObj(PyObj<U> &&ptr) : std::shared_ptr<inner>(ptr) {}
//PyObj(PyObj<U> &&ptr) : std::shared_ptr<inner>(ptr) {}
// using make_shared
template<class U>
PyObj(U&& other) : std::shared_ptr<inner>(std::make_shared<inner>(other)) {}
/*template<class U>
PyObj(U&& other) : std::shared_ptr<inner>(std::make_shared<inner>(other)) {}*/
/*template<typename... Args>
PyObj(Args&&... args) : std::shared_ptr<inner>(std::forward<Args>(args)...) {}*/
......@@ -124,7 +126,7 @@ public:
}
};
template <typename T, typename... Args> auto pyobj(Args &&...args) -> PyObj<T> {
template <typename T, typename... Args> auto pyobj(Args &&...args) -> PyObj<typename RealType<T>::type> {
return std::make_shared<typename RealType<T>::type>(
std::forward<Args>(args)...);
}
......
......@@ -91,7 +91,7 @@ class BlockVisitor(NodeVisitor):
else:
yield from self.visit(argty)
yield arg.arg
if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA} and default:
if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA, FunctionEmissionKind.METHOD} and default:
yield " = "
yield from self.expr().visit(default)
yield ")"
......
......@@ -34,7 +34,7 @@ class ClassVisitor(NodeVisitor):
else:
yield "void py_repr(std::ostream &s) const {"
yield "s << '{';"
for i, (name, memb) in enumerate(node.type.fields.items()):
for i, (name, memb) in enumerate(node.type.get_members().items()):
if i != 0:
yield 's << ", ";'
yield f's << "\\"{name}\\": ";'
......
......@@ -27,8 +27,9 @@ class ModuleVisitor(BlockVisitor):
yield f"struct {concrete}_t {{"
for name, obj in alias.module_obj.fields.items():
if obj.type.python_func_used:
yield from self.emit_python_func(alias.name, name, name, obj.type)
ty = obj.type.resolve()
if getattr(ty, "python_func_used", False):
yield from self.emit_python_func(alias.name, name, name, ty)
yield "} all;"
yield f"auto& get_all() {{ return all; }}"
......
......@@ -32,7 +32,9 @@ class ScoperBlockVisitor(ScoperVisitor):
# copy all functions to mod_scope
for fname, obj in py_mod.__dict__.items():
if callable(obj):
fty = FunctionType([], TypeVariable())
# fty = FunctionType([], TypeVariable())
# fty.is_python_func = True
fty = TypeVariable()
fty.is_python_func = True
mod_scope.vars[fname] = VarDecl(VarKind.LOCAL, fty)
mod = make_mod_decl(name, mod_scope)
......
......@@ -7,7 +7,7 @@ from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next
from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \
TY_SLICE, TY_FLOAT
TY_SLICE, TY_FLOAT, RuntimeValue
from transpiler.utils import linenodata
DUNDER = {
......@@ -92,11 +92,12 @@ class ScoperExprVisitor(ScoperVisitor):
if not obj:
from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(node.id)
if isinstance(obj.type, TypeType) and isinstance(obj.type.type_object, TypeVariable):
ty = obj.type.resolve()
if isinstance(ty, TypeType) and isinstance(ty.type_object, TypeVariable):
raise NameError(f"Use of type variable") # todo: when does this happen exactly?
if getattr(obj, "is_python_func", False):
obj.python_func_used = True
return obj.type
if getattr(ty, "is_python_func", False):
ty.python_func_used = True
return ty
def visit_BoolOp(self, node: ast.BoolOp) -> BaseType:
for value in node.values:
......@@ -199,12 +200,12 @@ class ScoperExprVisitor(ScoperVisitor):
# else:
# return meth
if field := ltype.fields.get(name):
ty = field.type
ty = field.type.resolve()
if getattr(ty, "is_python_func", False):
ty.python_func_used = True
if isinstance(ty, FunctionType):
ty = ty.gen_sub(ltype, {})
if bound and field.in_class_def:
if bound and field.in_class_def and type(field.val) != RuntimeValue:
return ty.remove_self()
return ty
......
......@@ -111,7 +111,7 @@ class StdlibVisitor(NodeVisitorSeq):
if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
else:
self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars))
self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars), ())
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert):
......
......@@ -148,6 +148,18 @@ def next_var_id():
class TypeVariable(BaseType):
name: str = field(default_factory=lambda: next_var_id())
resolved: Optional[BaseType] = None
patch_attrs: dict = field(default_factory=dict)
def __setattr__(self, key, value):
if "patch_attrs" in self.__dict__ and key not in self.__dict__:
self.patch_attrs[key] = value
else:
super().__setattr__(key, value)
def __getattr__(self, item):
if "patch_attrs" in self.__dict__ and item in self.patch_attrs:
return self.patch_attrs[item]
raise AttributeError(item)
def __str__(self):
if self.resolved is None:
......@@ -166,6 +178,8 @@ class TypeVariable(BaseType):
from transpiler.phases.typing.exceptions import RecursiveTypeUnificationError
raise RecursiveTypeUnificationError(self, other)
self.resolved = other
for k, v in self.patch_attrs.items():
setattr(other, k, v)
def contains_internal(self, other: BaseType) -> bool:
return self.resolve() is other.resolve()
......@@ -210,7 +224,7 @@ class TypeOperator(BaseType, ABC):
if self.name is None:
self.name = self.__class__.__name__
for name, factory in self.gen_methods.items():
self.fields[name] = MemberDef(factory(self))
self.fields[name] = MemberDef(factory(self), ())
for gp in self.gen_parents:
if not isinstance(gp, BaseType):
gp = gp(self.args)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment