Commit 28e5f7a8 authored by Vincent Pelletier's avatar Vincent Pelletier

Add support for complete deletion of partition tail in replication.

This can happen when enough most-recent objects have been transactionally
un-created and database was packed.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2465 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 8daf4383
...@@ -435,6 +435,13 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -435,6 +435,13 @@ class BTreeDatabaseManager(DatabaseManager):
except KeyError: except KeyError:
pass pass
def deleteTransactionsAbove(self, num_partitions, partition, tid):
tid = util.u64(tid)
def same_partition(key, _):
return key % num_partitions == partition
batchDelete(self._trans, same_partition,
iter_kw={'min': tid, 'excludemin': True})
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
u64 = util.u64 u64 = util.u64
oid = u64(oid) oid = u64(oid)
...@@ -458,6 +465,24 @@ class BTreeDatabaseManager(DatabaseManager): ...@@ -458,6 +465,24 @@ class BTreeDatabaseManager(DatabaseManager):
prune(obj[oid]) prune(obj[oid])
del obj[oid] del obj[oid]
def deleteObjectsAbove(self, num_partitions, partition, oid, serial):
obj = self._obj
u64 = util.u64
oid = u64(oid)
serial = u64(serial)
if oid % num_partitions == partition:
try:
tserial = obj[oid]
except KeyError:
pass
else:
batchDelete(tserial, lambda _, __: True,
iter_kw={'min': serial, 'excludemin': True})
def same_partition(key, _):
return key % num_partitions == partition
batchDelete(obj, same_partition,
iter_kw={'min': oid, 'excludemin': True}, recycle_subtrees=True)
def getTransaction(self, tid, all=False): def getTransaction(self, tid, all=False):
tid = util.u64(tid) tid = util.u64(tid)
try: try:
......
...@@ -388,11 +388,21 @@ class DatabaseManager(object): ...@@ -388,11 +388,21 @@ class DatabaseManager(object):
an oid list""" an oid list"""
raise NotImplementedError raise NotImplementedError
def deleteTransactionsAbove(self, num_partitions, partition, tid):
"""Delete all transactions above given TID (excluded) in given
partition."""
raise NotImplementedError
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
"""Delete given object. If serial is given, only delete that serial for """Delete given object. If serial is given, only delete that serial for
given oid.""" given oid."""
raise NotImplementedError raise NotImplementedError
def deleteObjectsAbove(self, num_partitions, partition, oid, serial):
"""Delete all objects above given OID and serial (excluded) in given
partition."""
raise NotImplementedError
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information, """Return a tuple of the list of OIDs, user information,
a description, and extension information, for a given transaction a description, and extension information, for a given transaction
......
...@@ -557,6 +557,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -557,6 +557,19 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def deleteTransactionsAbove(self, num_partitions, partition, tid):
self.begin()
try:
self.query('DELETE FROM trans WHERE partition=%(partition)d AND '
'tid > %(tid)d' % {
'partition': partition,
'tid': util.u64(tid),
})
except:
self.rollback()
raise
self.commit()
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
u64 = util.u64 u64 = util.u64
oid = u64(oid) oid = u64(oid)
...@@ -577,6 +590,21 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -577,6 +590,21 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def deleteObjectsAbove(self, num_partitions, partition, oid, serial):
u64 = util.u64
self.begin()
try:
self.query('DELETE FROM obj WHERE partition=%(partition)d AND '
'oid > %(oid)d OR (oid = %(oid)d AND serial > %(serial)d)' % {
'partition': partition,
'oid': u64(oid),
'serial': u64(serial),
})
except:
self.rollback()
raise
self.commit()
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
q = self.query q = self.query
tid = util.u64(tid) tid = util.u64(tid)
......
...@@ -242,7 +242,8 @@ class ReplicationHandler(EventHandler): ...@@ -242,7 +242,8 @@ class ReplicationHandler(EventHandler):
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum, def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid): max_tid):
ask = conn.ask ask = conn.ask
replicator = self.app.replicator app = self.app
replicator = app.replicator
next_tid = add64(max_tid, 1) next_tid = add64(max_tid, 1)
action, params = self._checkRange( action, params = self._checkRange(
replicator.getTIDCheckResult(min_tid, length) == ( replicator.getTIDCheckResult(min_tid, length) == (
...@@ -261,6 +262,10 @@ class ReplicationHandler(EventHandler): ...@@ -261,6 +262,10 @@ class ReplicationHandler(EventHandler):
else: else:
ask(self._doAskCheckTIDRange(min_tid, count)) ask(self._doAskCheckTIDRange(min_tid, count))
if action == CHECK_DONE: if action == CHECK_DONE:
# Delete all transactions we might have which are beyond what peer
# knows.
app.dm.deleteTransactionsAbove(app.pt.getPartitions(),
replicator.getCurrentRID(), max_tid)
# If no more TID, a replication of transactions is finished. # If no more TID, a replication of transactions is finished.
# So start to replicate objects now. # So start to replicate objects now.
ask(self._doAskCheckSerialRange(ZERO_OID, ZERO_TID)) ask(self._doAskCheckSerialRange(ZERO_OID, ZERO_TID))
...@@ -269,7 +274,8 @@ class ReplicationHandler(EventHandler): ...@@ -269,7 +274,8 @@ class ReplicationHandler(EventHandler):
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count, def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial): oid_checksum, max_oid, serial_checksum, max_serial):
ask = conn.ask ask = conn.ask
replicator = self.app.replicator app = self.app
replicator = app.replicator
next_params = (max_oid, add64(max_serial, 1)) next_params = (max_oid, add64(max_serial, 1))
action, params = self._checkRange( action, params = self._checkRange(
replicator.getSerialCheckResult(min_oid, min_serial, length) == ( replicator.getSerialCheckResult(min_oid, min_serial, length) == (
...@@ -284,6 +290,10 @@ class ReplicationHandler(EventHandler): ...@@ -284,6 +290,10 @@ class ReplicationHandler(EventHandler):
((min_oid, min_serial), count) = params ((min_oid, min_serial), count) = params
ask(self._doAskCheckSerialRange(min_oid, min_serial, count)) ask(self._doAskCheckSerialRange(min_oid, min_serial, count))
if action == CHECK_DONE: if action == CHECK_DONE:
# Delete all objects we might have which are beyond what peer
# knows.
app.dm.deleteObjectsAbove(app.pt.getPartitions(),
replicator.getCurrentRID(), max_oid, max_serial)
# Nothing remains, so the replication for this partition is # Nothing remains, so the replication for this partition is
# finished. # finished.
replicator.setReplicationDone() replicator.setReplicationDone()
......
...@@ -67,10 +67,12 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -67,10 +67,12 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
pass pass
def getApp(self, conn=None, tid_check_result=(0, 0, ZERO_TID), def getApp(self, conn=None, tid_check_result=(0, 0, ZERO_TID),
serial_check_result=(0, 0, ZERO_OID, 0, ZERO_TID), serial_check_result=(0, 0, ZERO_OID, 0, ZERO_TID),
tid_result=(), tid_result=(),
history_result=None, history_result=None,
rid=0, critical_tid=ZERO_TID): rid=0, critical_tid=ZERO_TID,
num_partitions=1,
):
if history_result is None: if history_result is None:
history_result = {} history_result = {}
replicator = Mock({ replicator = Mock({
...@@ -99,6 +101,9 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -99,6 +101,9 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
'storeTransaction': None, 'storeTransaction': None,
'deleteObject': None, 'deleteObject': None,
}) })
pt = Mock({
'getPartitions': num_partitions,
})
return FakeApp return FakeApp
def _checkReplicationStarted(self, conn, rid, replicator): def _checkReplicationStarted(self, conn, rid, replicator):
...@@ -360,9 +365,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -360,9 +365,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
max_tid = self.getNextTID() max_tid = self.getNextTID()
length = RANGE_LENGTH / 2 length = RANGE_LENGTH / 2
rid = 12 rid = 12
num_partitions = 13
conn = self.getFakeConnection() conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 1, 0, max_tid), rid=rid, app = self.getApp(tid_check_result=(length - 1, 0, max_tid), rid=rid,
conn=conn) conn=conn, num_partitions=num_partitions)
handler = ReplicationHandler(app) handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid # Peer has the same data as we have: length, checksum and max_tid
# match. # match.
...@@ -378,6 +384,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -378,6 +384,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
calls = app.replicator.mockGetNamedCalls('checkSerialRange') calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition) calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteTransactionsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(num_partitions, rid, max_tid)
def test_answerCheckTIDRangeDifferentBigChunk(self): def test_answerCheckTIDRangeDifferentBigChunk(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
...@@ -514,9 +524,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -514,9 +524,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
max_serial = self.getNextTID() max_serial = self.getNextTID()
length = RANGE_LENGTH / 2 length = RANGE_LENGTH / 2
rid = 12 rid = 12
num_partitions = 13
conn = self.getFakeConnection() conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length - 1, 0, max_oid, 1, app = self.getApp(serial_check_result=(length - 1, 0, max_oid, 1,
max_serial), rid=rid, conn=conn) max_serial), rid=rid, conn=conn, num_partitions=num_partitions)
handler = ReplicationHandler(app) handler = ReplicationHandler(app)
# Peer has the same data as we have # Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length, handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
...@@ -524,6 +535,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -524,6 +535,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
# Result: mark replication as done # Result: mark replication as done
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
self.assertTrue(app.replicator.replication_done) self.assertTrue(app.replicator.replication_done)
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(num_partitions, rid, max_oid, max_serial)
def test_answerCheckSerialRangeDifferentBigChunk(self): def test_answerCheckSerialRangeDifferentBigChunk(self):
min_oid = self.getOID(1) min_oid = self.getOID(1)
...@@ -590,9 +605,12 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -590,9 +605,12 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
critical_tid = self.getNextTID() critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1 length = MIN_RANGE_LENGTH - 1
rid = 12 rid = 12
num_partitions = 13
conn = self.getFakeConnection() conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, app = self.getApp(tid_check_result=(length - 5, 0, max_oid,
1, max_serial), rid=rid, conn=conn, critical_tid=critical_tid) 1, max_serial), rid=rid, conn=conn, critical_tid=critical_tid,
num_partitions=num_partitions,
)
handler = ReplicationHandler(app) handler = ReplicationHandler(app)
# Peer has different data, and less than length # Peer has different data, and less than length
handler.answerCheckSerialRange(conn, min_oid, min_serial, length, handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
...@@ -611,6 +629,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -611,6 +629,10 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_serial, plength, calls[0].checkArgs(pmin_oid, pmin_serial, pmax_serial, plength,
ppartition) ppartition)
self.assertTrue(app.replicator.replication_done) self.assertTrue(app.replicator.replication_done)
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(num_partitions, rid, max_oid, max_serial)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -334,6 +334,24 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -334,6 +334,24 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getTransaction(tid1, True), None) self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None) self.assertEqual(self.db.getTransaction(tid2, True), None)
def test_deleteTransactionsAbove(self):
self.db.setNumPartitions(2)
tid1 = self.getOID(0)
tid2 = self.getOID(1)
tid3 = self.getOID(2)
oid1 = self.getOID(1)
for tid in (tid1, tid2, tid3):
txn, objs = self.getTransaction([oid1])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.deleteTransactionsAbove(2, 0, tid1)
# Right partition, below cutoff
self.assertNotEqual(self.db.getTransaction(tid1, True), None)
# Wrong partition, above cutoff
self.assertNotEqual(self.db.getTransaction(tid2, True), None)
# Right partition, above cutoff
self.assertEqual(self.db.getTransaction(tid3, True), None)
def test_deleteObject(self): def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
...@@ -351,6 +369,31 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -351,6 +369,31 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \ self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \
objs2[1][1:]) objs2[1][1:])
def test_deleteObjectsAbove(self):
self.db.setNumPartitions(2)
tid1 = self.getOID(1)
tid2 = self.getOID(2)
tid3 = self.getOID(3)
oid1 = self.getOID(0)
oid2 = self.getOID(1)
oid3 = self.getOID(2)
for tid in (tid1, tid2, tid3):
txn, objs = self.getTransaction([oid1, oid2, oid3])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.deleteObjectsAbove(2, 0, oid1, tid1)
# Right partition, below cutoff
self.assertNotEqual(self.db.getObject(oid1, tid=tid1), None)
# Right partition, above tid cutoff
self.assertEqual(self.db.getObject(oid1, tid=tid2), False)
self.assertEqual(self.db.getObject(oid1, tid=tid3), False)
# Wrong partition, above cutoff
self.assertNotEqual(self.db.getObject(oid2, tid=tid1), None)
self.assertNotEqual(self.db.getObject(oid2, tid=tid2), None)
self.assertNotEqual(self.db.getObject(oid2, tid=tid3), None)
# Right partition, above cutoff
self.assertEqual(self.db.getObject(oid3), None)
def test_getTransaction(self): def test_getTransaction(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment