Commit aded81bc authored by Grégory Wisniewski's avatar Grégory Wisniewski

Use decorators and UnexpectedPacketError exception instead of calls

handleUnexpectedPacket() in storage handlers.


git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@505 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 863a3207
...@@ -23,10 +23,12 @@ from neo.protocol import INVALID_UUID, RUNNING_STATE, BROKEN_STATE, \ ...@@ -23,10 +23,12 @@ from neo.protocol import INVALID_UUID, RUNNING_STATE, BROKEN_STATE, \
MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE
from neo.node import MasterNode, StorageNode, ClientNode from neo.node import MasterNode, StorageNode, ClientNode
from neo.connection import ClientConnection from neo.connection import ClientConnection
from neo.protocol import Packet from neo.protocol import Packet, UnexpectedPacketError
from neo.pt import PartitionTable from neo.pt import PartitionTable
from neo.storage.verification import VerificationEventHandler from neo.storage.verification import VerificationEventHandler
from neo.util import dump from neo.util import dump
from neo.handler import identification_required, restrict_node_types, \
server_connection_required, client_connection_required
class BootstrapEventHandler(StorageEventHandler): class BootstrapEventHandler(StorageEventHandler):
"""This class deals with events for a bootstrap phase.""" """This class deals with events for a bootstrap phase."""
...@@ -105,11 +107,9 @@ class BootstrapEventHandler(StorageEventHandler): ...@@ -105,11 +107,9 @@ class BootstrapEventHandler(StorageEventHandler):
conn.close() conn.close()
@server_connection_required
def handleRequestNodeIdentification(self, conn, packet, node_type, def handleRequestNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port, name): uuid, ip_address, port, name):
if not conn.isServerConnection():
self.handleUnexpectedPacket(conn, packet)
else:
app = self.app app = self.app
if node_type != MASTER_NODE_TYPE: if node_type != MASTER_NODE_TYPE:
logging.info('reject a connection from a non-master') logging.info('reject a connection from a non-master')
...@@ -147,12 +147,10 @@ class BootstrapEventHandler(StorageEventHandler): ...@@ -147,12 +147,10 @@ class BootstrapEventHandler(StorageEventHandler):
# Now the master node should know that I am not the right one. # Now the master node should know that I am not the right one.
conn.abort() conn.abort()
@client_connection_required
def handleAcceptNodeIdentification(self, conn, packet, node_type, def handleAcceptNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port, uuid, ip_address, port,
num_partitions, num_replicas, your_uuid): num_partitions, num_replicas, your_uuid):
if conn.isServerConnection():
self.handleUnexpectedPacket(conn, packet)
else:
app = self.app app = self.app
node = app.nm.getNodeByServer(conn.getAddress()) node = app.nm.getNodeByServer(conn.getAddress())
if node_type != MASTER_NODE_TYPE: if node_type != MASTER_NODE_TYPE:
...@@ -197,11 +195,9 @@ class BootstrapEventHandler(StorageEventHandler): ...@@ -197,11 +195,9 @@ class BootstrapEventHandler(StorageEventHandler):
# Ask a primary master. # Ask a primary master.
conn.ask(protocol.askPrimaryMaster()) conn.ask(protocol.askPrimaryMaster())
@client_connection_required
def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid, def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid,
known_master_list): known_master_list):
if conn.isServerConnection():
self.handleUnexpectedPacket(conn, packet)
else:
app = self.app app = self.app
# Register new master nodes. # Register new master nodes.
for ip_address, port, uuid in known_master_list: for ip_address, port, uuid in known_master_list:
......
...@@ -18,13 +18,14 @@ ...@@ -18,13 +18,14 @@
import logging import logging
from neo.handler import EventHandler from neo.handler import EventHandler
from neo.protocol import Packet, \ from neo.protocol import Packet, UnexpectedPacketError, \
INVALID_UUID, RUNNING_STATE, BROKEN_STATE, \ INVALID_UUID, RUNNING_STATE, BROKEN_STATE, \
MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE
from neo.util import dump from neo.util import dump
from neo.node import MasterNode, StorageNode, ClientNode from neo.node import MasterNode, StorageNode, ClientNode
from neo.connection import ClientConnection from neo.connection import ClientConnection
from neo.exception import PrimaryFailure from neo.exception import PrimaryFailure
from neo.handler import identification_required, restrict_node_types
class StorageEventHandler(EventHandler): class StorageEventHandler(EventHandler):
"""This class implements a generic part of the event handlers.""" """This class implements a generic part of the event handlers."""
...@@ -69,24 +70,17 @@ class StorageEventHandler(EventHandler): ...@@ -69,24 +70,17 @@ class StorageEventHandler(EventHandler):
known_master_list): known_master_list):
raise NotImplementedError('this method must be overridden') raise NotImplementedError('this method must be overridden')
@identification_required
@restrict_node_types(MASTER_NODE_TYPE)
def handleAnnouncePrimaryMaster(self, conn, packet): def handleAnnouncePrimaryMaster(self, conn, packet):
"""Theoretically speaking, I should not get this message, """Theoretically speaking, I should not get this message,
because the primary master election must happen when I am because the primary master election must happen when I am
not connected to any master node.""" not connected to any master node."""
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None:
self.handleUnexpectedPacket(conn, packet)
return
app = self.app app = self.app
node = app.nm.getNodeByUUID(uuid) node = app.nm.getNodeByUUID(uuid)
if node is None: if node is None:
raise RuntimeError('I do not know the uuid %r' % dump(uuid)) raise RuntimeError('I do not know the uuid %r' % dump(uuid))
if node.getNodeType() != MASTER_NODE_TYPE:
self.handleUnexpectedPacket(conn, packet)
return
if app.primary_master_node is None: if app.primary_master_node is None:
# Hmm... I am somehow connected to the primary master already. # Hmm... I am somehow connected to the primary master already.
app.primary_master_node = node app.primary_master_node = node
...@@ -106,19 +100,16 @@ class StorageEventHandler(EventHandler): ...@@ -106,19 +100,16 @@ class StorageEventHandler(EventHandler):
def handleReelectPrimaryMaster(self, conn, packet): def handleReelectPrimaryMaster(self, conn, packet):
raise PrimaryFailure('re-election occurs') raise PrimaryFailure('re-election occurs')
@identification_required
@restrict_node_types(MASTER_NODE_TYPE)
def handleNotifyNodeInformation(self, conn, packet, node_list): def handleNotifyNodeInformation(self, conn, packet, node_list):
"""Store information on nodes, only if this is sent by a primary """Store information on nodes, only if this is sent by a primary
master node.""" master node."""
# XXX it might be better to implement this callback in each handler. # XXX it might be better to implement this callback in each handler.
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None:
self.handleUnexpectedPacket(conn, packet)
return
app = self.app app = self.app
node = app.nm.getNodeByUUID(uuid) node = app.nm.getNodeByUUID(uuid)
if node.getNodeType() != MASTER_NODE_TYPE \ if app.primary_master_node is None \
or app.primary_master_node is None \
or app.primary_master_node.getUUID() != uuid: or app.primary_master_node.getUUID() != uuid:
return return
...@@ -209,21 +200,21 @@ class StorageEventHandler(EventHandler): ...@@ -209,21 +200,21 @@ class StorageEventHandler(EventHandler):
raise NotImplementedError('this method must be overridden') raise NotImplementedError('this method must be overridden')
def handleAskObject(self, conn, packet, oid, serial, tid): def handleAskObject(self, conn, packet, oid, serial, tid):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleAskTIDs(self, conn, packet, first, last, partition): def handleAskTIDs(self, conn, packet, first, last, partition):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleAskObjectHistory(self, conn, packet, oid, first, last): def handleAskObjectHistory(self, conn, packet, oid, first, last):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleAskStoreTransaction(self, conn, packet, tid, user, desc, def handleAskStoreTransaction(self, conn, packet, tid, user, desc,
ext, oid_list): ext, oid_list):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleAskStoreObject(self, conn, packet, oid, serial, def handleAskStoreObject(self, conn, packet, oid, serial,
compression, checksum, data, tid): compression, checksum, data, tid):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleAbortTransaction(self, conn, packet, tid): def handleAbortTransaction(self, conn, packet, tid):
logging.info('ignoring abort transaction') logging.info('ignoring abort transaction')
......
...@@ -19,16 +19,18 @@ import logging ...@@ -19,16 +19,18 @@ import logging
from neo import protocol from neo import protocol
from neo.storage.handler import StorageEventHandler from neo.storage.handler import StorageEventHandler
from neo.protocol import INVALID_UUID, INVALID_SERIAL, INVALID_TID, \ from neo.protocol import INVALID_SERIAL, INVALID_TID, \
INVALID_PARTITION, \ INVALID_PARTITION, \
RUNNING_STATE, BROKEN_STATE, TEMPORARILY_DOWN_STATE, \ BROKEN_STATE, TEMPORARILY_DOWN_STATE, \
MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \ MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \
DISCARDED_STATE, OUT_OF_DATE_STATE DISCARDED_STATE, OUT_OF_DATE_STATE
from neo.util import dump from neo.util import dump
from neo.node import MasterNode, StorageNode, ClientNode from neo.node import MasterNode, StorageNode, ClientNode
from neo.connection import ClientConnection from neo.connection import ClientConnection
from neo.protocol import Packet from neo.protocol import Packet, UnexpectedPacketError
from neo.exception import PrimaryFailure, OperationFailure from neo.exception import PrimaryFailure, OperationFailure
from neo.handler import identification_required, restrict_node_types, \
server_connection_required, client_connection_required
class TransactionInformation(object): class TransactionInformation(object):
"""This class represents information on a transaction.""" """This class represents information on a transaction."""
...@@ -131,11 +133,9 @@ class OperationEventHandler(StorageEventHandler): ...@@ -131,11 +133,9 @@ class OperationEventHandler(StorageEventHandler):
StorageEventHandler.peerBroken(self, conn) StorageEventHandler.peerBroken(self, conn)
@server_connection_required
def handleRequestNodeIdentification(self, conn, packet, node_type, def handleRequestNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port, name): uuid, ip_address, port, name):
if not conn.isServerConnection():
self.handleUnexpectedPacket(conn, packet)
else:
app = self.app app = self.app
if name != app.name: if name != app.name:
logging.error('reject an alien cluster') logging.error('reject an alien cluster')
...@@ -180,31 +180,29 @@ class OperationEventHandler(StorageEventHandler): ...@@ -180,31 +180,29 @@ class OperationEventHandler(StorageEventHandler):
if node_type == MASTER_NODE_TYPE: if node_type == MASTER_NODE_TYPE:
conn.abort() conn.abort()
@client_connection_required
def handleAcceptNodeIdentification(self, conn, packet, node_type, def handleAcceptNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port, uuid, ip_address, port,
num_partitions, num_replicas, your_uuid): num_partitions, num_replicas, your_uuid):
if not conn.isServerConnection():
raise NotImplementedError raise NotImplementedError
else:
self.handleUnexpectedPacket(conn, packet)
def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid, def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid,
known_master_list): known_master_list):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleAskLastIDs(self, conn, packet): def handleAskLastIDs(self, conn, packet):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleAskPartitionTable(self, conn, packet, offset_list): def handleAskPartitionTable(self, conn, packet, offset_list):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleSendPartitionTable(self, conn, packet, ptid, row_list): def handleSendPartitionTable(self, conn, packet, ptid, row_list):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
@client_connection_required
def handleNotifyPartitionChanges(self, conn, packet, ptid, cell_list): def handleNotifyPartitionChanges(self, conn, packet, ptid, cell_list):
"""This is very similar to Send Partition Table, except that """This is very similar to Send Partition Table, except that
the information is only about changes from the previous.""" the information is only about changes from the previous."""
if not conn.isServerConnection():
app = self.app app = self.app
nm = app.nm nm = app.nm
pt = app.pt pt = app.pt
...@@ -234,20 +232,16 @@ class OperationEventHandler(StorageEventHandler): ...@@ -234,20 +232,16 @@ class OperationEventHandler(StorageEventHandler):
# Then, the database. # Then, the database.
app.dm.changePartitionTable(ptid, cell_list) app.dm.changePartitionTable(ptid, cell_list)
else:
self.handleUnexpectedPacket(conn, packet)
def handleStartOperation(self, conn, packet): def handleStartOperation(self, conn, packet):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
@client_connection_required
def handleStopOperation(self, conn, packet): def handleStopOperation(self, conn, packet):
if not conn.isServerConnection():
raise OperationFailure('operation stopped') raise OperationFailure('operation stopped')
else:
self.handleUnexpectedPacket(conn, packet)
def handleAskUnfinishedTransactions(self, conn, packet): def handleAskUnfinishedTransactions(self, conn, packet):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleAskTransactionInformation(self, conn, packet, tid): def handleAskTransactionInformation(self, conn, packet, tid):
app = self.app app = self.app
...@@ -260,16 +254,16 @@ class OperationEventHandler(StorageEventHandler): ...@@ -260,16 +254,16 @@ class OperationEventHandler(StorageEventHandler):
conn.answer(p, packet) conn.answer(p, packet)
def handleAskObjectPresent(self, conn, packet, oid, tid): def handleAskObjectPresent(self, conn, packet, oid, tid):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleDeleteTransaction(self, conn, packet, tid): def handleDeleteTransaction(self, conn, packet, tid):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
def handleCommitTransaction(self, conn, packet, tid): def handleCommitTransaction(self, conn, packet, tid):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
@client_connection_required
def handleLockInformation(self, conn, packet, tid): def handleLockInformation(self, conn, packet, tid):
if not conn.isServerConnection():
app = self.app app = self.app
try: try:
t = app.transaction_dict[tid] t = app.transaction_dict[tid]
...@@ -280,13 +274,10 @@ class OperationEventHandler(StorageEventHandler): ...@@ -280,13 +274,10 @@ class OperationEventHandler(StorageEventHandler):
app.dm.storeTransaction(tid, object_list, t.getTransaction()) app.dm.storeTransaction(tid, object_list, t.getTransaction())
except KeyError: except KeyError:
pass pass
conn.answer(protocol.notifyInformationLocked(tid), packet) conn.answer(protocol.notifyInformationLocked(tid), packet)
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleUnlockInformation(self, conn, packet, tid): def handleUnlockInformation(self, conn, packet, tid):
if not conn.isServerConnection():
app = self.app app = self.app
try: try:
t = app.transaction_dict[tid] t = app.transaction_dict[tid]
...@@ -303,8 +294,6 @@ class OperationEventHandler(StorageEventHandler): ...@@ -303,8 +294,6 @@ class OperationEventHandler(StorageEventHandler):
app.executeQueuedEvents() app.executeQueuedEvents()
except KeyError: except KeyError:
pass pass
else:
self.handleUnexpectedPacket(conn, packet)
def handleAskObject(self, conn, packet, oid, serial, tid): def handleAskObject(self, conn, packet, oid, serial, tid):
app = self.app app = self.app
...@@ -369,25 +358,19 @@ class OperationEventHandler(StorageEventHandler): ...@@ -369,25 +358,19 @@ class OperationEventHandler(StorageEventHandler):
p = protocol.answerObjectHistory(oid, history_list) p = protocol.answerObjectHistory(oid, history_list)
conn.answer(p, packet) conn.answer(p, packet)
@identification_required
def handleAskStoreTransaction(self, conn, packet, tid, user, desc, def handleAskStoreTransaction(self, conn, packet, tid, user, desc,
ext, oid_list): ext, oid_list):
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None:
self.handleUnexpectedPacket(conn, packet)
return
app = self.app app = self.app
t = app.transaction_dict.setdefault(tid, TransactionInformation(uuid)) t = app.transaction_dict.setdefault(tid, TransactionInformation(uuid))
t.addTransaction(oid_list, user, desc, ext) t.addTransaction(oid_list, user, desc, ext)
conn.answer(protocol.answerStoreTransaction(tid), packet) conn.answer(protocol.answerStoreTransaction(tid), packet)
@identification_required
def handleAskStoreObject(self, conn, packet, oid, serial, def handleAskStoreObject(self, conn, packet, oid, serial,
compression, checksum, data, tid): compression, checksum, data, tid):
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None:
self.handleUnexpectedPacket(conn, packet)
return
# First, check for the locking state. # First, check for the locking state.
app = self.app app = self.app
locking_tid = app.store_lock_dict.get(oid) locking_tid = app.store_lock_dict.get(oid)
...@@ -421,12 +404,9 @@ class OperationEventHandler(StorageEventHandler): ...@@ -421,12 +404,9 @@ class OperationEventHandler(StorageEventHandler):
conn.answer(p, packet) conn.answer(p, packet)
app.store_lock_dict[oid] = tid app.store_lock_dict[oid] = tid
@identification_required
def handleAbortTransaction(self, conn, packet, tid): def handleAbortTransaction(self, conn, packet, tid):
uuid = conn.getUUID() uuid = conn.getUUID()
if uuid is None:
self.handleUnexpectedPacket(conn, packet)
return
app = self.app app = self.app
try: try:
t = app.transaction_dict[tid] t = app.transaction_dict[tid]
...@@ -446,17 +426,13 @@ class OperationEventHandler(StorageEventHandler): ...@@ -446,17 +426,13 @@ class OperationEventHandler(StorageEventHandler):
except KeyError: except KeyError:
pass pass
@client_connection_required
def handleAnswerLastIDs(self, conn, packet, loid, ltid, lptid): def handleAnswerLastIDs(self, conn, packet, loid, ltid, lptid):
if not conn.isServerConnection():
self.app.replicator.setCriticalTID(packet, ltid) self.app.replicator.setCriticalTID(packet, ltid)
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleAnswerUnfinishedTransactions(self, conn, packet, tid_list): def handleAnswerUnfinishedTransactions(self, conn, packet, tid_list):
if not conn.isServerConnection():
self.app.replicator.setUnfinishedTIDList(tid_list) self.app.replicator.setUnfinishedTIDList(tid_list)
else:
self.handleUnexpectedPacket(conn, packet)
def handleAskOIDs(self, conn, packet, first, last, partition): def handleAskOIDs(self, conn, packet, first, last, partition):
# This method is complicated, because I must return OIDs only # This method is complicated, because I must return OIDs only
......
...@@ -109,6 +109,14 @@ server: 127.0.0.1:10020 ...@@ -109,6 +109,14 @@ server: 127.0.0.1:10020
def getLastUUID(self): def getLastUUID(self):
return self.uuid return self.uuid
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
# Method to test the kind of packet returned in answer # Method to test the kind of packet returned in answer
def checkCalledRequestNodeIdentification(self, conn, packet_number=0): def checkCalledRequestNodeIdentification(self, conn, packet_number=0):
""" Check Request Node Identification has been send""" """ Check Request Node Identification has been send"""
......
...@@ -50,6 +50,14 @@ class StorageOperationTests(unittest.TestCase): ...@@ -50,6 +50,14 @@ class StorageOperationTests(unittest.TestCase):
return min(ptids), max(ptids) return min(ptids), max(ptids)
ptid = min(ptids) ptid = min(ptids)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkCalledAbort(self, conn, packet_number=0): def checkCalledAbort(self, conn, packet_number=0):
"""Check the abort method has been called and an error packet has been sent""" """Check the abort method has been called and an error packet has been sent"""
# sometimes we answer an error, sometimes we just send it # sometimes we answer an error, sometimes we just send it
...@@ -81,9 +89,7 @@ class StorageOperationTests(unittest.TestCase): ...@@ -81,9 +89,7 @@ class StorageOperationTests(unittest.TestCase):
packet = Packet(msg_type=_msg_type) packet = Packet(msg_type=_msg_type)
# hook # hook
self.operation.peerBroken = lambda c: c.peerBrokendCalled() self.operation.peerBroken = lambda c: c.peerBrokendCalled()
_call(conn=conn, packet=packet, **kwargs) self.checkUnexpectedPacketRaised(_call, conn=conn, packet=packet, **kwargs)
self.checkCalledAbort(conn)
self.assertEquals(len(conn.mockGetNamedCalls("peerBrokendCalled")), 1)
def checkNoPacketSent(self, conn): def checkNoPacketSent(self, conn):
# no packet should be sent # no packet should be sent
......
...@@ -36,7 +36,7 @@ from neo.protocol import ACCEPT_NODE_IDENTIFICATION, REQUEST_NODE_IDENTIFICATION ...@@ -36,7 +36,7 @@ from neo.protocol import ACCEPT_NODE_IDENTIFICATION, REQUEST_NODE_IDENTIFICATION
UNLOCK_INFORMATION, TID_NOT_FOUND_CODE, ASK_TRANSACTION_INFORMATION, ANSWER_TRANSACTION_INFORMATION, \ UNLOCK_INFORMATION, TID_NOT_FOUND_CODE, ASK_TRANSACTION_INFORMATION, ANSWER_TRANSACTION_INFORMATION, \
ANSWER_PARTITION_TABLE,SEND_PARTITION_TABLE, COMMIT_TRANSACTION ANSWER_PARTITION_TABLE,SEND_PARTITION_TABLE, COMMIT_TRANSACTION
from neo.protocol import ERROR, BROKEN_NODE_DISALLOWED_CODE, ASK_PRIMARY_MASTER from neo.protocol import ERROR, BROKEN_NODE_DISALLOWED_CODE, ASK_PRIMARY_MASTER
from neo.protocol import ANSWER_PRIMARY_MASTER from neo.protocol import ANSWER_PRIMARY_MASTER, UnexpectedPacketError
from neo.exception import PrimaryFailure, OperationFailure from neo.exception import PrimaryFailure, OperationFailure
from neo.storage.mysqldb import MySQLDatabaseManager, p64, u64 from neo.storage.mysqldb import MySQLDatabaseManager, p64, u64
...@@ -127,6 +127,14 @@ server: 127.0.0.1:10020 ...@@ -127,6 +127,14 @@ server: 127.0.0.1:10020
return min(ptids), max(ptids) return min(ptids), max(ptids)
ptid = min(ptids) ptid = min(ptids)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkCalledAbort(self, conn, packet_number=0): def checkCalledAbort(self, conn, packet_number=0):
"""Check the abort method has been called and an error packet has been sent""" """Check the abort method has been called and an error packet has been sent"""
# sometimes we answer an error, sometimes we just notify it # sometimes we answer an error, sometimes we just notify it
...@@ -306,9 +314,8 @@ server: 127.0.0.1:10020 ...@@ -306,9 +314,8 @@ server: 127.0.0.1:10020
"getAddress" : ("127.0.0.1", self.client_port), "getAddress" : ("127.0.0.1", self.client_port),
"isServerConnection" : True}) "isServerConnection" : True})
p = Packet(msg_type=ACCEPT_NODE_IDENTIFICATION) p = Packet(msg_type=ACCEPT_NODE_IDENTIFICATION)
self.verification.handleAcceptNodeIdentification(conn, p, CLIENT_NODE_TYPE, self.checkUnexpectedPacketRaised(self.verification.handleAcceptNodeIdentification,
self.getNewUUID(),"127.0.0.1", self.client_port, 1009, 2, uuid) conn, p, CLIENT_NODE_TYPE, self.getNewUUID(),"127.0.0.1", self.client_port, 1009, 2, uuid)
self.checkCalledAbort(conn)
def test_07_handleAnswerPrimaryMaster(self): def test_07_handleAnswerPrimaryMaster(self):
# reject server connection # reject server connection
...@@ -317,8 +324,7 @@ server: 127.0.0.1:10020 ...@@ -317,8 +324,7 @@ server: 127.0.0.1:10020
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.client_port), "getAddress" : ("127.0.0.1", self.client_port),
"isServerConnection" : True}) "isServerConnection" : True})
self.verification.handleAnswerPrimaryMaster(conn, packet,self.getNewUUID(), ()) self.checkUnexpectedPacketRaised(self.verification.handleAnswerPrimaryMaster, conn, packet,self.getNewUUID(), ())
self.checkCalledAbort(conn)
# raise id uuid is different # raise id uuid is different
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
...@@ -343,8 +349,7 @@ server: 127.0.0.1:10020 ...@@ -343,8 +349,7 @@ server: 127.0.0.1:10020
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.client_port), "getAddress" : ("127.0.0.1", self.client_port),
"isServerConnection" : True}) "isServerConnection" : True})
self.verification.handleAskLastIDs(conn, packet) self.checkUnexpectedPacketRaised(self.verification.handleAskLastIDs, conn, packet)
self.checkCalledAbort(conn)
# return invalid if db store nothing # return invalid if db store nothing
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
...@@ -402,8 +407,7 @@ server: 127.0.0.1:10020 ...@@ -402,8 +407,7 @@ server: 127.0.0.1:10020
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.client_port), "getAddress" : ("127.0.0.1", self.client_port),
"isServerConnection" : True}) "isServerConnection" : True})
self.verification.handleAskPartitionTable(conn, packet, [1,]) self.checkUnexpectedPacketRaised(self.verification.handleAskPartitionTable, conn, packet, [1,])
self.checkCalledAbort(conn)
# try to get unknown offset # try to get unknown offset
self.assertEqual(len(self.app.pt.getNodeList()), 0) self.assertEqual(len(self.app.pt.getNodeList()), 0)
...@@ -449,9 +453,8 @@ server: 127.0.0.1:10020 ...@@ -449,9 +453,8 @@ server: 127.0.0.1:10020
"getAddress" : ("127.0.0.1", self.client_port), "getAddress" : ("127.0.0.1", self.client_port),
"isServerConnection" : True}) "isServerConnection" : True})
self.app.ptid = 1 self.app.ptid = 1
self.verification.handleSendPartitionTable(conn, packet, 0, ()) self.checkUnexpectedPacketRaised(self.verification.handleSendPartitionTable, conn, packet, 0, ())
self.assertEquals(self.app.ptid, 1) self.assertEquals(self.app.ptid, 1)
self.checkCalledAbort(conn)
# send a table # send a table
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
...@@ -496,9 +499,8 @@ server: 127.0.0.1:10020 ...@@ -496,9 +499,8 @@ server: 127.0.0.1:10020
"getAddress" : ("127.0.0.1", self.client_port), "getAddress" : ("127.0.0.1", self.client_port),
"isServerConnection" : True}) "isServerConnection" : True})
self.app.ptid = 1 self.app.ptid = 1
self.verification.handleNotifyPartitionChanges(conn, packet, 0, ()) self.checkUnexpectedPacketRaised(self.verification.handleNotifyPartitionChanges, conn, packet, 0, ())
self.assertEquals(self.app.ptid, 1) self.assertEquals(self.app.ptid, 1)
self.checkCalledAbort(conn)
# old partition change # old partition change
conn = Mock({ conn = Mock({
...@@ -534,8 +536,7 @@ server: 127.0.0.1:10020 ...@@ -534,8 +536,7 @@ server: 127.0.0.1:10020
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': True }) 'isServerConnection': True })
packet = Packet(msg_type=STOP_OPERATION) packet = Packet(msg_type=STOP_OPERATION)
self.verification.handleStartOperation(conn, packet) self.checkUnexpectedPacketRaised(self.verification.handleStartOperation, conn, packet)
self.checkCalledAbort(conn)
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': False }) 'isServerConnection': False })
self.assertFalse(self.app.operational) self.assertFalse(self.app.operational)
...@@ -547,8 +548,7 @@ server: 127.0.0.1:10020 ...@@ -547,8 +548,7 @@ server: 127.0.0.1:10020
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': True }) 'isServerConnection': True })
packet = Packet(msg_type=STOP_OPERATION) packet = Packet(msg_type=STOP_OPERATION)
self.verification.handleStopOperation(conn, packet) self.checkUnexpectedPacketRaised(self.verification.handleStopOperation, conn, packet)
self.checkCalledAbort(conn)
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': False }) 'isServerConnection': False })
packet = Packet(msg_type=STOP_OPERATION) packet = Packet(msg_type=STOP_OPERATION)
...@@ -559,8 +559,7 @@ server: 127.0.0.1:10020 ...@@ -559,8 +559,7 @@ server: 127.0.0.1:10020
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': True }) 'isServerConnection': True })
packet = Packet(msg_type=ASK_UNFINISHED_TRANSACTIONS) packet = Packet(msg_type=ASK_UNFINISHED_TRANSACTIONS)
self.verification.handleAskUnfinishedTransactions(conn, packet) self.checkUnexpectedPacketRaised(self.verification.handleAskUnfinishedTransactions, conn, packet)
self.checkCalledAbort(conn)
# client connection with no data # client connection with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': False}) 'isServerConnection': False})
...@@ -688,8 +687,7 @@ server: 127.0.0.1:10020 ...@@ -688,8 +687,7 @@ server: 127.0.0.1:10020
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': True }) 'isServerConnection': True })
packet = Packet(msg_type=ASK_OBJECT_PRESENT) packet = Packet(msg_type=ASK_OBJECT_PRESENT)
self.verification.handleAskObjectPresent(conn, packet, p64(1), p64(2)) self.checkUnexpectedPacketRaised(self.verification.handleAskObjectPresent, conn, packet, p64(1), p64(2))
self.checkCalledAbort(conn)
# client connection with no data # client connection with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': False}) 'isServerConnection': False})
...@@ -724,8 +722,7 @@ server: 127.0.0.1:10020 ...@@ -724,8 +722,7 @@ server: 127.0.0.1:10020
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': True }) 'isServerConnection': True })
packet = Packet(msg_type=ASK_OBJECT_PRESENT) packet = Packet(msg_type=ASK_OBJECT_PRESENT)
self.verification.handleDeleteTransaction(conn, packet, p64(1)) self.checkUnexpectedPacketRaised(self.verification.handleDeleteTransaction, conn, packet, p64(1))
self.checkCalledAbort(conn)
# client connection with no data # client connection with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServerConnection': False}) 'isServerConnection': False})
...@@ -747,8 +744,7 @@ server: 127.0.0.1:10020 ...@@ -747,8 +744,7 @@ server: 127.0.0.1:10020
dm = Mock() dm = Mock()
self.app.dm = dm self.app.dm = dm
packet = Packet(msg_type=COMMIT_TRANSACTION) packet = Packet(msg_type=COMMIT_TRANSACTION)
self.verification.handleCommitTransaction(conn, packet, p64(1)) self.checkUnexpectedPacketRaised(self.verification.handleCommitTransaction, conn, packet, p64(1))
self.checkCalledAbort(conn)
self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 0) self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 0)
# commit a transaction # commit a transaction
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
......
...@@ -20,13 +20,15 @@ import logging ...@@ -20,13 +20,15 @@ import logging
from neo import protocol from neo import protocol
from neo.storage.handler import StorageEventHandler from neo.storage.handler import StorageEventHandler
from neo.protocol import INVALID_OID, INVALID_TID, \ from neo.protocol import INVALID_OID, INVALID_TID, \
RUNNING_STATE, BROKEN_STATE, TEMPORARILY_DOWN_STATE, \ BROKEN_STATE, TEMPORARILY_DOWN_STATE, \
MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \ MASTER_NODE_TYPE, STORAGE_NODE_TYPE, \
Packet Packet, UnexpectedPacketError
from neo.util import dump from neo.util import dump
from neo.node import MasterNode, StorageNode, ClientNode from neo.node import MasterNode, StorageNode, ClientNode
from neo.connection import ClientConnection from neo.connection import ClientConnection
from neo.exception import PrimaryFailure, OperationFailure from neo.exception import PrimaryFailure, OperationFailure
from neo.handler import identification_required, restrict_node_types, \
server_connection_required, client_connection_required
class VerificationEventHandler(StorageEventHandler): class VerificationEventHandler(StorageEventHandler):
"""This class deals with events for a verification phase.""" """This class deals with events for a verification phase."""
...@@ -61,11 +63,9 @@ class VerificationEventHandler(StorageEventHandler): ...@@ -61,11 +63,9 @@ class VerificationEventHandler(StorageEventHandler):
StorageEventHandler.peerBroken(self, conn) StorageEventHandler.peerBroken(self, conn)
@server_connection_required
def handleRequestNodeIdentification(self, conn, packet, node_type, def handleRequestNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port, name): uuid, ip_address, port, name):
if not conn.isServerConnection():
self.handleUnexpectedPacket(conn, packet)
else:
app = self.app app = self.app
if node_type != MASTER_NODE_TYPE: if node_type != MASTER_NODE_TYPE:
logging.info('reject a connection from a non-master') logging.info('reject a connection from a non-master')
...@@ -108,32 +108,28 @@ class VerificationEventHandler(StorageEventHandler): ...@@ -108,32 +108,28 @@ class VerificationEventHandler(StorageEventHandler):
def handleAcceptNodeIdentification(self, conn, packet, node_type, def handleAcceptNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port, uuid, ip_address, port,
num_partitions, num_replicas, your_uuid): num_partitions, num_replicas, your_uuid):
self.handleUnexpectedPacket(conn, packet) raise UnexpectedPacketError
@client_connection_required
def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid, def handleAnswerPrimaryMaster(self, conn, packet, primary_uuid,
known_master_list): known_master_list):
if not conn.isServerConnection():
app = self.app app = self.app
if app.primary_master_node.getUUID() != primary_uuid: if app.primary_master_node.getUUID() != primary_uuid:
raise PrimaryFailure('the primary master node seems to have changed') raise PrimaryFailure('the primary master node seems to have changed')
# XXX is it better to deal with known_master_list here? # XXX is it better to deal with known_master_list here?
# But a primary master node is supposed not to send any info # But a primary master node is supposed not to send any info
# with this packet, so it would be useless. # with this packet, so it would be useless.
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleAskLastIDs(self, conn, packet): def handleAskLastIDs(self, conn, packet):
if not conn.isServerConnection():
app = self.app app = self.app
oid = app.dm.getLastOID() or INVALID_OID oid = app.dm.getLastOID() or INVALID_OID
tid = app.dm.getLastTID() or INVALID_TID tid = app.dm.getLastTID() or INVALID_TID
p = protocol.answerLastIDs(oid, tid, app.ptid) p = protocol.answerLastIDs(oid, tid, app.ptid)
conn.answer(p, packet) conn.answer(p, packet)
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleAskPartitionTable(self, conn, packet, offset_list): def handleAskPartitionTable(self, conn, packet, offset_list):
if not conn.isServerConnection():
app = self.app app = self.app
row_list = [] row_list = []
try: try:
...@@ -152,13 +148,11 @@ class VerificationEventHandler(StorageEventHandler): ...@@ -152,13 +148,11 @@ class VerificationEventHandler(StorageEventHandler):
p = protocol.answerPartitionTable(app.ptid, row_list) p = protocol.answerPartitionTable(app.ptid, row_list)
conn.answer(p, packet) conn.answer(p, packet)
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleSendPartitionTable(self, conn, packet, ptid, row_list): def handleSendPartitionTable(self, conn, packet, ptid, row_list):
"""A primary master node sends this packet to synchronize a partition """A primary master node sends this packet to synchronize a partition
table. Note that the message can be split into multiple packets.""" table. Note that the message can be split into multiple packets."""
if not conn.isServerConnection():
app = self.app app = self.app
nm = app.nm nm = app.nm
pt = app.pt pt = app.pt
...@@ -186,13 +180,11 @@ class VerificationEventHandler(StorageEventHandler): ...@@ -186,13 +180,11 @@ class VerificationEventHandler(StorageEventHandler):
cell_list.append((offset, cell.getUUID(), cell_list.append((offset, cell.getUUID(),
cell.getState())) cell.getState()))
app.dm.setPartitionTable(ptid, cell_list) app.dm.setPartitionTable(ptid, cell_list)
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleNotifyPartitionChanges(self, conn, packet, ptid, cell_list): def handleNotifyPartitionChanges(self, conn, packet, ptid, cell_list):
"""This is very similar to Send Partition Table, except that """This is very similar to Send Partition Table, except that
the information is only about changes from the previous.""" the information is only about changes from the previous."""
if not conn.isServerConnection():
app = self.app app = self.app
nm = app.nm nm = app.nm
pt = app.pt pt = app.pt
...@@ -215,29 +207,20 @@ class VerificationEventHandler(StorageEventHandler): ...@@ -215,29 +207,20 @@ class VerificationEventHandler(StorageEventHandler):
# Then, the database. # Then, the database.
app.dm.changePartitionTable(ptid, cell_list) app.dm.changePartitionTable(ptid, cell_list)
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleStartOperation(self, conn, packet): def handleStartOperation(self, conn, packet):
if not conn.isServerConnection():
self.app.operational = True self.app.operational = True
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleStopOperation(self, conn, packet): def handleStopOperation(self, conn, packet):
if not conn.isServerConnection():
raise OperationFailure('operation stopped') raise OperationFailure('operation stopped')
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleAskUnfinishedTransactions(self, conn, packet): def handleAskUnfinishedTransactions(self, conn, packet):
if not conn.isServerConnection(): tid_list = self.app.dm.getUnfinishedTIDList()
app = self.app
tid_list = app.dm.getUnfinishedTIDList()
p = protocol.answerUnfinishedTransactions(tid_list) p = protocol.answerUnfinishedTransactions(tid_list)
conn.answer(p, packet) conn.answer(p, packet)
else:
self.handleUnexpectedPacket(conn, packet)
def handleAskTransactionInformation(self, conn, packet, tid): def handleAskTransactionInformation(self, conn, packet, tid):
app = self.app app = self.app
...@@ -255,31 +238,22 @@ class VerificationEventHandler(StorageEventHandler): ...@@ -255,31 +238,22 @@ class VerificationEventHandler(StorageEventHandler):
p = protocol.answerTransactionInformation(tid, t[1], t[2], t[3], t[0]) p = protocol.answerTransactionInformation(tid, t[1], t[2], t[3], t[0])
conn.answer(p, packet) conn.answer(p, packet)
@client_connection_required
def handleAskObjectPresent(self, conn, packet, oid, tid): def handleAskObjectPresent(self, conn, packet, oid, tid):
if not conn.isServerConnection(): if self.app.dm.objectPresent(oid, tid):
app = self.app
if app.dm.objectPresent(oid, tid):
p = protocol.answerObjectPresent(oid, tid) p = protocol.answerObjectPresent(oid, tid)
else: else:
p = protocol.oidNotFound( p = protocol.oidNotFound(
'%s:%s do not exist' % (dump(oid), dump(tid))) '%s:%s do not exist' % (dump(oid), dump(tid)))
conn.answer(p, packet) conn.answer(p, packet)
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleDeleteTransaction(self, conn, packet, tid): def handleDeleteTransaction(self, conn, packet, tid):
if not conn.isServerConnection(): self.app.dm.deleteTransaction(tid, all = True)
app = self.app
app.dm.deleteTransaction(tid, all = True)
else:
self.handleUnexpectedPacket(conn, packet)
@client_connection_required
def handleCommitTransaction(self, conn, packet, tid): def handleCommitTransaction(self, conn, packet, tid):
if not conn.isServerConnection(): self.app.dm.finishTransaction(tid)
app = self.app
app.dm.finishTransaction(tid)
else:
self.handleUnexpectedPacket(conn, packet)
def handleLockInformation(self, conn, packet, tid): def handleLockInformation(self, conn, packet, tid):
pass pass
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment