Commit a0bd2ae8 authored by Julien Muchembled's avatar Julien Muchembled

storage: fix memory leak in replication

The following 3 methods are renamed:
 unlockData -> releaseData
 storeData -> holdData
 _storeData -> storeData

and StorageOperationHandler use the new storeData instead of the old one.
parent f1b72dfe
...@@ -308,7 +308,7 @@ class DatabaseManager(object): ...@@ -308,7 +308,7 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def _storeData(self, checksum, data, compression): def storeData(self, checksum, data, compression):
"""To be overriden by the backend to store object raw data """To be overriden by the backend to store object raw data
If same data was already stored, the storage only has to check there's If same data was already stored, the storage only has to check there's
...@@ -316,24 +316,24 @@ class DatabaseManager(object): ...@@ -316,24 +316,24 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def storeData(self, checksum_or_id, data=None, compression=None): def holdData(self, checksum_or_id, data=None, compression=None):
"""Store object raw data """Store raw data of temporary object
checksum must be the result of neo.lib.util.makeChecksum(data) checksum must be the result of neo.lib.util.makeChecksum(data)
'compression' indicates if 'data' is compressed. 'compression' indicates if 'data' is compressed.
A volatile reference is set to this data until 'unlockData' is called A volatile reference is set to this data until 'releaseData' is called
with this checksum. with this checksum.
If called with only an id, it only increment the volatile If called with only an id, it only increment the volatile
reference to the data matching the id. reference to the data matching the id.
""" """
refcount = self._uncommitted_data refcount = self._uncommitted_data
if data is not None: if data is not None:
checksum_or_id = self._storeData(checksum_or_id, data, compression) checksum_or_id = self.storeData(checksum_or_id, data, compression)
refcount[checksum_or_id] = 1 + refcount.get(checksum_or_id, 0) refcount[checksum_or_id] = 1 + refcount.get(checksum_or_id, 0)
return checksum_or_id return checksum_or_id
def unlockData(self, data_id_list, prune=False): def releaseData(self, data_id_list, prune=False):
"""Release 1 volatile reference to given list of checksums """Release 1 volatile reference to given list of data ids
If 'prune' is true, any data that is not referenced anymore (either by If 'prune' is true, any data that is not referenced anymore (either by
a volatile reference or by a fully-committed object) is deleted. a volatile reference or by a fully-committed object) is deleted.
......
...@@ -44,7 +44,7 @@ def splitOIDField(tid, oids): ...@@ -44,7 +44,7 @@ def splitOIDField(tid, oids):
class MySQLDatabaseManager(DatabaseManager): class MySQLDatabaseManager(DatabaseManager):
"""This class manages a database on MySQL.""" """This class manages a database on MySQL."""
# WARNING: some parts are not concurrent safe (ex: storeData) # WARNING: some parts are not concurrent safe (ex: holdData)
# (there must be only 1 writable connection per DB) # (there must be only 1 writable connection per DB)
# Disabled even on MySQL 5.1-5.5 and MariaDB 5.2-5.3 because # Disabled even on MySQL 5.1-5.5 and MariaDB 5.2-5.3 because
...@@ -370,7 +370,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -370,7 +370,7 @@ class MySQLDatabaseManager(DatabaseManager):
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x] data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("TRUNCATE tobj") q("TRUNCATE tobj")
q("TRUNCATE ttrans") q("TRUNCATE ttrans")
self.unlockData(data_id_list, True) self.releaseData(data_id_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary = True): def storeTransaction(self, tid, object_list, transaction, temporary = True):
e = self.escape e = self.escape
...@@ -392,7 +392,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -392,7 +392,7 @@ class MySQLDatabaseManager(DatabaseManager):
" WHERE partition=%d AND oid=%d AND tid=%d" " WHERE partition=%d AND oid=%d AND tid=%d"
% (partition, oid, value_serial)) % (partition, oid, value_serial))
if temporary: if temporary:
self.storeData(data_id) self.holdData(data_id)
else: else:
value_serial = 'NULL' value_serial = 'NULL'
q("REPLACE INTO %s VALUES (%d, %d, %d, %s, %s)" % (obj_table, q("REPLACE INTO %s VALUES (%d, %d, %d, %s, %s)" % (obj_table,
...@@ -415,7 +415,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -415,7 +415,7 @@ class MySQLDatabaseManager(DatabaseManager):
" WHERE id IN (%s) AND data_id IS NULL" " WHERE id IN (%s) AND data_id IS NULL"
% ",".join(map(str, data_id_list))) % ",".join(map(str, data_id_list)))
def _storeData(self, checksum, data, compression): def storeData(self, checksum, data, compression):
e = self.escape e = self.escape
checksum = e(checksum) checksum = e(checksum)
try: try:
...@@ -454,7 +454,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -454,7 +454,7 @@ class MySQLDatabaseManager(DatabaseManager):
q("DELETE FROM tobj WHERE tid=%d" % tid) q("DELETE FROM tobj WHERE tid=%d" % tid)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid) q("DELETE FROM ttrans WHERE tid=%d" % tid)
self.unlockData(data_id_list) self.releaseData(data_id_list)
self.commit() self.commit()
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid, oid_list=()):
...@@ -464,7 +464,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -464,7 +464,7 @@ class MySQLDatabaseManager(DatabaseManager):
q = self.query q = self.query
sql = " FROM tobj WHERE tid=%d" % tid sql = " FROM tobj WHERE tid=%d" % tid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
self.unlockData(data_id_list) self.releaseData(data_id_list)
q("DELETE" + sql) q("DELETE" + sql)
q("""DELETE FROM ttrans WHERE tid = %d""" % tid) q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
q("""DELETE FROM trans WHERE partition = %d AND tid = %d""" % q("""DELETE FROM trans WHERE partition = %d AND tid = %d""" %
......
...@@ -311,7 +311,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -311,7 +311,7 @@ class SQLiteDatabaseManager(DatabaseManager):
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x] data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("DELETE FROM tobj") q("DELETE FROM tobj")
q("DELETE FROM ttrans") q("DELETE FROM ttrans")
self.unlockData(data_id_list, True) self.releaseData(data_id_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary=True): def storeTransaction(self, tid, object_list, transaction, temporary=True):
u64 = util.u64 u64 = util.u64
...@@ -328,7 +328,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -328,7 +328,7 @@ class SQLiteDatabaseManager(DatabaseManager):
" WHERE partition=? AND oid=? AND tid=?", " WHERE partition=? AND oid=? AND tid=?",
(partition, oid, value_serial)) (partition, oid, value_serial))
if temporary: if temporary:
self.storeData(data_id) self.holdData(data_id)
try: try:
q(obj_sql, (partition, oid, tid, data_id, value_serial)) q(obj_sql, (partition, oid, tid, data_id, value_serial))
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
...@@ -361,7 +361,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -361,7 +361,7 @@ class SQLiteDatabaseManager(DatabaseManager):
q("DELETE FROM data WHERE id IN (%s)" q("DELETE FROM data WHERE id IN (%s)"
% ",".join(map(str, data_id_list))) % ",".join(map(str, data_id_list)))
def _storeData(self, checksum, data, compression, def storeData(self, checksum, data, compression,
_dup_hash=unique_constraint_message("data", "hash")): _dup_hash=unique_constraint_message("data", "hash")):
H = buffer(checksum) H = buffer(checksum)
try: try:
...@@ -399,7 +399,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -399,7 +399,7 @@ class SQLiteDatabaseManager(DatabaseManager):
q("DELETE FROM tobj WHERE tid=?", args) q("DELETE FROM tobj WHERE tid=?", args)
q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", args) q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?", args)
q("DELETE FROM ttrans WHERE tid=?", args) q("DELETE FROM ttrans WHERE tid=?", args)
self.unlockData(data_id_list) self.releaseData(data_id_list)
self.commit() self.commit()
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid, oid_list=()):
...@@ -409,7 +409,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -409,7 +409,7 @@ class SQLiteDatabaseManager(DatabaseManager):
q = self.query q = self.query
sql = " FROM tobj WHERE tid=?" sql = " FROM tobj WHERE tid=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x] data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x]
self.unlockData(data_id_list) self.releaseData(data_id_list)
q("DELETE" + sql, (tid,)) q("DELETE" + sql, (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,)) q("DELETE FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM trans WHERE partition=? AND tid=?", q("DELETE FROM trans WHERE partition=? AND tid=?",
......
...@@ -299,7 +299,7 @@ class TransactionManager(object): ...@@ -299,7 +299,7 @@ class TransactionManager(object):
if data is None: if data is None:
data_id = None data_id = None
else: else:
data_id = self._app.dm.storeData(checksum, data, compression) data_id = self._app.dm.holdData(checksum, data, compression)
self._transaction_dict[ttid].addObject(oid, data_id, value_serial) self._transaction_dict[ttid].addObject(oid, data_id, value_serial)
def abort(self, ttid, even_if_locked=False): def abort(self, ttid, even_if_locked=False):
...@@ -322,7 +322,7 @@ class TransactionManager(object): ...@@ -322,7 +322,7 @@ class TransactionManager(object):
if not even_if_locked: if not even_if_locked:
return return
else: else:
self._app.dm.unlockData([data_id self._app.dm.releaseData([data_id
for oid, data_id, value_serial in transaction.getObjectList() for oid, data_id, value_serial in transaction.getObjectList()
if data_id], True) if data_id], True)
# unlock any object # unlock any object
...@@ -387,5 +387,5 @@ class TransactionManager(object): ...@@ -387,5 +387,5 @@ class TransactionManager(object):
if new_serial: if new_serial:
data_id = None data_id = None
else: else:
self._app.dm.storeData(data_id) self._app.dm.holdData(data_id)
transaction.addObject(oid, data_id, new_serial) transaction.addObject(oid, data_id, new_serial)
...@@ -115,7 +115,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -115,7 +115,7 @@ class StorageDBTests(NeoUnitTestBase):
self._last_ttid = ttid = add64(self._last_ttid, 1) self._last_ttid = ttid = add64(self._last_ttid, 1)
transaction = oid_list, 'user', 'desc', 'ext', False, ttid transaction = oid_list, 'user', 'desc', 'ext', False, ttid
H = "0" * 20 H = "0" * 20
object_list = [(oid, self.db.storeData(H, '', 1), None) object_list = [(oid, self.db.holdData(H, '', 1), None)
for oid in oid_list] for oid in oid_list]
return (transaction, object_list) return (transaction, object_list)
...@@ -528,9 +528,9 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -528,9 +528,9 @@ class StorageDBTests(NeoUnitTestBase):
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid5 = self.getNextTID() tid5 = self.getNextTID()
oid1 = self.getOID(1) oid1 = self.getOID(1)
foo = db.storeData("3" * 20, 'foo', 0) foo = db.holdData("3" * 20, 'foo', 0)
bar = db.storeData("4" * 20, 'bar', 0) bar = db.holdData("4" * 20, 'bar', 0)
db.unlockData((foo, bar)) db.releaseData((foo, bar))
db.storeTransaction( db.storeTransaction(
tid1, ( tid1, (
(oid1, foo, None), (oid1, foo, None),
......
...@@ -120,7 +120,7 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -120,7 +120,7 @@ class TransactionManagerTests(NeoUnitTestBase):
def testSimpleCase(self): def testSimpleCase(self):
""" One node, one transaction, not abort """ """ One node, one transaction, not abort """
data_id_list = random.random(), random.random() data_id_list = random.random(), random.random()
self.app.dm.mockAddReturnValues(storeData=ReturnValues(*data_id_list)) self.app.dm.mockAddReturnValues(holdData=ReturnValues(*data_id_list))
uuid = self.getClientUUID() uuid = self.getClientUUID()
ttid = self.getNextTID() ttid = self.getNextTID()
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
...@@ -328,7 +328,7 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -328,7 +328,7 @@ class TransactionManagerTests(NeoUnitTestBase):
def test_getObjectFromTransaction(self): def test_getObjectFromTransaction(self):
data_id = random.random() data_id = random.random()
self.app.dm.mockAddReturnValues(storeData=ReturnValues(data_id)) self.app.dm.mockAddReturnValues(holdData=ReturnValues(data_id))
uuid = self.getClientUUID() uuid = self.getClientUUID()
tid1, txn1 = self._getTransaction() tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction() tid2, txn2 = self._getTransaction()
...@@ -374,8 +374,8 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -374,8 +374,8 @@ class TransactionManagerTests(NeoUnitTestBase):
self.manager.register(uuid, locking_serial) self.manager.register(uuid, locking_serial)
self.manager.storeObject(locking_serial, ram_serial, oid, 0, "3" * 20, self.manager.storeObject(locking_serial, ram_serial, oid, 0, "3" * 20,
'bar', None) 'bar', None)
storeData = self.app.dm.mockGetNamedCalls('storeData') holdData = self.app.dm.mockGetNamedCalls('holdData')
self.assertEqual(storeData.pop(0).params, ("3" * 20, 'bar', 0)) self.assertEqual(holdData.pop(0).params, ("3" * 20, 'bar', 0))
orig_object = self.manager.getObjectFromTransaction(locking_serial, orig_object = self.manager.getObjectFromTransaction(locking_serial,
oid) oid)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum) self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
...@@ -406,11 +406,11 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -406,11 +406,11 @@ class TransactionManagerTests(NeoUnitTestBase):
self.manager.storeObject(locking_serial, ram_serial, oid, None, None, self.manager.storeObject(locking_serial, ram_serial, oid, None, None,
None, orig_serial) None, orig_serial)
self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum) self.manager.updateObjectDataForPack(oid, orig_serial, None, checksum)
self.assertEqual(storeData.pop(0).params, (checksum,)) self.assertEqual(holdData.pop(0).params, (checksum,))
self.assertEqual(self.manager.getObjectFromTransaction(locking_serial, self.assertEqual(self.manager.getObjectFromTransaction(locking_serial,
oid), (oid, checksum, None)) oid), (oid, checksum, None))
self.manager.abort(locking_serial, even_if_locked=True) self.manager.abort(locking_serial, even_if_locked=True)
self.assertFalse(storeData) self.assertFalse(holdData)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -68,6 +68,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -68,6 +68,7 @@ class ReplicationTests(NEOThreadedTest):
checked = 0 checked = 0
source_dict = {x.uuid: x for x in cluster.upstream.storage_list} source_dict = {x.uuid: x for x in cluster.upstream.storage_list}
for storage in cluster.storage_list: for storage in cluster.storage_list:
self.assertFalse(storage.dm._uncommitted_data)
self.assertEqual(np, storage.pt.getPartitions()) self.assertEqual(np, storage.pt.getPartitions())
for partition in pt.getAssignedPartitionList(storage.uuid): for partition in pt.getAssignedPartitionList(storage.uuid):
cell_list = upstream_pt.getCellList(partition, readable=True) cell_list = upstream_pt.getCellList(partition, readable=True)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment