Commit 99a533ea authored by Xavier Thompson's avatar Xavier Thompson

Ensure cypclass references in templated code are correctly refcounted using smart pointers

parent d59a9c3a
...@@ -5490,4 +5490,7 @@ def template_parameter_code(T, for_display = 0, pyrex = 0): ...@@ -5490,4 +5490,7 @@ def template_parameter_code(T, for_display = 0, pyrex = 0):
""" """
Return the code string for a template parameter in a template instanciation. Return the code string for a template parameter in a template instanciation.
""" """
return T.declaration_code('', for_display, None, pyrex) if T.is_cyp_class and not for_display:
return "Cy_Ref<%s>" % T.empty_declaration_code()
else:
return T.declaration_code('', for_display, None, pyrex)
...@@ -45,48 +45,14 @@ cdef extern from * nogil: ...@@ -45,48 +45,14 @@ cdef extern from * nogil:
public: public:
using iterator = iterator_t; using iterator = iterator_t;
friend void swap(dict_view & first, dict_view & second)
{
using std::swap;
swap(first.urange, second.urange);
}
dict_view() = default; dict_view() = default;
dict_view(dict_view const & rhs) : urange(rhs.urange) dict_view(const dict_view & rhs) = default;
{ dict_view(dict_view && rhs) = default;
if (urange != NULL) dict_view & operator=(const dict_view& rhs) = default;
{ dict_view & operator=(dict_view&& rhs) = default;
urange->CyObject_INCREF(); ~dict_view() = default;
}
}
dict_view(dict_view && rhs) : dict_view() dict_view(dict_t urange) : urange(urange) {}
{
swap(*this, rhs);
}
dict_view & operator=(dict_view rhs)
{
swap(*this, rhs);
return *this;
}
~dict_view()
{
if (urange != NULL)
{
urange->CyObject_DECREF();
urange = NULL;
}
}
dict_view(dict_t urange) : urange(urange)
{
if (urange != NULL)
{
urange->CyObject_INCREF();
}
}
iterator begin() const iterator begin() const
{ {
...@@ -168,11 +134,6 @@ cdef cypclass cypdict[K, V]: ...@@ -168,11 +134,6 @@ cdef cypclass cypdict[K, V]:
__init__(self): __init__(self):
self._active_iterators.store(0) self._active_iterators.store(0)
__dealloc__(self):
for item in self._items:
Cy_DECREF(item.first)
Cy_DECREF(item.second)
V __getitem__(self, const key_type key) except ~ const: V __getitem__(self, const key_type key) except ~ const:
it = self._indices.find(key) it = self._indices.find(key)
if it != self._indices.end(): if it != self._indices.end():
...@@ -184,12 +145,8 @@ cdef cypclass cypdict[K, V]: ...@@ -184,12 +145,8 @@ cdef cypclass cypdict[K, V]:
it = self._indices.find(key) it = self._indices.find(key)
if it != self._indices.end(): if it != self._indices.end():
index = dereference(it).second index = dereference(it).second
Cy_INCREF(value)
Cy_DECREF(self._items[index].second)
self._items[index].second = value self._items[index].second = value
elif self._active_iterators == 0: elif self._active_iterators == 0:
Cy_INCREF(key)
Cy_INCREF(value)
self._indices[key] = self._items.size() self._indices[key] = self._items.size()
self._items.push_back(item_type(key, value)) self._items.push_back(item_type(key, value))
else: else:
...@@ -205,8 +162,6 @@ cdef cypclass cypdict[K, V]: ...@@ -205,8 +162,6 @@ cdef cypclass cypdict[K, V]:
with gil: with gil:
raise RuntimeError("Modifying a dictionary with active iterators") raise RuntimeError("Modifying a dictionary with active iterators")
index = dereference(it).second index = dereference(it).second
Cy_DECREF(self._items[index].first)
Cy_DECREF(self._items[index].second)
self._indices.erase(it) self._indices.erase(it)
if index < self._items.size() - 1: if index < self._items.size() - 1:
self._items[index] = self._items[self._indices.size() - 1] self._items[index] = self._items[self._indices.size() - 1]
...@@ -218,9 +173,6 @@ cdef cypclass cypdict[K, V]: ...@@ -218,9 +173,6 @@ cdef cypclass cypdict[K, V]:
void clear(self) except ~: void clear(self) except ~:
if self._active_iterators == 0: if self._active_iterators == 0:
for item in self._items:
Cy_DECREF(item.first)
Cy_DECREF(item.second)
self._items.clear() self._items.clear()
self._indices.clear() self._indices.clear()
else: else:
......
...@@ -20,7 +20,6 @@ cdef extern from * nogil: ...@@ -20,7 +20,6 @@ cdef extern from * nogil:
{ {
if (urange != NULL) if (urange != NULL)
{ {
urange->CyObject_INCREF();
urange->_active_iterators++; urange->_active_iterators++;
} }
} }
...@@ -41,7 +40,6 @@ cdef extern from * nogil: ...@@ -41,7 +40,6 @@ cdef extern from * nogil:
swap(static_cast<base&>(*this), rhs); swap(static_cast<base&>(*this), rhs);
if (urange != NULL) { if (urange != NULL) {
urange->_active_iterators--; urange->_active_iterators--;
urange->CyObject_DECREF();
urange = NULL; urange = NULL;
} }
return *this; return *this;
...@@ -51,7 +49,6 @@ cdef extern from * nogil: ...@@ -51,7 +49,6 @@ cdef extern from * nogil:
{ {
if (urange != NULL) { if (urange != NULL) {
urange->_active_iterators--; urange->_active_iterators--;
urange->CyObject_DECREF();
urange = NULL; urange = NULL;
} }
} }
...@@ -61,7 +58,6 @@ cdef extern from * nogil: ...@@ -61,7 +58,6 @@ cdef extern from * nogil:
cy_iterator_t(base const & b, urng_t urange) : base{b}, urange{urange} cy_iterator_t(base const & b, urng_t urange) : base{b}, urange{urange}
{ {
if (urange != NULL) { if (urange != NULL) {
urange->CyObject_INCREF();
urange->_active_iterators++; urange->_active_iterators++;
} }
} }
......
...@@ -34,10 +34,6 @@ cdef cypclass cyplist[V]: ...@@ -34,10 +34,6 @@ cdef cypclass cyplist[V]:
__init__(self): __init__(self):
self._active_iterators.store(0) self._active_iterators.store(0)
__dealloc__(self):
for value in self._elements:
Cy_DECREF(value)
V __getitem__(self, const size_type index) except ~ const: V __getitem__(self, const size_type index) except ~ const:
if index < self._elements.size(): if index < self._elements.size():
return self._elements[index] return self._elements[index]
...@@ -47,8 +43,6 @@ cdef cypclass cyplist[V]: ...@@ -47,8 +43,6 @@ cdef cypclass cyplist[V]:
void __setitem__(self, size_type index, const value_type value) except ~: void __setitem__(self, size_type index, const value_type value) except ~:
if index < self._elements.size(): if index < self._elements.size():
Cy_INCREF(value)
Cy_DECREF(self._elements[index])
self._elements[index] = value self._elements[index] = value
else: else:
with gil: with gil:
...@@ -58,7 +52,6 @@ cdef cypclass cyplist[V]: ...@@ -58,7 +52,6 @@ cdef cypclass cyplist[V]:
if index < self._elements.size(): if index < self._elements.size():
if self._active_iterators == 0: if self._active_iterators == 0:
it = self._elements.begin() + index it = self._elements.begin() + index
Cy_DECREF(dereference(it))
self._elements.erase(it) self._elements.erase(it)
else: else:
with gil: with gil:
...@@ -69,7 +62,6 @@ cdef cypclass cyplist[V]: ...@@ -69,7 +62,6 @@ cdef cypclass cyplist[V]:
void append(self, const value_type value) except ~: void append(self, const value_type value) except ~:
if self._active_iterators == 0: if self._active_iterators == 0:
Cy_INCREF(value)
self._elements.push_back(value) self._elements.push_back(value)
else: else:
with gil: with gil:
...@@ -79,7 +71,6 @@ cdef cypclass cyplist[V]: ...@@ -79,7 +71,6 @@ cdef cypclass cyplist[V]:
if self._active_iterators == 0: if self._active_iterators == 0:
if index <= self._elements.size(): if index <= self._elements.size():
it = self._elements.begin() + index it = self._elements.begin() + index
Cy_INCREF(value)
self._elements.insert(it, value) self._elements.insert(it, value)
else: else:
with gil: with gil:
...@@ -90,8 +81,6 @@ cdef cypclass cyplist[V]: ...@@ -90,8 +81,6 @@ cdef cypclass cyplist[V]:
void clear(self) except ~: void clear(self) except ~:
if self._active_iterators == 0: if self._active_iterators == 0:
for value in self._elements:
Cy_DECREF(value)
self._elements.clear() self._elements.clear()
else: else:
with gil: with gil:
...@@ -100,20 +89,13 @@ cdef cypclass cyplist[V]: ...@@ -100,20 +89,13 @@ cdef cypclass cyplist[V]:
cyplist[V] __add__(self, const cyplist[V] other) const: cyplist[V] __add__(self, const cyplist[V] other) const:
result = cyplist[V]() result = cyplist[V]()
result._elements.reserve(self._elements.size() + other._elements.size()) result._elements.reserve(self._elements.size() + other._elements.size())
for value in self._elements: result._elements.insert(result._elements.end(), self._elements.begin(), self._elements.end())
Cy_INCREF(value) result._elements.insert(result._elements.end(), other._elements.begin(), other._elements.end())
result._elements.push_back(value)
for value in other._elements:
Cy_INCREF(value)
result._elements.push_back(value)
return result return result
cyplist[V] __iadd__(self, const cyplist[V] other): cyplist[V] __iadd__(self, const cyplist[V] other):
if self._active_iterators == 0: if self._active_iterators == 0:
self._elements.reserve(self._elements.size() + other._elements.size()) self._elements.insert(self._elements.end(), other._elements.begin(), other._elements.end())
for value in other._elements:
Cy_INCREF(value)
self._elements.push_back(value)
return self return self
else: else:
with gil: with gil:
...@@ -123,9 +105,7 @@ cdef cypclass cyplist[V]: ...@@ -123,9 +105,7 @@ cdef cypclass cyplist[V]:
result = cyplist[V]() result = cyplist[V]()
result._elements.reserve(self._elements.size() * n) result._elements.reserve(self._elements.size() * n)
for i in range(n): for i in range(n):
for value in self._elements: result._elements.insert(result._elements.end(), self._elements.begin(), self._elements.end())
Cy_INCREF(value)
result._elements.push_back(value)
return result return result
cyplist[V] __imul__(self, size_type n): cyplist[V] __imul__(self, size_type n):
...@@ -134,15 +114,11 @@ cdef cypclass cyplist[V]: ...@@ -134,15 +114,11 @@ cdef cypclass cyplist[V]:
elements = self._elements elements = self._elements
self._elements.reserve(elements.size() * n) self._elements.reserve(elements.size() * n)
for i in range(1, n): for i in range(1, n):
for value in elements: self._elements.insert(self._elements.end(), elements.begin(), elements.end())
Cy_INCREF(value)
self._elements.push_back(value)
return self return self
elif n == 1: elif n == 1:
return self return self
else: else:
for value in self._elements:
Cy_DECREF(value)
self._elements.clear() self._elements.clear()
return self return self
else: else:
......
...@@ -77,31 +77,28 @@ template <typename T, bool = std::is_void<T>::value> ...@@ -77,31 +77,28 @@ template <typename T, bool = std::is_void<T>::value>
struct CheckedResult {}; struct CheckedResult {};
template <typename T> template <typename T>
class CheckedResult<T, false> { struct CheckedResult<T, false> {
enum Status { Ok, Err }; bool error;
T result;
private:
enum Status status; CheckedResult() : error(true) {}
CheckedResult(const T& value) : error(false), result(value) {}
public: template<typename U, typename std::enable_if<std::is_convertible<U, T>::value, int>::type = 0>
T result; CheckedResult(const CheckedResult<U>& rhs) : error(rhs.error), result(rhs.result) {}
// CheckedResult& operator=(const T& value) {
CheckedResult(const T& value) : status(Ok), result(value) {} // result = value;
CheckedResult() : status(Err) {} // status = Ok;
operator T() { return result; } // return *this;
void set_error() { status = Err; } // }
bool is_error() { return status == Err; } operator T() { return result; }
void set_error() { error = true; }
bool is_error() { return error; }
}; };
template <typename T> template <typename T>
class CheckedResult<T, true> { struct CheckedResult<T, true> {
enum Status { Ok, Err }; bool error;
CheckedResult() : error(false) {}
public: void set_error() { error = true; }
CheckedResult() : status(Ok) {} bool is_error() { return error; }
void set_error() { status = Err; }
bool is_error() { return status == Err; }
private:
enum Status status;
}; };
...@@ -88,6 +88,152 @@ ...@@ -88,6 +88,152 @@
int CyObject_TRYWLOCK(); int CyObject_TRYWLOCK();
}; };
template <typename T>
struct Cy_Ref_impl {
T* uobj = nullptr;
constexpr Cy_Ref_impl() noexcept = default;
// constexpr Cy_Ref_impl(std::nullptr_t null) noexcept : uobj(null) {}
Cy_Ref_impl(T* uobj) : uobj(uobj) {
if (uobj != nullptr) {
uobj->CyObject_INCREF();
}
}
Cy_Ref_impl(const Cy_Ref_impl& rhs) : uobj(rhs.uobj) {
if (uobj != nullptr) {
uobj->CyObject_INCREF();
}
}
template<typename U, typename std::enable_if<std::is_convertible<U*, T*>::value, int>::type = 0>
Cy_Ref_impl(const Cy_Ref_impl<U>& rhs) : uobj(rhs.uobj) {
if (uobj != nullptr) {
uobj->CyObject_INCREF();
}
}
Cy_Ref_impl(Cy_Ref_impl&& rhs) noexcept : uobj(rhs.uobj) {
rhs.uobj = nullptr;
}
template<typename U, typename std::enable_if<std::is_convertible<U*, T*>::value, int>::type = 0>
Cy_Ref_impl(Cy_Ref_impl<U>&& rhs) noexcept : uobj(rhs.uobj) {
rhs.uobj = nullptr;
}
Cy_Ref_impl& operator=(Cy_Ref_impl rhs) noexcept {
std::swap(uobj, rhs.uobj);
return *this;
}
~Cy_Ref_impl() {
if (uobj != nullptr) {
uobj->CyObject_DECREF();
uobj = nullptr;
}
}
constexpr T& operator*() const noexcept{
return *uobj;
}
constexpr T* operator->() const noexcept {
return uobj;
}
explicit operator bool() const noexcept {
return uobj;
}
constexpr operator T*() const noexcept {
return uobj;
}
template <typename U>
bool operator==(const Cy_Ref_impl<U>& rhs) const noexcept {
return uobj == rhs.uobj;
}
template <typename U>
friend bool operator==(const Cy_Ref_impl<U>& lhs, const Cy_Ref_impl<T>& rhs) noexcept {
return lhs.uobj == rhs.uobj;
}
template <typename U>
bool operator==(U* rhs) const noexcept {
return uobj == rhs;
}
template <typename U>
friend bool operator==(U* lhs, const Cy_Ref_impl<T>& rhs) noexcept {
return lhs == rhs.uobj;
}
bool operator==(std::nullptr_t) const noexcept {
return uobj == nullptr;
}
friend bool operator==(std::nullptr_t, const Cy_Ref_impl<T>& rhs) noexcept {
return rhs.uobj == nullptr;
}
template <typename U>
bool operator!=(const Cy_Ref_impl<U>& rhs) const noexcept {
return uobj != rhs.uobj;
}
template <typename U>
friend bool operator!=(const Cy_Ref_impl<U>& lhs, const Cy_Ref_impl<T>& rhs) noexcept {
return lhs.uobj != rhs.uobj;
}
template <typename U>
bool operator!=(U* rhs) const noexcept {
return uobj != rhs;
}
template <typename U>
friend bool operator!=(U* lhs, const Cy_Ref_impl<T>& rhs) noexcept {
return lhs != rhs.uobj;
}
bool operator!=(std::nullptr_t) const noexcept {
return uobj != nullptr;
}
friend bool operator!=(std::nullptr_t, const Cy_Ref_impl<T>& rhs) noexcept {
return rhs.uobj != nullptr;
}
};
namespace std {
template <typename T>
struct hash<Cy_Ref_impl<T>> {
size_t operator()(const Cy_Ref_impl<T>& ref) const {
return std::hash<T*>()(ref.uobj);
}
};
}
template <typename T, bool = std::is_convertible<T*, CyObject*>::value>
struct Cy_Ref_t {};
template <typename T>
struct Cy_Ref_t<T, true> {
using type = Cy_Ref_impl<T>;
};
template <typename T>
struct Cy_Ref_t<T, false> {
using type = T;
};
template <typename T>
using Cy_Ref = typename Cy_Ref_t<T>::type;
class Cy_rlock_guard { class Cy_rlock_guard {
CyObject* o; CyObject* o;
public: public:
......
...@@ -437,26 +437,26 @@ cdef cypclass DestroyCheckValue(Value): ...@@ -437,26 +437,26 @@ cdef cypclass DestroyCheckValue(Value):
def test_items_destroyed(): def test_items_destroyed():
""" """
>>> test_items_destroyed() >>> test_items_destroyed()
('destroyed index', 0)
('destroyed value', 0) ('destroyed value', 0)
('destroyed index', 1) ('destroyed index', 0)
('destroyed value', 1) ('destroyed value', 1)
('destroyed index', 2) ('destroyed index', 1)
('destroyed value', 2) ('destroyed value', 2)
('destroyed index', 3) ('destroyed index', 2)
('destroyed value', 3) ('destroyed value', 3)
('destroyed index', 4) ('destroyed index', 3)
('destroyed value', 4) ('destroyed value', 4)
('destroyed index', 5) ('destroyed index', 4)
('destroyed value', 5) ('destroyed value', 5)
('destroyed index', 6) ('destroyed index', 5)
('destroyed value', 6) ('destroyed value', 6)
('destroyed index', 7) ('destroyed index', 6)
('destroyed value', 7) ('destroyed value', 7)
('destroyed index', 8) ('destroyed index', 7)
('destroyed value', 8) ('destroyed value', 8)
('destroyed index', 9) ('destroyed index', 8)
('destroyed value', 9) ('destroyed value', 9)
('destroyed index', 9)
""" """
d = cypdict[DestroyCheckIndex, DestroyCheckValue]() d = cypdict[DestroyCheckIndex, DestroyCheckValue]()
for i in range(10): for i in range(10):
...@@ -475,7 +475,7 @@ def test_items_refcount(): ...@@ -475,7 +475,7 @@ def test_items_refcount():
if Cy_GETREF(value) != 2: if Cy_GETREF(value) != 2:
return -2 return -2
d[index] = value d[index] = value
if Cy_GETREF(index) != 3: if Cy_GETREF(index) != 4:
return -3 return -3
if Cy_GETREF(value) != 3: if Cy_GETREF(value) != 3:
return -4 return -4
...@@ -485,7 +485,7 @@ def test_items_refcount(): ...@@ -485,7 +485,7 @@ def test_items_refcount():
if Cy_GETREF(value) != 2: if Cy_GETREF(value) != 2:
return -6 return -6
d[index] = value d[index] = value
if Cy_GETREF(index) != 3: if Cy_GETREF(index) != 4:
return -7 return -7
if Cy_GETREF(value) != 3: if Cy_GETREF(value) != 3:
return -8 return -8
...@@ -495,7 +495,7 @@ def test_items_refcount(): ...@@ -495,7 +495,7 @@ def test_items_refcount():
if Cy_GETREF(value) != 2: if Cy_GETREF(value) != 2:
return -10 return -10
d[index] = value d[index] = value
if Cy_GETREF(index) != 3: if Cy_GETREF(index) != 4:
return -11 return -11
if Cy_GETREF(value) != 3: if Cy_GETREF(value) != 3:
return -12 return -12
...@@ -522,41 +522,41 @@ def test_update_refcount(): ...@@ -522,41 +522,41 @@ def test_update_refcount():
d1[index1] = value1 d1[index1] = value1
d2[index2] = value2 d2[index2] = value2
d2[index3] = value3 d2[index3] = value3
if Cy_GETREF(index1) != 3: if Cy_GETREF(index1) != 4:
return -1 return -1
if Cy_GETREF(value1) != 3: if Cy_GETREF(value1) != 3:
return -2 return -2
if Cy_GETREF(index2) != 3: if Cy_GETREF(index2) != 4:
return -3 return -3
if Cy_GETREF(value2) != 3: if Cy_GETREF(value2) != 3:
return -4 return -4
if Cy_GETREF(index3) != 3: if Cy_GETREF(index3) != 4:
return -5 return -5
if Cy_GETREF(value3) != 3: if Cy_GETREF(value3) != 3:
return -6 return -6
d1.update(d2) d1.update(d2)
if Cy_GETREF(index1) != 3: if Cy_GETREF(index1) != 4:
return -7 return -7
if Cy_GETREF(value1) != 3: if Cy_GETREF(value1) != 3:
return -8 return -8
if Cy_GETREF(index2) != 4: if Cy_GETREF(index2) != 6:
return -9 return -9
if Cy_GETREF(value2) != 4: if Cy_GETREF(value2) != 4:
return -10 return -10
if Cy_GETREF(index3) != 4: if Cy_GETREF(index3) != 6:
return -11 return -11
if Cy_GETREF(value3) != 4: if Cy_GETREF(value3) != 4:
return -12 return -12
del d2 del d2
if Cy_GETREF(index1) != 3: if Cy_GETREF(index1) != 4:
return -13 return -13
if Cy_GETREF(value1) != 3: if Cy_GETREF(value1) != 3:
return -14 return -14
if Cy_GETREF(index2) != 3: if Cy_GETREF(index2) != 4:
return -15 return -15
if Cy_GETREF(value2) != 3: if Cy_GETREF(value2) != 3:
return -16 return -16
if Cy_GETREF(index3) != 3: if Cy_GETREF(index3) != 4:
return -17 return -17
if Cy_GETREF(value3) != 3: if Cy_GETREF(value3) != 3:
return -18 return -18
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment