Commit 99a533ea authored by Xavier Thompson's avatar Xavier Thompson

Ensure cypclass references in templated code are correctly refcounted using smart pointers

parent d59a9c3a
......@@ -5490,4 +5490,7 @@ def template_parameter_code(T, for_display = 0, pyrex = 0):
"""
Return the code string for a template parameter in a template instanciation.
"""
return T.declaration_code('', for_display, None, pyrex)
if T.is_cyp_class and not for_display:
return "Cy_Ref<%s>" % T.empty_declaration_code()
else:
return T.declaration_code('', for_display, None, pyrex)
......@@ -45,48 +45,14 @@ cdef extern from * nogil:
public:
using iterator = iterator_t;
friend void swap(dict_view & first, dict_view & second)
{
using std::swap;
swap(first.urange, second.urange);
}
dict_view() = default;
dict_view(dict_view const & rhs) : urange(rhs.urange)
{
if (urange != NULL)
{
urange->CyObject_INCREF();
}
}
dict_view(const dict_view & rhs) = default;
dict_view(dict_view && rhs) = default;
dict_view & operator=(const dict_view& rhs) = default;
dict_view & operator=(dict_view&& rhs) = default;
~dict_view() = default;
dict_view(dict_view && rhs) : dict_view()
{
swap(*this, rhs);
}
dict_view & operator=(dict_view rhs)
{
swap(*this, rhs);
return *this;
}
~dict_view()
{
if (urange != NULL)
{
urange->CyObject_DECREF();
urange = NULL;
}
}
dict_view(dict_t urange) : urange(urange)
{
if (urange != NULL)
{
urange->CyObject_INCREF();
}
}
dict_view(dict_t urange) : urange(urange) {}
iterator begin() const
{
......@@ -168,11 +134,6 @@ cdef cypclass cypdict[K, V]:
__init__(self):
self._active_iterators.store(0)
__dealloc__(self):
for item in self._items:
Cy_DECREF(item.first)
Cy_DECREF(item.second)
V __getitem__(self, const key_type key) except ~ const:
it = self._indices.find(key)
if it != self._indices.end():
......@@ -184,12 +145,8 @@ cdef cypclass cypdict[K, V]:
it = self._indices.find(key)
if it != self._indices.end():
index = dereference(it).second
Cy_INCREF(value)
Cy_DECREF(self._items[index].second)
self._items[index].second = value
elif self._active_iterators == 0:
Cy_INCREF(key)
Cy_INCREF(value)
self._indices[key] = self._items.size()
self._items.push_back(item_type(key, value))
else:
......@@ -205,8 +162,6 @@ cdef cypclass cypdict[K, V]:
with gil:
raise RuntimeError("Modifying a dictionary with active iterators")
index = dereference(it).second
Cy_DECREF(self._items[index].first)
Cy_DECREF(self._items[index].second)
self._indices.erase(it)
if index < self._items.size() - 1:
self._items[index] = self._items[self._indices.size() - 1]
......@@ -218,9 +173,6 @@ cdef cypclass cypdict[K, V]:
void clear(self) except ~:
if self._active_iterators == 0:
for item in self._items:
Cy_DECREF(item.first)
Cy_DECREF(item.second)
self._items.clear()
self._indices.clear()
else:
......
......@@ -20,7 +20,6 @@ cdef extern from * nogil:
{
if (urange != NULL)
{
urange->CyObject_INCREF();
urange->_active_iterators++;
}
}
......@@ -41,7 +40,6 @@ cdef extern from * nogil:
swap(static_cast<base&>(*this), rhs);
if (urange != NULL) {
urange->_active_iterators--;
urange->CyObject_DECREF();
urange = NULL;
}
return *this;
......@@ -51,7 +49,6 @@ cdef extern from * nogil:
{
if (urange != NULL) {
urange->_active_iterators--;
urange->CyObject_DECREF();
urange = NULL;
}
}
......@@ -61,7 +58,6 @@ cdef extern from * nogil:
cy_iterator_t(base const & b, urng_t urange) : base{b}, urange{urange}
{
if (urange != NULL) {
urange->CyObject_INCREF();
urange->_active_iterators++;
}
}
......
......@@ -34,10 +34,6 @@ cdef cypclass cyplist[V]:
__init__(self):
self._active_iterators.store(0)
__dealloc__(self):
for value in self._elements:
Cy_DECREF(value)
V __getitem__(self, const size_type index) except ~ const:
if index < self._elements.size():
return self._elements[index]
......@@ -47,8 +43,6 @@ cdef cypclass cyplist[V]:
void __setitem__(self, size_type index, const value_type value) except ~:
if index < self._elements.size():
Cy_INCREF(value)
Cy_DECREF(self._elements[index])
self._elements[index] = value
else:
with gil:
......@@ -58,7 +52,6 @@ cdef cypclass cyplist[V]:
if index < self._elements.size():
if self._active_iterators == 0:
it = self._elements.begin() + index
Cy_DECREF(dereference(it))
self._elements.erase(it)
else:
with gil:
......@@ -69,7 +62,6 @@ cdef cypclass cyplist[V]:
void append(self, const value_type value) except ~:
if self._active_iterators == 0:
Cy_INCREF(value)
self._elements.push_back(value)
else:
with gil:
......@@ -79,7 +71,6 @@ cdef cypclass cyplist[V]:
if self._active_iterators == 0:
if index <= self._elements.size():
it = self._elements.begin() + index
Cy_INCREF(value)
self._elements.insert(it, value)
else:
with gil:
......@@ -90,8 +81,6 @@ cdef cypclass cyplist[V]:
void clear(self) except ~:
if self._active_iterators == 0:
for value in self._elements:
Cy_DECREF(value)
self._elements.clear()
else:
with gil:
......@@ -100,20 +89,13 @@ cdef cypclass cyplist[V]:
cyplist[V] __add__(self, const cyplist[V] other) const:
result = cyplist[V]()
result._elements.reserve(self._elements.size() + other._elements.size())
for value in self._elements:
Cy_INCREF(value)
result._elements.push_back(value)
for value in other._elements:
Cy_INCREF(value)
result._elements.push_back(value)
result._elements.insert(result._elements.end(), self._elements.begin(), self._elements.end())
result._elements.insert(result._elements.end(), other._elements.begin(), other._elements.end())
return result
cyplist[V] __iadd__(self, const cyplist[V] other):
if self._active_iterators == 0:
self._elements.reserve(self._elements.size() + other._elements.size())
for value in other._elements:
Cy_INCREF(value)
self._elements.push_back(value)
self._elements.insert(self._elements.end(), other._elements.begin(), other._elements.end())
return self
else:
with gil:
......@@ -123,9 +105,7 @@ cdef cypclass cyplist[V]:
result = cyplist[V]()
result._elements.reserve(self._elements.size() * n)
for i in range(n):
for value in self._elements:
Cy_INCREF(value)
result._elements.push_back(value)
result._elements.insert(result._elements.end(), self._elements.begin(), self._elements.end())
return result
cyplist[V] __imul__(self, size_type n):
......@@ -134,15 +114,11 @@ cdef cypclass cyplist[V]:
elements = self._elements
self._elements.reserve(elements.size() * n)
for i in range(1, n):
for value in elements:
Cy_INCREF(value)
self._elements.push_back(value)
self._elements.insert(self._elements.end(), elements.begin(), elements.end())
return self
elif n == 1:
return self
else:
for value in self._elements:
Cy_DECREF(value)
self._elements.clear()
return self
else:
......
......@@ -77,31 +77,28 @@ template <typename T, bool = std::is_void<T>::value>
struct CheckedResult {};
template <typename T>
class CheckedResult<T, false> {
enum Status { Ok, Err };
private:
enum Status status;
public:
T result;
CheckedResult(const T& value) : status(Ok), result(value) {}
CheckedResult() : status(Err) {}
operator T() { return result; }
void set_error() { status = Err; }
bool is_error() { return status == Err; }
struct CheckedResult<T, false> {
bool error;
T result;
CheckedResult() : error(true) {}
CheckedResult(const T& value) : error(false), result(value) {}
template<typename U, typename std::enable_if<std::is_convertible<U, T>::value, int>::type = 0>
CheckedResult(const CheckedResult<U>& rhs) : error(rhs.error), result(rhs.result) {}
// CheckedResult& operator=(const T& value) {
// result = value;
// status = Ok;
// return *this;
// }
operator T() { return result; }
void set_error() { error = true; }
bool is_error() { return error; }
};
template <typename T>
class CheckedResult<T, true> {
enum Status { Ok, Err };
public:
CheckedResult() : status(Ok) {}
void set_error() { status = Err; }
bool is_error() { return status == Err; }
private:
enum Status status;
struct CheckedResult<T, true> {
bool error;
CheckedResult() : error(false) {}
void set_error() { error = true; }
bool is_error() { return error; }
};
......@@ -88,6 +88,152 @@
int CyObject_TRYWLOCK();
};
template <typename T>
struct Cy_Ref_impl {
T* uobj = nullptr;
constexpr Cy_Ref_impl() noexcept = default;
// constexpr Cy_Ref_impl(std::nullptr_t null) noexcept : uobj(null) {}
Cy_Ref_impl(T* uobj) : uobj(uobj) {
if (uobj != nullptr) {
uobj->CyObject_INCREF();
}
}
Cy_Ref_impl(const Cy_Ref_impl& rhs) : uobj(rhs.uobj) {
if (uobj != nullptr) {
uobj->CyObject_INCREF();
}
}
template<typename U, typename std::enable_if<std::is_convertible<U*, T*>::value, int>::type = 0>
Cy_Ref_impl(const Cy_Ref_impl<U>& rhs) : uobj(rhs.uobj) {
if (uobj != nullptr) {
uobj->CyObject_INCREF();
}
}
Cy_Ref_impl(Cy_Ref_impl&& rhs) noexcept : uobj(rhs.uobj) {
rhs.uobj = nullptr;
}
template<typename U, typename std::enable_if<std::is_convertible<U*, T*>::value, int>::type = 0>
Cy_Ref_impl(Cy_Ref_impl<U>&& rhs) noexcept : uobj(rhs.uobj) {
rhs.uobj = nullptr;
}
Cy_Ref_impl& operator=(Cy_Ref_impl rhs) noexcept {
std::swap(uobj, rhs.uobj);
return *this;
}
~Cy_Ref_impl() {
if (uobj != nullptr) {
uobj->CyObject_DECREF();
uobj = nullptr;
}
}
constexpr T& operator*() const noexcept{
return *uobj;
}
constexpr T* operator->() const noexcept {
return uobj;
}
explicit operator bool() const noexcept {
return uobj;
}
constexpr operator T*() const noexcept {
return uobj;
}
template <typename U>
bool operator==(const Cy_Ref_impl<U>& rhs) const noexcept {
return uobj == rhs.uobj;
}
template <typename U>
friend bool operator==(const Cy_Ref_impl<U>& lhs, const Cy_Ref_impl<T>& rhs) noexcept {
return lhs.uobj == rhs.uobj;
}
template <typename U>
bool operator==(U* rhs) const noexcept {
return uobj == rhs;
}
template <typename U>
friend bool operator==(U* lhs, const Cy_Ref_impl<T>& rhs) noexcept {
return lhs == rhs.uobj;
}
bool operator==(std::nullptr_t) const noexcept {
return uobj == nullptr;
}
friend bool operator==(std::nullptr_t, const Cy_Ref_impl<T>& rhs) noexcept {
return rhs.uobj == nullptr;
}
template <typename U>
bool operator!=(const Cy_Ref_impl<U>& rhs) const noexcept {
return uobj != rhs.uobj;
}
template <typename U>
friend bool operator!=(const Cy_Ref_impl<U>& lhs, const Cy_Ref_impl<T>& rhs) noexcept {
return lhs.uobj != rhs.uobj;
}
template <typename U>
bool operator!=(U* rhs) const noexcept {
return uobj != rhs;
}
template <typename U>
friend bool operator!=(U* lhs, const Cy_Ref_impl<T>& rhs) noexcept {
return lhs != rhs.uobj;
}
bool operator!=(std::nullptr_t) const noexcept {
return uobj != nullptr;
}
friend bool operator!=(std::nullptr_t, const Cy_Ref_impl<T>& rhs) noexcept {
return rhs.uobj != nullptr;
}
};
namespace std {
template <typename T>
struct hash<Cy_Ref_impl<T>> {
size_t operator()(const Cy_Ref_impl<T>& ref) const {
return std::hash<T*>()(ref.uobj);
}
};
}
template <typename T, bool = std::is_convertible<T*, CyObject*>::value>
struct Cy_Ref_t {};
template <typename T>
struct Cy_Ref_t<T, true> {
using type = Cy_Ref_impl<T>;
};
template <typename T>
struct Cy_Ref_t<T, false> {
using type = T;
};
template <typename T>
using Cy_Ref = typename Cy_Ref_t<T>::type;
class Cy_rlock_guard {
CyObject* o;
public:
......
......@@ -437,26 +437,26 @@ cdef cypclass DestroyCheckValue(Value):
def test_items_destroyed():
"""
>>> test_items_destroyed()
('destroyed index', 0)
('destroyed value', 0)
('destroyed index', 1)
('destroyed index', 0)
('destroyed value', 1)
('destroyed index', 2)
('destroyed index', 1)
('destroyed value', 2)
('destroyed index', 3)
('destroyed index', 2)
('destroyed value', 3)
('destroyed index', 4)
('destroyed index', 3)
('destroyed value', 4)
('destroyed index', 5)
('destroyed index', 4)
('destroyed value', 5)
('destroyed index', 6)
('destroyed index', 5)
('destroyed value', 6)
('destroyed index', 7)
('destroyed index', 6)
('destroyed value', 7)
('destroyed index', 8)
('destroyed index', 7)
('destroyed value', 8)
('destroyed index', 9)
('destroyed index', 8)
('destroyed value', 9)
('destroyed index', 9)
"""
d = cypdict[DestroyCheckIndex, DestroyCheckValue]()
for i in range(10):
......@@ -475,7 +475,7 @@ def test_items_refcount():
if Cy_GETREF(value) != 2:
return -2
d[index] = value
if Cy_GETREF(index) != 3:
if Cy_GETREF(index) != 4:
return -3
if Cy_GETREF(value) != 3:
return -4
......@@ -485,7 +485,7 @@ def test_items_refcount():
if Cy_GETREF(value) != 2:
return -6
d[index] = value
if Cy_GETREF(index) != 3:
if Cy_GETREF(index) != 4:
return -7
if Cy_GETREF(value) != 3:
return -8
......@@ -495,7 +495,7 @@ def test_items_refcount():
if Cy_GETREF(value) != 2:
return -10
d[index] = value
if Cy_GETREF(index) != 3:
if Cy_GETREF(index) != 4:
return -11
if Cy_GETREF(value) != 3:
return -12
......@@ -522,41 +522,41 @@ def test_update_refcount():
d1[index1] = value1
d2[index2] = value2
d2[index3] = value3
if Cy_GETREF(index1) != 3:
if Cy_GETREF(index1) != 4:
return -1
if Cy_GETREF(value1) != 3:
return -2
if Cy_GETREF(index2) != 3:
if Cy_GETREF(index2) != 4:
return -3
if Cy_GETREF(value2) != 3:
return -4
if Cy_GETREF(index3) != 3:
if Cy_GETREF(index3) != 4:
return -5
if Cy_GETREF(value3) != 3:
return -6
d1.update(d2)
if Cy_GETREF(index1) != 3:
if Cy_GETREF(index1) != 4:
return -7
if Cy_GETREF(value1) != 3:
return -8
if Cy_GETREF(index2) != 4:
if Cy_GETREF(index2) != 6:
return -9
if Cy_GETREF(value2) != 4:
return -10
if Cy_GETREF(index3) != 4:
if Cy_GETREF(index3) != 6:
return -11
if Cy_GETREF(value3) != 4:
return -12
del d2
if Cy_GETREF(index1) != 3:
if Cy_GETREF(index1) != 4:
return -13
if Cy_GETREF(value1) != 3:
return -14
if Cy_GETREF(index2) != 3:
if Cy_GETREF(index2) != 4:
return -15
if Cy_GETREF(value2) != 3:
return -16
if Cy_GETREF(index3) != 3:
if Cy_GETREF(index3) != 4:
return -17
if Cy_GETREF(value3) != 3:
return -18
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment