Commit eea5aff2 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Raise an exception when a broken node send packet through a connection. Fix

tests according to this changes and fix some others affected with previous
commits.


git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@506 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent aded81bc
...@@ -841,11 +841,12 @@ class Application(object): ...@@ -841,11 +841,12 @@ class Application(object):
def __del__(self): def __del__(self):
"""Clear all connection.""" """Clear all connection."""
# TODO: Stop polling thread here.
# Due to bug in ZODB, close is not always called when shutting # Due to bug in ZODB, close is not always called when shutting
# down zope, so use __del__ to close connections # down zope, so use __del__ to close connections
for conn in self.em.getConnectionList(): for conn in self.em.getConnectionList():
conn.close() conn.close()
# Stop polling thread
self.poll_thread.stop()
close = __del__ close = __del__
def sync(self): def sync(self):
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from threading import Thread from threading import Thread, Event
import logging import logging
class ThreadedPoll(Thread): class ThreadedPoll(Thread):
...@@ -25,10 +25,11 @@ class ThreadedPoll(Thread): ...@@ -25,10 +25,11 @@ class ThreadedPoll(Thread):
Thread.__init__(self, **kw) Thread.__init__(self, **kw)
self.em = em self.em = em
self.setDaemon(True) self.setDaemon(True)
self._stop = Event()
self.start() self.start()
def run(self): def run(self):
while 1: while not self._stop.isSet():
# First check if we receive any new message from other node # First check if we receive any new message from other node
try: try:
self.em.poll() self.em.poll()
...@@ -36,4 +37,7 @@ class ThreadedPoll(Thread): ...@@ -36,4 +37,7 @@ class ThreadedPoll(Thread):
# This happen when there is no connection # This happen when there is no connection
# XXX: This should be handled inside event manager, not here. # XXX: This should be handled inside event manager, not here.
logging.error('Dispatcher, run, poll returned a KeyError') logging.error('Dispatcher, run, poll returned a KeyError')
logging.info('Threaded poll stopped')
def stop(self):
self._stop.set()
...@@ -335,17 +335,8 @@ class ClientEventHandlerTest(unittest.TestCase): ...@@ -335,17 +335,8 @@ class ClientEventHandlerTest(unittest.TestCase):
self.assertEquals(app.uuid, 'C' * 16) self.assertEquals(app.uuid, 'C' * 16)
def _testHandleUnexpectedPacketCalledWithMedhod(self, client_handler, method, args=(), kw=()): def _testHandleUnexpectedPacketCalledWithMedhod(self, client_handler, method, args=(), kw=()):
# Monkey-patch handleUnexpectedPacket to check if it is called
call_list = [] call_list = []
def ClientHandler_handleUnexpectedPacket(self, conn, packet):
call_list.append((conn, packet))
original_handleUnexpectedPacket = client_handler.__class__.handleUnexpectedPacket
client_handler.__class__.handleUnexpectedPacket = ClientHandler_handleUnexpectedPacket
try:
self.assertRaises(UnexpectedPacketError, method, *args, **dict(kw)) self.assertRaises(UnexpectedPacketError, method, *args, **dict(kw))
finally:
# Restore original method
client_handler.__class__.handleUnexpectedPacket = original_handleUnexpectedPacket
# Master node handler # Master node handler
def test_initialAnswerPrimaryMaster(self): def test_initialAnswerPrimaryMaster(self):
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
import logging import logging
from neo import protocol from neo import protocol
from neo.protocol import Packet, PacketMalformedError, UnexpectedPacketError from neo.protocol import Packet, PacketMalformedError, UnexpectedPacketError, \
BrokenNotDisallowedError, NotReadyError
from neo.connection import ServerConnection from neo.connection import ServerConnection
from protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \ from protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
...@@ -156,8 +157,11 @@ class EventHandler(object): ...@@ -156,8 +157,11 @@ class EventHandler(object):
conn.notify(protocol.protocolError(message)) conn.notify(protocol.protocolError(message))
conn.abort() conn.abort()
self.peerBroken(conn) self.peerBroken(conn)
# TODO: remove this old method name
handleUnexpectedPacket = unexpectedPacket def brokenNodeDisallowedError(conn, packet, message=None):
""" Called when a broken node send packets """
conn.notify(protocol.brokenNodeDisallowedError('go away'))
conn.abort()
def dispatch(self, conn, packet): def dispatch(self, conn, packet):
"""This is a helper method to handle various packet types.""" """This is a helper method to handle various packet types."""
...@@ -172,6 +176,9 @@ class EventHandler(object): ...@@ -172,6 +176,9 @@ class EventHandler(object):
self.unexpectedPacket(conn, packet, msg) self.unexpectedPacket(conn, packet, msg)
except PacketMalformedError, msg: except PacketMalformedError, msg:
self.packetMalformed(conn, packet, msg) self.packetMalformed(conn, packet, msg)
except BrokenNotDisallowedError, msg:
self.brokenNodeDisallowedError(conn, packet, msg)
# Packet handlers. # Packet handlers.
......
...@@ -29,10 +29,6 @@ from neo.node import MasterNode, StorageNode, ClientNode ...@@ -29,10 +29,6 @@ from neo.node import MasterNode, StorageNode, ClientNode
from neo.handler import identification_required, restrict_node_types, \ from neo.handler import identification_required, restrict_node_types, \
client_connection_required, server_connection_required client_connection_required, server_connection_required
# TODO: finalize decorators integration (identification, restriction, client...)
# TODO: here use specific decorator such as restrict_node_types which do custom
# operations such as send retryLater instead of unexpectedPacket
class ElectionEventHandler(MasterEventHandler): class ElectionEventHandler(MasterEventHandler):
"""This class deals with events for a primary master election.""" """This class deals with events for a primary master election."""
...@@ -198,10 +194,7 @@ class ElectionEventHandler(MasterEventHandler): ...@@ -198,10 +194,7 @@ class ElectionEventHandler(MasterEventHandler):
# If this node is broken, reject it. # If this node is broken, reject it.
if node.getUUID() == uuid: if node.getUUID() == uuid:
if node.getState() == BROKEN_STATE: if node.getState() == BROKEN_STATE:
conn.answer(protocol.brokenNodeDisallowedError( raise protocol.BrokenNotDisallowedError
'go away'), packet)
conn.abort()
return
# supplied another uuid in case of conflict # supplied another uuid in case of conflict
while not app.isValidUUID(uuid, addr): while not app.isValidUUID(uuid, addr):
......
...@@ -166,11 +166,7 @@ class RecoveryEventHandler(MasterEventHandler): ...@@ -166,11 +166,7 @@ class RecoveryEventHandler(MasterEventHandler):
# If this node is broken, reject it. Otherwise, assume that it is # If this node is broken, reject it. Otherwise, assume that it is
# working again. # working again.
if node.getState() == BROKEN_STATE: if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away') raise protocol.BrokenNotDisallowedError
conn.answer(p, packet)
conn.abort()
return
else:
node.setUUID(uuid) node.setUUID(uuid)
node.setState(RUNNING_STATE) node.setState(RUNNING_STATE)
app.broadcastNodeInformation(node) app.broadcastNodeInformation(node)
......
...@@ -263,10 +263,7 @@ class ServiceEventHandler(MasterEventHandler): ...@@ -263,10 +263,7 @@ class ServiceEventHandler(MasterEventHandler):
# If this node is broken, reject it. Otherwise, assume that # If this node is broken, reject it. Otherwise, assume that
# it is working again. # it is working again.
if node.getState() == BROKEN_STATE: if node.getState() == BROKEN_STATE:
conn.notify(protocol.brokenNodeDisallowedError('go away')) raise protocol.BrokenNotDisallowedError
conn.abort()
return
else:
node.setUUID(uuid) node.setUUID(uuid)
node.setState(RUNNING_STATE) node.setState(RUNNING_STATE)
logging.debug('broadcasting node information') logging.debug('broadcasting node information')
......
...@@ -22,7 +22,7 @@ from tempfile import mkstemp ...@@ -22,7 +22,7 @@ from tempfile import mkstemp
from mock import Mock from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo import protocol from neo import protocol
from neo.protocol import Packet, UnexpectedPacketError, INVALID_UUID from neo.protocol import Packet, INVALID_UUID
from neo.master.election import ElectionEventHandler from neo.master.election import ElectionEventHandler
from neo.master.app import Application from neo.master.app import Application
from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \ from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
...@@ -136,12 +136,16 @@ server: 127.0.0.1:10023 ...@@ -136,12 +136,16 @@ server: 127.0.0.1:10023
def checkUnexpectedPacketRaised(self, method, *args, **kwargs): def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """ """ Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs) self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs): def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """ """ Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs) self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0): def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send""" """ Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1) self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
...@@ -577,14 +581,15 @@ server: 127.0.0.1:10023 ...@@ -577,14 +581,15 @@ server: 127.0.0.1:10023
self.assertEqual(node.getState(), RUNNING_STATE) self.assertEqual(node.getState(), RUNNING_STATE)
node.setState(BROKEN_STATE) node.setState(BROKEN_STATE)
self.assertEqual(node.getState(), BROKEN_STATE) self.assertEqual(node.getState(), BROKEN_STATE)
election.handleRequestNodeIdentification(conn, self.checkBrokenNotDisallowedErrorRaised(
election.handleRequestNodeIdentification,
conn,
packet=packet, packet=packet,
node_type=MASTER_NODE_TYPE, node_type=MASTER_NODE_TYPE,
uuid=new_uuid, uuid=new_uuid,
ip_address='127.0.0.1', ip_address='127.0.0.1',
port=self.master_port+1, port=self.master_port+1,
name=self.app.name,) name=self.app.name,)
self.checkCalledAbort(conn)
def test_11_handleAskPrimaryMaster(self): def test_11_handleAskPrimaryMaster(self):
......
...@@ -22,7 +22,7 @@ from tempfile import mkstemp ...@@ -22,7 +22,7 @@ from tempfile import mkstemp
from mock import Mock from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo import protocol from neo import protocol
from neo.protocol import Packet, UnexpectedPacketError, INVALID_UUID from neo.protocol import Packet, INVALID_UUID
from neo.master.recovery import RecoveryEventHandler from neo.master.recovery import RecoveryEventHandler
from neo.master.app import Application from neo.master.app import Application
from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \ from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
...@@ -120,15 +120,6 @@ server: 127.0.0.1:10023 ...@@ -120,15 +120,6 @@ server: 127.0.0.1:10023
# Delete tmp file # Delete tmp file
os.remove(self.tmp_path) os.remove(self.tmp_path)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
# Common methods # Common methods
def getNewUUID(self): def getNewUUID(self):
uuid = INVALID_UUID uuid = INVALID_UUID
...@@ -155,12 +146,25 @@ server: 127.0.0.1:10023 ...@@ -155,12 +146,25 @@ server: 127.0.0.1:10023
def checkUnexpectedPacketRaised(self, method, *args, **kwargs): def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """ """ Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs) self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs): def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """ """ Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs) self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
call = conn.mockGetNamedCalls("answer")[packet_number]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
# Method to test the kind of packet returned in answer # Method to test the kind of packet returned in answer
def checkCalledRequestNodeIdentification(self, conn, packet_number=0): def checkCalledRequestNodeIdentification(self, conn, packet_number=0):
""" Check Request Node Identification has been send""" """ Check Request Node Identification has been send"""
...@@ -443,7 +447,9 @@ server: 127.0.0.1:10023 ...@@ -443,7 +447,9 @@ server: 127.0.0.1:10023
self.assertEqual(node.getState(), BROKEN_STATE) self.assertEqual(node.getState(), BROKEN_STATE)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(len(self.app.nm.getMasterNodeList()), 2) self.assertEqual(len(self.app.nm.getMasterNodeList()), 2)
recovery.handleRequestNodeIdentification(conn, self.checkBrokenNotDisallowedErrorRaised(
recovery.handleRequestNodeIdentification,
conn,
packet=packet, packet=packet,
node_type=MASTER_NODE_TYPE, node_type=MASTER_NODE_TYPE,
uuid=uuid, uuid=uuid,
...@@ -451,8 +457,6 @@ server: 127.0.0.1:10023 ...@@ -451,8 +457,6 @@ server: 127.0.0.1:10023
port=self.master_port, port=self.master_port,
name=self.app.name,) name=self.app.name,)
self.checkCalledAbort(conn)
# 8. known node but down # 8. known node but down
conn = Mock({"addPacket" : None, conn = Mock({"addPacket" : None,
"abort" : None, "abort" : None,
......
...@@ -22,7 +22,7 @@ from tempfile import mkstemp ...@@ -22,7 +22,7 @@ from tempfile import mkstemp
from mock import Mock from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo import protocol from neo import protocol
from neo.protocol import Packet, UnexpectedPacketError, INVALID_UUID from neo.protocol import Packet, INVALID_UUID
from neo.master.service import ServiceEventHandler from neo.master.service import ServiceEventHandler
from neo.master.app import Application from neo.master.app import Application
from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \ from neo.protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
...@@ -118,12 +118,16 @@ server: 127.0.0.1:10023 ...@@ -118,12 +118,16 @@ server: 127.0.0.1:10023
def checkUnexpectedPacketRaised(self, method, *args, **kwargs): def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """ """ Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs) self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs): def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """ """ Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs) self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
# Method to test the kind of packet returned in answer # Method to test the kind of packet returned in answer
def checkCalledAbort(self, conn, packet_number=0): def checkCalledAbort(self, conn, packet_number=0):
"""Check the abort method has been called and an error packet has been sent""" """Check the abort method has been called and an error packet has been sent"""
...@@ -379,14 +383,15 @@ server: 127.0.0.1:10023 ...@@ -379,14 +383,15 @@ server: 127.0.0.1:10023
sn.setState(BROKEN_STATE) sn.setState(BROKEN_STATE)
self.assertEquals(sn.getState(), BROKEN_STATE) self.assertEquals(sn.getState(), BROKEN_STATE)
service.handleRequestNodeIdentification(conn, self.checkBrokenNotDisallowedErrorRaised(
service.handleRequestNodeIdentification,
conn,
packet=packet, packet=packet,
node_type=STORAGE_NODE_TYPE, node_type=STORAGE_NODE_TYPE,
uuid=uuid, uuid=uuid,
ip_address='127.0.0.1', ip_address='127.0.0.1',
port=self.storage_port, port=self.storage_port,
name=self.app.name,) name=self.app.name,)
self.checkCalledNotifyAbort(conn)
self.assertEquals(len(self.app.nm.getStorageNodeList()), 2) self.assertEquals(len(self.app.nm.getStorageNodeList()), 2)
sn = self.app.nm.getStorageNodeList()[0] sn = self.app.nm.getStorageNodeList()[0]
self.assertEquals(sn.getServer(), ('127.0.0.1', self.storage_port)) self.assertEquals(sn.getServer(), ('127.0.0.1', self.storage_port))
......
...@@ -21,7 +21,7 @@ import logging ...@@ -21,7 +21,7 @@ import logging
from tempfile import mkstemp from tempfile import mkstemp
from mock import Mock from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo.protocol import Packet, UnexpectedPacketError, INVALID_UUID from neo.protocol import Packet, INVALID_UUID
from neo.master.verification import VerificationEventHandler from neo.master.verification import VerificationEventHandler
from neo.master.app import Application from neo.master.app import Application
from neo import protocol from neo import protocol
...@@ -125,12 +125,16 @@ server: 127.0.0.1:10023 ...@@ -125,12 +125,16 @@ server: 127.0.0.1:10023
def checkUnexpectedPacketRaised(self, method, *args, **kwargs): def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """ """ Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs) self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs): def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """ """ Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs) self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAcceptNodeIdentification(self, conn, packet_number=0): def checkCalledAcceptNodeIdentification(self, conn, packet_number=0):
""" Check Accept Node Identification has been send""" """ Check Accept Node Identification has been send"""
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1) self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
...@@ -465,7 +469,9 @@ server: 127.0.0.1:10023 ...@@ -465,7 +469,9 @@ server: 127.0.0.1:10023
self.assertEqual(node.getState(), BROKEN_STATE) self.assertEqual(node.getState(), BROKEN_STATE)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(len(self.app.nm.getMasterNodeList()), 2) self.assertEqual(len(self.app.nm.getMasterNodeList()), 2)
verification.handleRequestNodeIdentification(conn, self.checkBrokenNotDisallowedErrorRaised(
verification.handleRequestNodeIdentification,
conn,
packet=packet, packet=packet,
node_type=MASTER_NODE_TYPE, node_type=MASTER_NODE_TYPE,
uuid=uuid, uuid=uuid,
...@@ -473,8 +479,6 @@ server: 127.0.0.1:10023 ...@@ -473,8 +479,6 @@ server: 127.0.0.1:10023
port=self.master_port, port=self.master_port,
name=self.app.name,) name=self.app.name,)
self.checkCalledAbort(conn)
# 8. known node but down # 8. known node but down
conn = Mock({"addPacket" : None, conn = Mock({"addPacket" : None,
"abort" : None, "abort" : None,
......
...@@ -189,11 +189,7 @@ class VerificationEventHandler(MasterEventHandler): ...@@ -189,11 +189,7 @@ class VerificationEventHandler(MasterEventHandler):
# If this node is broken, reject it. Otherwise, assume that it is # If this node is broken, reject it. Otherwise, assume that it is
# working again. # working again.
if node.getState() == BROKEN_STATE: if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away') raise protocol.BrokenNotDisallowedError
conn.answer(p, packet)
conn.abort()
return
else:
node.setUUID(uuid) node.setUUID(uuid)
node.setState(RUNNING_STATE) node.setState(RUNNING_STATE)
app.broadcastNodeInformation(node) app.broadcastNodeInformation(node)
......
...@@ -318,9 +318,26 @@ UUID_NAMESPACES = { ...@@ -318,9 +318,26 @@ UUID_NAMESPACES = {
ADMIN_NODE_TYPE: ADMIN_NS, ADMIN_NODE_TYPE: ADMIN_NS,
} }
class ProtocolError(Exception): pass class ProtocolError(Exception):
class PacketMalformedError(ProtocolError): pass """ Base class for protocol errors, close the connection """
class UnexpectedPacketError(ProtocolError): pass pass
class PacketMalformedError(ProtocolError):
""" Close the connection and set the node as broken"""
pass
class UnexpectedPacketError(ProtocolError):
""" Close the connection and set the node as broken"""
pass
class NotReadyError(ProtocolError):
""" Just close the connection """
pass
class BrokenNotDisallowedError(ProtocolError):
""" Just close the connection """
pass
decode_table = {} decode_table = {}
......
...@@ -131,10 +131,7 @@ class BootstrapEventHandler(StorageEventHandler): ...@@ -131,10 +131,7 @@ class BootstrapEventHandler(StorageEventHandler):
# If this node is broken, reject it. # If this node is broken, reject it.
if node.getUUID() == uuid: if node.getUUID() == uuid:
if node.getState() == BROKEN_STATE: if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away') raise protocol.BrokenNotDisallowedError
conn.answer(p, packet)
conn.abort()
return
# Trust the UUID sent by the peer. # Trust the UUID sent by the peer.
node.setUUID(uuid) node.setUUID(uuid)
......
...@@ -163,10 +163,7 @@ class OperationEventHandler(StorageEventHandler): ...@@ -163,10 +163,7 @@ class OperationEventHandler(StorageEventHandler):
# If this node is broken, reject it. # If this node is broken, reject it.
if node.getUUID() == uuid: if node.getUUID() == uuid:
if node.getState() == BROKEN_STATE: if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away') raise protocol.BrokenNotDisallowedError
conn.answer(p, packet)
conn.abort()
return
# Trust the UUID sent by the peer. # Trust the UUID sent by the peer.
node.setUUID(uuid) node.setUUID(uuid)
......
...@@ -26,6 +26,7 @@ from neo.pt import PartitionTable ...@@ -26,6 +26,7 @@ from neo.pt import PartitionTable
from neo.storage.app import Application, StorageNode from neo.storage.app import Application, StorageNode
from neo.storage.bootstrap import BootstrapEventHandler from neo.storage.bootstrap import BootstrapEventHandler
from neo.storage.verification import VerificationEventHandler from neo.storage.verification import VerificationEventHandler
from neo import protocol
from neo.protocol import STORAGE_NODE_TYPE, MASTER_NODE_TYPE from neo.protocol import STORAGE_NODE_TYPE, MASTER_NODE_TYPE
from neo.protocol import BROKEN_STATE, RUNNING_STATE, Packet, INVALID_UUID from neo.protocol import BROKEN_STATE, RUNNING_STATE, Packet, INVALID_UUID
from neo.protocol import ACCEPT_NODE_IDENTIFICATION, REQUEST_NODE_IDENTIFICATION from neo.protocol import ACCEPT_NODE_IDENTIFICATION, REQUEST_NODE_IDENTIFICATION
...@@ -111,12 +112,16 @@ server: 127.0.0.1:10020 ...@@ -111,12 +112,16 @@ server: 127.0.0.1:10020
def checkUnexpectedPacketRaised(self, method, *args, **kwargs): def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """ """ Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs) self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs): def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """ """ Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs) self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
# Method to test the kind of packet returned in answer # Method to test the kind of packet returned in answer
def checkCalledRequestNodeIdentification(self, conn, packet_number=0): def checkCalledRequestNodeIdentification(self, conn, packet_number=0):
""" Check Request Node Identification has been send""" """ Check Request Node Identification has been send"""
...@@ -284,7 +289,8 @@ server: 127.0.0.1:10020 ...@@ -284,7 +289,8 @@ server: 127.0.0.1:10020
conn = Mock({"isServerConnection": False, conn = Mock({"isServerConnection": False,
"getAddress" : ("127.0.0.1", self.master_port), }) "getAddress" : ("127.0.0.1", self.master_port), })
self.app.trying_master_node = self.trying_master_node self.app.trying_master_node = self.trying_master_node
self.bootstrap.handleRequestNodeIdentification( self.checkUnexpectedPacketRaised(
self.bootstrap.handleRequestNodeIdentification,
conn=conn, conn=conn,
uuid=self.getNewUUID(), uuid=self.getNewUUID(),
packet=packet, packet=packet,
...@@ -292,7 +298,6 @@ server: 127.0.0.1:10020 ...@@ -292,7 +298,6 @@ server: 127.0.0.1:10020
node_type=MASTER_NODE_TYPE, node_type=MASTER_NODE_TYPE,
ip_address='127.0.0.1', ip_address='127.0.0.1',
name='',) name='',)
self.checkCalledAbort(conn)
self.assertEquals(len(conn.mockGetNamedCalls("setUUID")), 0) self.assertEquals(len(conn.mockGetNamedCalls("setUUID")), 0)
def test_08_handleRequestNodeIdentification2(self): def test_08_handleRequestNodeIdentification2(self):
...@@ -366,7 +371,8 @@ server: 127.0.0.1:10020 ...@@ -366,7 +371,8 @@ server: 127.0.0.1:10020
uuid=self.getNewUUID() uuid=self.getNewUUID()
master.setState(BROKEN_STATE) master.setState(BROKEN_STATE)
master.setUUID(uuid) master.setUUID(uuid)
self.bootstrap.handleRequestNodeIdentification( self.checkBrokenNotDisallowedErrorRaised(
self.bootstrap.handleRequestNodeIdentification,
conn=conn, conn=conn,
uuid=uuid, uuid=uuid,
packet=packet, packet=packet,
...@@ -374,7 +380,6 @@ server: 127.0.0.1:10020 ...@@ -374,7 +380,6 @@ server: 127.0.0.1:10020
node_type=MASTER_NODE_TYPE, node_type=MASTER_NODE_TYPE,
ip_address='127.0.0.1', ip_address='127.0.0.1',
name=self.app.name,) name=self.app.name,)
self.checkCalledAbort(conn)
self.assertEquals(len(conn.mockGetNamedCalls("setUUID")), 0) self.assertEquals(len(conn.mockGetNamedCalls("setUUID")), 0)
def test_08_handleRequestNodeIdentification6(self): def test_08_handleRequestNodeIdentification6(self):
...@@ -415,7 +420,8 @@ server: 127.0.0.1:10020 ...@@ -415,7 +420,8 @@ server: 127.0.0.1:10020
"getAddress" : ("127.0.0.1", self.master_port), }) "getAddress" : ("127.0.0.1", self.master_port), })
packet = Packet(msg_type=ACCEPT_NODE_IDENTIFICATION) packet = Packet(msg_type=ACCEPT_NODE_IDENTIFICATION)
self.app.trying_master_node = self.trying_master_node self.app.trying_master_node = self.trying_master_node
self.bootstrap.handleAcceptNodeIdentification( self.checkUnexpectedPacketRaised(
self.bootstrap.handleAcceptNodeIdentification,
conn=conn, conn=conn,
packet=packet, packet=packet,
node_type=MASTER_NODE_TYPE, node_type=MASTER_NODE_TYPE,
...@@ -425,7 +431,6 @@ server: 127.0.0.1:10020 ...@@ -425,7 +431,6 @@ server: 127.0.0.1:10020
num_partitions=self.app.num_partitions, num_partitions=self.app.num_partitions,
num_replicas=self.app.num_replicas, num_replicas=self.app.num_replicas,
your_uuid=self.getNewUUID()) your_uuid=self.getNewUUID())
self.checkCalledAbort(conn)
def test_09_handleAcceptNodeIdentification2(self): def test_09_handleAcceptNodeIdentification2(self):
# not a master node -> rejected # not a master node -> rejected
...@@ -560,13 +565,13 @@ server: 127.0.0.1:10020 ...@@ -560,13 +565,13 @@ server: 127.0.0.1:10020
packet = Packet(msg_type=ANSWER_PRIMARY_MASTER) packet = Packet(msg_type=ANSWER_PRIMARY_MASTER)
self.app.trying_master_node = self.trying_master_node self.app.trying_master_node = self.trying_master_node
self.app.primary_master_node = None self.app.primary_master_node = None
self.bootstrap.handleAnswerPrimaryMaster( self.checkUnexpectedPacketRaised(
self.bootstrap.handleAnswerPrimaryMaster,
conn=conn, conn=conn,
packet=packet, packet=packet,
primary_uuid=self.getNewUUID(), primary_uuid=self.getNewUUID(),
known_master_list=() known_master_list=()
) )
self.checkCalledAbort(conn)
self.assertEquals(self.app.trying_master_node, self.trying_master_node) self.assertEquals(self.app.trying_master_node, self.trying_master_node)
self.assertEquals(self.app.primary_master_node, None) self.assertEquals(self.app.primary_master_node, None)
......
...@@ -28,6 +28,7 @@ from neo.storage.app import Application, StorageNode ...@@ -28,6 +28,7 @@ from neo.storage.app import Application, StorageNode
from neo.storage.operation import TransactionInformation, OperationEventHandler from neo.storage.operation import TransactionInformation, OperationEventHandler
from neo.exception import PrimaryFailure, OperationFailure from neo.exception import PrimaryFailure, OperationFailure
from neo.pt import PartitionTable from neo.pt import PartitionTable
from neo import protocol
from neo.protocol import * from neo.protocol import *
SQL_ADMIN_USER = 'root' SQL_ADMIN_USER = 'root'
...@@ -52,12 +53,16 @@ class StorageOperationTests(unittest.TestCase): ...@@ -52,12 +53,16 @@ class StorageOperationTests(unittest.TestCase):
def checkUnexpectedPacketRaised(self, method, *args, **kwargs): def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """ """ Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs) self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs): def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """ """ Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs) self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAbort(self, conn, packet_number=0): def checkCalledAbort(self, conn, packet_number=0):
"""Check the abort method has been called and an error packet has been sent""" """Check the abort method has been called and an error packet has been sent"""
# sometimes we answer an error, sometimes we just send it # sometimes we answer an error, sometimes we just send it
...@@ -373,7 +378,8 @@ server: 127.0.0.1:10020 ...@@ -373,7 +378,8 @@ server: 127.0.0.1:10020
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
}) })
count = len(self.app.nm.getNodeList()) count = len(self.app.nm.getNodeList())
self.operation.handleRequestNodeIdentification( self.checkBrokenNotDisallowedErrorRaised(
self.operation.handleRequestNodeIdentification,
conn=conn, conn=conn,
packet=packet, packet=packet,
node_type=MASTER_NODE_TYPE, node_type=MASTER_NODE_TYPE,
...@@ -381,7 +387,6 @@ server: 127.0.0.1:10020 ...@@ -381,7 +387,6 @@ server: 127.0.0.1:10020
ip_address='127.0.0.1', ip_address='127.0.0.1',
port=self.master_port, port=self.master_port,
name=self.app.name) name=self.app.name)
self.checkPacket(conn, packet_type=ERROR)
self.assertEquals(len(self.app.nm.getNodeList()), count) self.assertEquals(len(self.app.nm.getNodeList()), count)
def test_09_handleRequestNodeIdentification4(self): def test_09_handleRequestNodeIdentification4(self):
......
...@@ -36,7 +36,7 @@ from neo.protocol import ACCEPT_NODE_IDENTIFICATION, REQUEST_NODE_IDENTIFICATION ...@@ -36,7 +36,7 @@ from neo.protocol import ACCEPT_NODE_IDENTIFICATION, REQUEST_NODE_IDENTIFICATION
UNLOCK_INFORMATION, TID_NOT_FOUND_CODE, ASK_TRANSACTION_INFORMATION, ANSWER_TRANSACTION_INFORMATION, \ UNLOCK_INFORMATION, TID_NOT_FOUND_CODE, ASK_TRANSACTION_INFORMATION, ANSWER_TRANSACTION_INFORMATION, \
ANSWER_PARTITION_TABLE,SEND_PARTITION_TABLE, COMMIT_TRANSACTION ANSWER_PARTITION_TABLE,SEND_PARTITION_TABLE, COMMIT_TRANSACTION
from neo.protocol import ERROR, BROKEN_NODE_DISALLOWED_CODE, ASK_PRIMARY_MASTER from neo.protocol import ERROR, BROKEN_NODE_DISALLOWED_CODE, ASK_PRIMARY_MASTER
from neo.protocol import ANSWER_PRIMARY_MASTER, UnexpectedPacketError from neo.protocol import ANSWER_PRIMARY_MASTER
from neo.exception import PrimaryFailure, OperationFailure from neo.exception import PrimaryFailure, OperationFailure
from neo.storage.mysqldb import MySQLDatabaseManager, p64, u64 from neo.storage.mysqldb import MySQLDatabaseManager, p64, u64
...@@ -129,12 +129,16 @@ server: 127.0.0.1:10020 ...@@ -129,12 +129,16 @@ server: 127.0.0.1:10020
def checkUnexpectedPacketRaised(self, method, *args, **kwargs): def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """ """ Check if the UnexpectedPacketError exception wxas raised """
self.assertRaises(UnexpectedPacketError, method, *args, **kwargs) self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs): def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """ """ Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs) self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNotDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNotDisallowedError exception wxas raised """
self.assertRaises(protocol.BrokenNotDisallowedError, method, *args, **kwargs)
def checkCalledAbort(self, conn, packet_number=0): def checkCalledAbort(self, conn, packet_number=0):
"""Check the abort method has been called and an error packet has been sent""" """Check the abort method has been called and an error packet has been sent"""
# sometimes we answer an error, sometimes we just notify it # sometimes we answer an error, sometimes we just notify it
...@@ -277,14 +281,10 @@ server: 127.0.0.1:10020 ...@@ -277,14 +281,10 @@ server: 127.0.0.1:10020
node = self.app.nm.getNodeByServer(conn.getAddress()) node = self.app.nm.getNodeByServer(conn.getAddress())
node.setState(BROKEN_STATE) node.setState(BROKEN_STATE)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.verification.handleRequestNodeIdentification(conn, p, MASTER_NODE_TYPE, self.checkBrokenNotDisallowedErrorRaised(
self.verification.handleRequestNodeIdentification,
conn, p, MASTER_NODE_TYPE,
uuid, "127.0.0.1", self.master_port, "main") uuid, "127.0.0.1", self.master_port, "main")
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
call = conn.mockGetNamedCalls("answer")[0]
packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(), ERROR)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 1)
# change uuid of a known node # change uuid of a known node
uuid = self.getNewUUID() uuid = self.getNewUUID()
......
...@@ -88,10 +88,7 @@ class VerificationEventHandler(StorageEventHandler): ...@@ -88,10 +88,7 @@ class VerificationEventHandler(StorageEventHandler):
# If this node is broken, reject it. # If this node is broken, reject it.
if node.getUUID() == uuid: if node.getUUID() == uuid:
if node.getState() == BROKEN_STATE: if node.getState() == BROKEN_STATE:
p = protocol.brokenNodeDisallowedError('go away') raise protocol.BrokenNotDisallowedError
conn.answer(p, packet)
conn.abort()
return
# Trust the UUID sent by the peer. # Trust the UUID sent by the peer.
node.setUUID(uuid) node.setUUID(uuid)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment