Commit 65f5516d authored by Julien Muchembled's avatar Julien Muchembled

WIP: Do not send invalidations for new oids, just like ZEO

parent 49631a9f
...@@ -647,12 +647,15 @@ class Application(ThreadedApplication): ...@@ -647,12 +647,15 @@ class Application(ThreadedApplication):
# Call finish on master # Call finish on master
txn_context = txn_container.pop(transaction) txn_context = txn_container.pop(transaction)
cache_dict = txn_context.cache_dict cache_dict = txn_context.cache_dict
checked_list = [oid for oid, data in cache_dict.iteritems() checked = [oid for oid, data in cache_dict.iteritems()
if data is CHECKED_SERIAL] if data is CHECKED_SERIAL]
for oid in checked_list: for oid in checked:
del cache_dict[oid] del cache_dict[oid]
created = txn_context.created_list
modified = set(cache_dict)
modified.difference_update(created)
ttid = txn_context.ttid ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, cache_dict, checked_list) p = Packets.AskFinishTransaction(ttid, modified, checked, created)
try: try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f) tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid assert tid
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from ZODB.POSException import StorageTransactionError from ZODB.POSException import StorageTransactionError
from neo.lib.connection import ConnectionClosed from neo.lib.connection import ConnectionClosed
from neo.lib.locking import SimpleQueue from neo.lib.locking import SimpleQueue
from neo.lib.protocol import Packets from neo.lib.protocol import Packets, ZERO_OID
from .exception import NEOStorageError from .exception import NEOStorageError
@apply @apply
...@@ -52,6 +52,8 @@ class Transaction(object): ...@@ -52,6 +52,8 @@ class Transaction(object):
# if the id is still known by the NodeManager. # if the id is still known by the NodeManager.
# status: 0 -> check only, 1 -> store, 2 -> failed # status: 0 -> check only, 1 -> store, 2 -> failed
self.involved_nodes = {} # {node_id: status} self.involved_nodes = {} # {node_id: status}
# new oids
self.created_list = []
def wakeup(self, conn): def wakeup(self, conn):
self.queue.put((conn, _WakeupPacket, {})) self.queue.put((conn, _WakeupPacket, {}))
...@@ -128,6 +130,8 @@ class Transaction(object): ...@@ -128,6 +130,8 @@ class Transaction(object):
# would just flush it on tpc_finish. This also # would just flush it on tpc_finish. This also
# prevents memory errors for big transactions. # prevents memory errors for big transactions.
data = None data = None
if serial == ZERO_OID:
self.created_list.append(oid)
self.cache_dict[oid] = data self.cache_dict[oid] = data
def nodeLost(self, app, uuid): def nodeLost(self, app, uuid):
......
...@@ -22,7 +22,7 @@ from struct import Struct ...@@ -22,7 +22,7 @@ from struct import Struct
# The protocol version must be increased whenever upgrading a node may require # The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and # to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and
# the high order byte 0 is different from TLS Handshake (0x16). # the high order byte 0 is different from TLS Handshake (0x16).
PROTOCOL_VERSION = 1 PROTOCOL_VERSION = 2
ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION) ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION)
# Avoid memory errors on corrupted data. # Avoid memory errors on corrupted data.
...@@ -867,6 +867,9 @@ class FinishTransaction(Packet): ...@@ -867,6 +867,9 @@ class FinishTransaction(Packet):
PList('checked_list', PList('checked_list',
POID('oid'), POID('oid'),
), ),
PList('created_list',
POID('oid'),
),
) )
_answer = PStruct('answer_information_locked', _answer = PStruct('answer_information_locked',
......
...@@ -57,15 +57,9 @@ class ClientServiceHandler(MasterHandler): ...@@ -57,15 +57,9 @@ class ClientServiceHandler(MasterHandler):
conn.answer((Errors.Ack if app.tm.vote(app, *args) else conn.answer((Errors.Ack if app.tm.vote(app, *args) else
Errors.IncompleteTransaction)()) Errors.IncompleteTransaction)())
def askFinishTransaction(self, conn, ttid, oid_list, checked_list): def askFinishTransaction(self, conn, ttid, *args):
app = self.app app = self.app
tid, node_list = app.tm.prepare( tid, node_list = app.tm.prepare(app, ttid, conn.getPeerId(), *args)
app,
ttid,
oid_list,
checked_list,
conn.getPeerId(),
)
if tid: if tid:
p = Packets.AskLockInformation(ttid, tid) p = Packets.AskLockInformation(ttid, tid)
for node in node_list: for node in node_list:
......
...@@ -91,10 +91,10 @@ class Transaction(object): ...@@ -91,10 +91,10 @@ class Transaction(object):
def getOIDList(self): def getOIDList(self):
""" """
Returns the list of OIDs used in the transaction Returns the list of OIDs modified in the transaction
""" """
return list(self._oid_list) return self._oid_list
def isPrepared(self): def isPrepared(self):
""" """
...@@ -348,7 +348,7 @@ class TransactionManager(EventQueue): ...@@ -348,7 +348,7 @@ class TransactionManager(EventQueue):
txn._failed = failed txn._failed = failed
return True return True
def prepare(self, app, ttid, oid_list, checked_list, msg_id): def prepare(self, app, ttid, msg_id, modified, checked, oid_list):
""" """
Prepare a transaction to be finished Prepare a transaction to be finished
""" """
...@@ -360,8 +360,9 @@ class TransactionManager(EventQueue): ...@@ -360,8 +360,9 @@ class TransactionManager(EventQueue):
return None, None return None, None
ready = app.getStorageReadySet(txn._storage_readiness) ready = app.getStorageReadySet(txn._storage_readiness)
getPartition = pt.getPartition getPartition = pt.getPartition
oid_list += modified
partition_set = set(map(getPartition, oid_list)) partition_set = set(map(getPartition, oid_list))
partition_set.update(map(getPartition, checked_list)) partition_set.update(map(getPartition, checked))
partition_set.add(getPartition(ttid)) partition_set.add(getPartition(ttid))
node_list = [] node_list = []
uuid_set = set() uuid_set = set()
...@@ -394,7 +395,7 @@ class TransactionManager(EventQueue): ...@@ -394,7 +395,7 @@ class TransactionManager(EventQueue):
self._queue.append(ttid) self._queue.append(ttid)
logging.debug('Finish TXN %s for %s (was %s)', logging.debug('Finish TXN %s for %s (was %s)',
dump(tid), txn.getNode(), dump(ttid)) dump(tid), txn.getNode(), dump(ttid))
txn.prepare(tid, oid_list, uuid_set, msg_id) txn.prepare(tid, modified, uuid_set, msg_id)
# check if greater and foreign OID was stored # check if greater and foreign OID was stored
if oid_list: if oid_list:
self.setLastOID(max(oid_list)) self.setLastOID(max(oid_list))
......
...@@ -905,11 +905,16 @@ class Test(NEOThreadedTest): ...@@ -905,11 +905,16 @@ class Test(NEOThreadedTest):
def testExternalInvalidation(self, cluster): def testExternalInvalidation(self, cluster):
# Initialize objects # Initialize objects
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
old_zodb = hasattr(c1, '_invalidated')
if old_zodb: # BBB: ZODB < 5
invalidations = lambda conn: conn._invalidated
else:
invalidations = lambda conn: conn._storage._invalidations
c1.root()['x'] = x1 = PCounter() c1.root()['x'] = x1 = PCounter()
c1.root()['y'] = y = PCounter() c1.root()['y'] = y = PCounter()
y.value = 1 y.value = 1
t1.commit() t1.commit()
# Get pickle of y # Get pickles of 0 and 1
t1.begin() t1.begin()
x = c1._storage.load(x1._p_oid)[0] x = c1._storage.load(x1._p_oid)[0]
y = c1._storage.load(y._p_oid)[0] y = c1._storage.load(y._p_oid)[0]
...@@ -922,6 +927,10 @@ class Test(NEOThreadedTest): ...@@ -922,6 +927,10 @@ class Test(NEOThreadedTest):
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(None, txn) client.tpc_begin(None, txn)
client.store(x1._p_oid, x1._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
# At the same time, we check that there's no invalidation sent
# for new oids.
new_oid = client.new_oid()
client.store(new_oid, ZERO_TID, x, '', txn)
# Delay invalidation for x # Delay invalidation for x
with cluster.master.filterConnection(cluster.client) as m2c: with cluster.master.filterConnection(cluster.client) as m2c:
m2c.delayInvalidateObjects() m2c.delayInvalidateObjects()
...@@ -932,12 +941,21 @@ class Test(NEOThreadedTest): ...@@ -932,12 +941,21 @@ class Test(NEOThreadedTest):
x2 = c2.root()['x'] x2 = c2.root()['x']
cache.clear() # bypass cache cache.clear() # bypass cache
self.assertEqual(x2.value, 0) self.assertEqual(x2.value, 0)
self.assertRaises(POSException.POSKeyError if old_zodb else
POSException.ReadConflictError, c2.get, new_oid)
x2._p_deactivate() x2._p_deactivate()
t1.begin() # process invalidation and sync connection storage t1.begin() # process invalidation and sync connection storage
if old_zodb:
self.assertEqual(c2.get(new_oid).value, 0)
else:
self.assertRaises(POSException.ReadConflictError,
c2.get, new_oid)
self.assertEqual(x2.value, 0) self.assertEqual(x2.value, 0)
self.assertEqual({x2._p_oid}, invalidations(c2))
# New testing transaction. Now we can see the last value of x. # New testing transaction. Now we can see the last value of x.
t2.begin() t2.begin()
self.assertEqual(x2.value, 1) self.assertEqual(x2.value, 1)
self.assertEqual(c2.get(new_oid).value, 0)
# Now test cache invalidation during a load from a storage # Now test cache invalidation during a load from a storage
ll = LockLock() ll = LockLock()
...@@ -952,9 +970,9 @@ class Test(NEOThreadedTest): ...@@ -952,9 +970,9 @@ class Test(NEOThreadedTest):
with ll, Patch(cluster.client, _loadFromStorage=break_after): with ll, Patch(cluster.client, _loadFromStorage=break_after):
t = self.newThread(x2._p_activate) t = self.newThread(x2._p_activate)
ll() ll()
# At this point, x could not be found the cache and the result # At this point, x could not be found in the cache and the
# from the storage (which is <value=1, next_tid=None>) is about # result from the storage (which is <value=1, next_tid=None>)
# to be processed. # is about to be processed.
# Now modify x to receive an invalidation for it. # Now modify x to receive an invalidation for it.
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(None, txn) client.tpc_begin(None, txn)
...@@ -967,12 +985,6 @@ class Test(NEOThreadedTest): ...@@ -967,12 +985,6 @@ class Test(NEOThreadedTest):
self.assertEqual(x2.value, 1) self.assertEqual(x2.value, 1)
self.assertEqual(x1.value, 0) self.assertEqual(x1.value, 0)
def invalidations(conn):
try:
return conn._storage._invalidations
except AttributeError: # BBB: ZODB < 5
return conn._invalidated
# Change x again from 0 to 1, while the checking connection c1 # Change x again from 0 to 1, while the checking connection c1
# is suspended at the beginning of the transaction t1, # is suspended at the beginning of the transaction t1,
# between Storage.sync() and flush of invalidations. # between Storage.sync() and flush of invalidations.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment