Commit 4bedd3fc authored by Kirill Smelkov's avatar Kirill Smelkov

.

parents 0f30552f 8eb14b01
[run]
source = neo
omit =
neo/debug.py
neo/scripts/runner.py
neo/tests/*
...@@ -34,7 +34,7 @@ ZODB API is fully implemented except: ...@@ -34,7 +34,7 @@ ZODB API is fully implemented except:
for garbage collection) for garbage collection)
- blobs: not implemented (not considered yet) - blobs: not implemented (not considered yet)
Any ZODB like FileStorage can be converted to NEO instanteously, Any ZODB like FileStorage can be converted to NEO instantaneously,
which means the database is operational before all data are imported. which means the database is operational before all data are imported.
There's also a tool to convert back to FileStorage. There's also a tool to convert back to FileStorage.
......
...@@ -105,13 +105,9 @@ class Application(BaseApplication): ...@@ -105,13 +105,9 @@ class Application(BaseApplication):
""" """
self.cluster_state = None self.cluster_state = None
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, NodeTypes.ADMIN, bootstrap = BootstrapManager(self, NodeTypes.ADMIN, self.server)
self.uuid, self.server) self.master_node, self.master_conn, num_partitions, num_replicas = \
data = bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data
self.master_node = node
self.master_conn = conn
self.uuid = uuid
if self.pt is None: if self.pt is None:
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
...@@ -125,7 +121,6 @@ class Application(BaseApplication): ...@@ -125,7 +121,6 @@ class Application(BaseApplication):
# passive handler # passive handler
self.master_conn.setHandler(self.master_event_handler) self.master_conn.setHandler(self.master_event_handler)
self.master_conn.ask(Packets.AskClusterState()) self.master_conn.ask(Packets.AskClusterState())
self.master_conn.ask(Packets.AskNodeInformation())
self.master_conn.ask(Packets.AskPartitionTable()) self.master_conn.ask(Packets.AskPartitionTable())
def sendPartitionTable(self, conn, min_offset, max_offset, uuid): def sendPartitionTable(self, conn, min_offset, max_offset, uuid):
......
...@@ -74,7 +74,7 @@ class AdminEventHandler(EventHandler): ...@@ -74,7 +74,7 @@ class AdminEventHandler(EventHandler):
class MasterEventHandler(EventHandler): class MasterEventHandler(EventHandler):
""" This class is just used to dispacth message to right handler""" """ This class is just used to dispatch message to right handler"""
def _connectionLost(self, conn): def _connectionLost(self, conn):
app = self.app app = self.app
...@@ -106,11 +106,6 @@ class MasterEventHandler(EventHandler): ...@@ -106,11 +106,6 @@ class MasterEventHandler(EventHandler):
def answerClusterState(self, conn, state): def answerClusterState(self, conn, state):
self.app.cluster_state = state self.app.cluster_state = state
def answerNodeInformation(self, conn):
# XXX: This will no more exists when the initialization module will be
# implemented for factorize code (as done for bootstrap)
logging.debug("answerNodeInformation")
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
...@@ -125,8 +120,6 @@ class MasterEventHandler(EventHandler): ...@@ -125,8 +120,6 @@ class MasterEventHandler(EventHandler):
def notifyClusterInformation(self, conn, cluster_state): def notifyClusterInformation(self, conn, cluster_state):
self.app.cluster_state = cluster_state self.app.cluster_state = cluster_state
def notifyNodeInformation(self, conn, node_list):
self.app.nm.update(node_list)
class MasterRequestEventHandler(EventHandler): class MasterRequestEventHandler(EventHandler):
""" This class handle all answer from primary master node""" """ This class handle all answer from primary master node"""
......
...@@ -108,7 +108,7 @@ class Storage(BaseStorage.BaseStorage, ...@@ -108,7 +108,7 @@ class Storage(BaseStorage.BaseStorage,
def deleteObject(self, oid, serial, transaction): def deleteObject(self, oid, serial, transaction):
self.app.store(oid, serial, None, None, transaction) self.app.store(oid, serial, None, None, transaction)
# mutliple revisions # multiple revisions
def loadSerial(self, oid, serial): def loadSerial(self, oid, serial):
try: try:
return self.app.load(oid, serial)[0] return self.app.load(oid, serial)[0]
......
...@@ -87,4 +87,4 @@ def patch(): ...@@ -87,4 +87,4 @@ def patch():
patch() patch()
import app # set up signal handers early enough to do it in the main thread import app # set up signal handlers early enough to do it in the main thread
...@@ -132,7 +132,7 @@ class Application(ThreadedApplication): ...@@ -132,7 +132,7 @@ class Application(ThreadedApplication):
self._cache_lock_acquire = lock.acquire self._cache_lock_acquire = lock.acquire
self._cache_lock_release = lock.release self._cache_lock_release = lock.release
# _connecting_to_master_node is used to prevent simultaneous master # _connecting_to_master_node is used to prevent simultaneous master
# node connection attemps # node connection attempts
self._connecting_to_master_node = Lock() self._connecting_to_master_node = Lock()
self.compress = compress self.compress = compress
...@@ -240,10 +240,10 @@ class Application(ThreadedApplication): ...@@ -240,10 +240,10 @@ class Application(ThreadedApplication):
self.notifications_handler, self.notifications_handler,
node=self.trying_master_node, node=self.trying_master_node,
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
p = Packets.RequestIdentification(
NodeTypes.CLIENT, self.uuid, None, self.name, None)
try: try:
ask(conn, Packets.RequestIdentification( ask(conn, p, handler=handler)
NodeTypes.CLIENT, self.uuid, None, self.name),
handler=handler)
except ConnectionClosed: except ConnectionClosed:
continue continue
# If we reached the primary master node, mark as connected # If we reached the primary master node, mark as connected
...@@ -256,7 +256,6 @@ class Application(ThreadedApplication): ...@@ -256,7 +256,6 @@ class Application(ThreadedApplication):
# operational. Might raise ConnectionClosed so that the new # operational. Might raise ConnectionClosed so that the new
# primary can be looked-up again. # primary can be looked-up again.
logging.info('Initializing from master') logging.info('Initializing from master')
ask(conn, Packets.AskNodeInformation(), handler=handler)
ask(conn, Packets.AskPartitionTable(), handler=handler) ask(conn, Packets.AskPartitionTable(), handler=handler)
ask(conn, Packets.AskLastTransaction(), handler=handler) ask(conn, Packets.AskLastTransaction(), handler=handler)
if self.pt.operational(): if self.pt.operational():
...@@ -324,7 +323,7 @@ class Application(ThreadedApplication): ...@@ -324,7 +323,7 @@ class Application(ThreadedApplication):
object existed, but its creation was undone object existed, but its creation was undone
Note that loadSerial is used during conflict resolution to load Note that loadSerial is used during conflict resolution to load
object's current version, which is not visible to us normaly (it was object's current version, which is not visible to us normally (it was
committed after our snapshot was taken). committed after our snapshot was taken).
""" """
# TODO: # TODO:
...@@ -987,7 +986,7 @@ class Application(ThreadedApplication): ...@@ -987,7 +986,7 @@ class Application(ThreadedApplication):
queue = txn_context['queue'] queue = txn_context['queue']
txn_context['object_stored_counter_dict'][oid] = {} txn_context['object_stored_counter_dict'][oid] = {}
# ZODB.Connection performs calls 'checkCurrentSerialInTransaction' # ZODB.Connection performs calls 'checkCurrentSerialInTransaction'
# after stores, and skips oids that have been succeessfully stored. # after stores, and skips oids that have been successfully stored.
assert oid not in txn_context['cache_dict'], (oid, txn_context) assert oid not in txn_context['cache_dict'], (oid, txn_context)
txn_context['data_dict'].setdefault(oid, CHECKED_SERIAL) txn_context['data_dict'].setdefault(oid, CHECKED_SERIAL)
checked_nodes = txn_context['checked_nodes'] checked_nodes = txn_context['checked_nodes']
......
...@@ -203,7 +203,7 @@ class ClientCache(object): ...@@ -203,7 +203,7 @@ class ClientCache(object):
item = self._load(oid, next_tid) item = self._load(oid, next_tid)
if item: if item:
# We don't handle late invalidations for cached oids, because # We don't handle late invalidations for cached oids, because
# the caller is not supposed to explicitely asks for tids after # the caller is not supposed to explicitly asks for tids after
# app.last_tid (and the cache should be empty when app.last_tid # app.last_tid (and the cache should be empty when app.last_tid
# is still None). # is still None).
assert item.tid == tid, (item, tid) assert item.tid == tid, (item, tid)
......
...@@ -30,6 +30,16 @@ class PrimaryBootstrapHandler(AnswerBaseHandler): ...@@ -30,6 +30,16 @@ class PrimaryBootstrapHandler(AnswerBaseHandler):
self.app.trying_master_node = None self.app.trying_master_node = None
conn.close() conn.close()
def answerPartitionTable(self, conn, ptid, row_list):
assert row_list
self.app.pt.load(ptid, row_list, self.app.nm)
def answerLastTransaction(*args):
pass
class PrimaryNotificationsHandler(MTEventHandler):
""" Handler that process the notifications from the primary master """
def _acceptIdentification(self, node, uuid, num_partitions, def _acceptIdentification(self, node, uuid, num_partitions,
num_replicas, your_uuid, primary, known_master_list): num_replicas, your_uuid, primary, known_master_list):
app = self.app app = self.app
...@@ -77,27 +87,13 @@ class PrimaryBootstrapHandler(AnswerBaseHandler): ...@@ -77,27 +87,13 @@ class PrimaryBootstrapHandler(AnswerBaseHandler):
raise ProtocolError('No UUID supplied') raise ProtocolError('No UUID supplied')
app.uuid = your_uuid app.uuid = your_uuid
logging.info('Got an UUID: %s', dump(app.uuid)) logging.info('Got an UUID: %s', dump(app.uuid))
app.id_timestamp = None
# Always create partition table # Always create partition table
app.pt = PartitionTable(num_partitions, num_replicas) app.pt = PartitionTable(num_partitions, num_replicas)
def answerPartitionTable(self, conn, ptid, row_list):
assert row_list
self.app.pt.load(ptid, row_list, self.app.nm)
def answerNodeInformation(self, conn):
pass
def answerLastTransaction(self, conn, ltid): def answerLastTransaction(self, conn, ltid):
pass
class PrimaryNotificationsHandler(MTEventHandler):
""" Handler that process the notifications from the primary master """
def packetReceived(self, conn, packet, kw={}):
if type(packet) is Packets.AnswerLastTransaction:
app = self.app app = self.app
ltid = packet.decode()[0]
if app.last_tid != ltid: if app.last_tid != ltid:
# Either we're connecting or we already know the last tid # Either we're connecting or we already know the last tid
# via invalidations. # via invalidations.
...@@ -124,15 +120,15 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -124,15 +120,15 @@ class PrimaryNotificationsHandler(MTEventHandler):
db = app.getDB() db = app.getDB()
db is None or db.invalidateCache() db is None or db.invalidateCache()
app.last_tid = ltid app.last_tid = ltid
elif type(packet) is Packets.AnswerTransactionFinished:
def answerTransactionFinished(self, conn, _, tid, callback, cache_dict):
app = self.app app = self.app
app.last_tid = tid = packet.decode()[1] app.last_tid = tid
callback = kw.pop('callback')
# Update cache # Update cache
cache = app._cache cache = app._cache
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
for oid, data in kw.pop('cache_dict').iteritems(): for oid, data in cache_dict.iteritems():
# Update ex-latest value in cache # Update ex-latest value in cache
cache.invalidate(oid, tid) cache.invalidate(oid, tid)
if data is not None: if data is not None:
...@@ -142,7 +138,6 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -142,7 +138,6 @@ class PrimaryNotificationsHandler(MTEventHandler):
callback(tid) callback(tid)
finally: finally:
app._cache_lock_release() app._cache_lock_release()
MTEventHandler.packetReceived(self, conn, packet, kw)
def connectionClosed(self, conn): def connectionClosed(self, conn):
app = self.app app = self.app
...@@ -185,13 +180,14 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -185,13 +180,14 @@ class PrimaryNotificationsHandler(MTEventHandler):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
nm = self.app.nm super(PrimaryNotificationsHandler, self).notifyNodeInformation(
nm.update(node_list) conn, node_list)
# XXX: 'update' automatically closes DOWN nodes. Do we really want # XXX: 'update' automatically closes DOWN nodes. Do we really want
# to do the same thing for nodes in other non-running states ? # to do the same thing for nodes in other non-running states ?
for node_type, addr, uuid, state in node_list: getByUUID = self.app.nm.getByUUID
if state != NodeStates.RUNNING: for node in node_list:
node = nm.getByUUID(uuid) if node[3] != NodeStates.RUNNING:
node = getByUUID(node[2])
if node and node.isConnected(): if node and node.isConnected():
node.getConnection().close() node.getConnection().close()
......
...@@ -41,14 +41,6 @@ class StorageEventHandler(MTEventHandler): ...@@ -41,14 +41,6 @@ class StorageEventHandler(MTEventHandler):
self.app.cp.removeConnection(node) self.app.cp.removeConnection(node)
super(StorageEventHandler, self).connectionFailed(conn) super(StorageEventHandler, self).connectionFailed(conn)
class StorageBootstrapHandler(AnswerBaseHandler):
""" Handler used when connecting to a storage node """
def notReady(self, conn, message):
conn.close()
raise NodeNotReady(message)
def _acceptIdentification(self, node, def _acceptIdentification(self, node,
uuid, num_partitions, num_replicas, your_uuid, primary, uuid, num_partitions, num_replicas, your_uuid, primary,
master_list): master_list):
...@@ -57,6 +49,13 @@ class StorageBootstrapHandler(AnswerBaseHandler): ...@@ -57,6 +49,13 @@ class StorageBootstrapHandler(AnswerBaseHandler):
primary, self.app.master_conn) primary, self.app.master_conn)
assert uuid == node.getUUID(), (uuid, node.getUUID()) assert uuid == node.getUUID(), (uuid, node.getUUID())
class StorageBootstrapHandler(AnswerBaseHandler):
""" Handler used when connecting to a storage node """
def notReady(self, conn, message):
conn.close()
raise NodeNotReady(message)
class StorageAnswersHandler(AnswerBaseHandler): class StorageAnswersHandler(AnswerBaseHandler):
""" Handle all messages related to ZODB operations """ """ Handle all messages related to ZODB operations """
...@@ -170,7 +169,7 @@ class StorageAnswersHandler(AnswerBaseHandler): ...@@ -170,7 +169,7 @@ class StorageAnswersHandler(AnswerBaseHandler):
raise ConflictError, 'Lock wait timeout for oid %s on %r' % ( raise ConflictError, 'Lock wait timeout for oid %s on %r' % (
dump(oid), conn) dump(oid), conn)
# HasLock design required that storage is multi-threaded so that # HasLock design required that storage is multi-threaded so that
# it can answer to AskHasLock while processing store resquests. # it can answer to AskHasLock while processing store requests.
# This means that the 2 cases (granted to us or nobody) are legitimate, # This means that the 2 cases (granted to us or nobody) are legitimate,
# either because it gave us the lock but is/was slow to store our data, # either because it gave us the lock but is/was slow to store our data,
# or because the storage took a lot of time processing a previous # or because the storage took a lot of time processing a previous
......
...@@ -57,7 +57,7 @@ class ConnectionPool(object): ...@@ -57,7 +57,7 @@ class ConnectionPool(object):
conn = MTClientConnection(app, app.storage_event_handler, node, conn = MTClientConnection(app, app.storage_event_handler, node,
dispatcher=app.dispatcher) dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name) app.uuid, None, app.name, app.id_timestamp)
try: try:
app._ask(conn, p, handler=app.storage_bootstrap_handler) app._ask(conn, p, handler=app.storage_bootstrap_handler)
except ConnectionClosed: except ConnectionClosed:
......
...@@ -26,7 +26,7 @@ class BootstrapManager(EventHandler): ...@@ -26,7 +26,7 @@ class BootstrapManager(EventHandler):
""" """
accepted = False accepted = False
def __init__(self, app, name, node_type, uuid=None, server=None): def __init__(self, app, node_type, server=None):
""" """
Manage the bootstrap stage of a non-master node, it lookup for the Manage the bootstrap stage of a non-master node, it lookup for the
primary master node, connect to it then returns when the master node primary master node, connect to it then returns when the master node
...@@ -35,12 +35,12 @@ class BootstrapManager(EventHandler): ...@@ -35,12 +35,12 @@ class BootstrapManager(EventHandler):
self.primary = None self.primary = None
self.server = server self.server = server
self.node_type = node_type self.node_type = node_type
self.uuid = uuid
self.name = name
self.num_replicas = None self.num_replicas = None
self.num_partitions = None self.num_partitions = None
self.current = None self.current = None
uuid = property(lambda self: self.app.uuid)
def announcePrimary(self, conn): def announcePrimary(self, conn):
# We found the primary master early enough to be notified of election # We found the primary master early enough to be notified of election
# end. Lucky. Anyway, we must carry on with identification request, so # end. Lucky. Anyway, we must carry on with identification request, so
...@@ -55,7 +55,7 @@ class BootstrapManager(EventHandler): ...@@ -55,7 +55,7 @@ class BootstrapManager(EventHandler):
EventHandler.connectionCompleted(self, conn) EventHandler.connectionCompleted(self, conn)
self.current.setRunning() self.current.setRunning()
conn.ask(Packets.RequestIdentification(self.node_type, self.uuid, conn.ask(Packets.RequestIdentification(self.node_type, self.uuid,
self.server, self.name)) self.server, self.app.name, None))
def connectionFailed(self, conn): def connectionFailed(self, conn):
""" """
...@@ -106,8 +106,9 @@ class BootstrapManager(EventHandler): ...@@ -106,8 +106,9 @@ class BootstrapManager(EventHandler):
self.num_replicas = num_replicas self.num_replicas = num_replicas
if self.uuid != your_uuid: if self.uuid != your_uuid:
# got an uuid from the primary master # got an uuid from the primary master
self.uuid = your_uuid self.app.uuid = your_uuid
logging.info('Got a new UUID: %s', uuid_str(self.uuid)) logging.info('Got a new UUID: %s', uuid_str(self.uuid))
self.app.id_timestamp = None
self.accepted = True self.accepted = True
def getPrimaryConnection(self): def getPrimaryConnection(self):
...@@ -141,8 +142,4 @@ class BootstrapManager(EventHandler): ...@@ -141,8 +142,4 @@ class BootstrapManager(EventHandler):
continue continue
# still processing # still processing
poll(1) poll(1)
return (self.current, conn, self.uuid, self.num_partitions, return self.current, conn, self.num_partitions, self.num_replicas
self.num_replicas)
...@@ -72,7 +72,7 @@ class HandlerSwitcher(object): ...@@ -72,7 +72,7 @@ class HandlerSwitcher(object):
_pending = self._pending _pending = self._pending
if self._is_handling: if self._is_handling:
# If this is called while handling a packet, the response is to # If this is called while handling a packet, the response is to
# be excpected for the current handler... # be expected for the current handler...
(request_dict, _) = _pending[0] (request_dict, _) = _pending[0]
else: else:
# ...otherwise, queue for the latest handler # ...otherwise, queue for the latest handler
...@@ -100,7 +100,7 @@ class HandlerSwitcher(object): ...@@ -100,7 +100,7 @@ class HandlerSwitcher(object):
# on_timeout sent a packet with a smaller timeout # on_timeout sent a packet with a smaller timeout
# so keep the connection open # so keep the connection open
return return
# Notify that a timeout occured # Notify that a timeout occurred
return msg_id return msg_id
def handle(self, connection, packet): def handle(self, connection, packet):
......
...@@ -124,8 +124,8 @@ class SocketConnector(object): ...@@ -124,8 +124,8 @@ class SocketConnector(object):
def getDescriptor(self): def getDescriptor(self):
# this descriptor must only be used by the event manager, where it # this descriptor must only be used by the event manager, where it
# guarantee unicity only while the connector is opened and registered # guarantee uniqueness only while the connector is opened and
# in epoll # registered in epoll
return self.socket_fd return self.socket_fd
@staticmethod @staticmethod
......
...@@ -165,6 +165,10 @@ class EventHandler(object): ...@@ -165,6 +165,10 @@ class EventHandler(object):
return return
conn.close() conn.close()
def notifyNodeInformation(self, conn, node_list):
app = self.app
app.nm.update(app, node_list)
def ping(self, conn): def ping(self, conn):
conn.answer(Packets.Pong()) conn.answer(Packets.Pong())
...@@ -227,6 +231,9 @@ class MTEventHandler(EventHandler): ...@@ -227,6 +231,9 @@ class MTEventHandler(EventHandler):
def packetReceived(self, conn, packet, kw={}): def packetReceived(self, conn, packet, kw={}):
"""Redirect all received packet to dispatcher thread.""" """Redirect all received packet to dispatcher thread."""
if packet.isResponse(): if packet.isResponse():
if packet.poll_thread:
self.dispatch(conn, packet, kw)
kw = {}
if not (self.dispatcher.dispatch(conn, packet.getId(), packet, kw) if not (self.dispatcher.dispatch(conn, packet.getId(), packet, kw)
or type(packet) is Packets.Pong): or type(packet) is Packets.Pong):
raise ProtocolError('Unexpected response packet from %r: %r' raise ProtocolError('Unexpected response packet from %r: %r'
...@@ -254,3 +261,6 @@ class AnswerBaseHandler(EventHandler): ...@@ -254,3 +261,6 @@ class AnswerBaseHandler(EventHandler):
packetReceived = unexpectedInAnswerHandler packetReceived = unexpectedInAnswerHandler
peerBroken = unexpectedInAnswerHandler peerBroken = unexpectedInAnswerHandler
protocolError = unexpectedInAnswerHandler protocolError = unexpectedInAnswerHandler
def acceptIdentification(*args):
pass
...@@ -12,7 +12,7 @@ from Queue import Empty ...@@ -12,7 +12,7 @@ from Queue import Empty
Python threading module contains a simple logging mechanism, but: Python threading module contains a simple logging mechanism, but:
- It's limitted to RLock class - It's limitted to RLock class
- It's enabled instance by instance - It's enabled instance by instance
- Choice to log or not is done at instanciation - Choice to log or not is done at instantiation
- It does not emit any log before trying to acquire lock - It does not emit any log before trying to acquire lock
This file defines a VerboseLock class implementing basic lock API and This file defines a VerboseLock class implementing basic lock API and
...@@ -29,7 +29,7 @@ class LockUser(object): ...@@ -29,7 +29,7 @@ class LockUser(object):
def __init__(self, message, level=0): def __init__(self, message, level=0):
t = threading.currentThread() t = threading.currentThread()
ident = getattr(t, 'node_name', t.name) ident = getattr(t, 'node_name', t.name)
# This class is instanciated from a place desiring to known what # This class is instantiated from a place desiring to known what
# called it. # called it.
# limit=1 would return execution position in this method # limit=1 would return execution position in this method
# limit=2 would return execution position in caller # limit=2 would return execution position in caller
......
...@@ -26,6 +26,8 @@ class Node(object): ...@@ -26,6 +26,8 @@ class Node(object):
"""This class represents a node.""" """This class represents a node."""
_connection = None _connection = None
_identified = False
id_timestamp = None
def __init__(self, manager, address=None, uuid=None, def __init__(self, manager, address=None, uuid=None,
state=NodeStates.UNKNOWN): state=NodeStates.UNKNOWN):
...@@ -34,7 +36,6 @@ class Node(object): ...@@ -34,7 +36,6 @@ class Node(object):
self._uuid = uuid self._uuid = uuid
self._manager = manager self._manager = manager
self._last_state_change = time() self._last_state_change = time()
self._identified = False
manager.add(self) manager.add(self)
def notify(self, packet): def notify(self, packet):
...@@ -83,7 +84,6 @@ class Node(object): ...@@ -83,7 +84,6 @@ class Node(object):
old_uuid = self._uuid old_uuid = self._uuid
self._uuid = uuid self._uuid = uuid
self._manager._updateUUID(self, old_uuid) self._manager._updateUUID(self, old_uuid)
self._manager._updateIdentified(self)
if self._connection is not None: if self._connection is not None:
self._connection.setUUID(uuid) self._connection.setUUID(uuid)
...@@ -97,7 +97,6 @@ class Node(object): ...@@ -97,7 +97,6 @@ class Node(object):
assert self._connection is not None assert self._connection is not None
del self._connection del self._connection
self._identified = False self._identified = False
self._manager._updateIdentified(self)
def setConnection(self, connection, force=None): def setConnection(self, connection, force=None):
""" """
...@@ -136,7 +135,6 @@ class Node(object): ...@@ -136,7 +135,6 @@ class Node(object):
conn.close() conn.close()
assert not connection.isClosed(), connection assert not connection.isClosed(), connection
connection.setOnClose(self.onConnectionClosed) connection.setOnClose(self.onConnectionClosed)
self._manager._updateIdentified(self)
def getConnection(self): def getConnection(self):
""" """
...@@ -163,72 +161,20 @@ class Node(object): ...@@ -163,72 +161,20 @@ class Node(object):
return self._identified return self._identified
def __repr__(self): def __repr__(self):
return '<%s(uuid=%s, address=%s, state=%s, connection=%r) at %x>' % ( return '<%s(uuid=%s, address=%s, state=%s, connection=%r%s) at %x>' % (
self.__class__.__name__, self.__class__.__name__,
uuid_str(self._uuid), uuid_str(self._uuid),
self._address, self._address,
self._state, self._state,
self._connection, self._connection,
'' if self._identified else ', not identified',
id(self), id(self),
) )
def isMaster(self):
return False
def isStorage(self):
return False
def isClient(self):
return False
def isAdmin(self):
return False
def isRunning(self):
return self._state == NodeStates.RUNNING
def isUnknown(self):
return self._state == NodeStates.UNKNOWN
def isTemporarilyDown(self):
return self._state == NodeStates.TEMPORARILY_DOWN
def isDown(self):
return self._state == NodeStates.DOWN
def isBroken(self):
return self._state == NodeStates.BROKEN
def isHidden(self):
return self._state == NodeStates.HIDDEN
def isPending(self):
return self._state == NodeStates.PENDING
def setRunning(self):
self.setState(NodeStates.RUNNING)
def setUnknown(self):
self.setState(NodeStates.UNKNOWN)
def setTemporarilyDown(self):
self.setState(NodeStates.TEMPORARILY_DOWN)
def setDown(self):
self.setState(NodeStates.DOWN)
def setBroken(self):
self.setState(NodeStates.BROKEN)
def setHidden(self):
self.setState(NodeStates.HIDDEN)
def setPending(self):
self.setState(NodeStates.PENDING)
def asTuple(self): def asTuple(self):
""" Returned tuple is intented to be used in procotol encoders """ """ Returned tuple is intended to be used in protocol encoders """
return (self.getType(), self._address, self._uuid, self._state) return (self.getType(), self._address, self._uuid, self._state,
self.id_timestamp)
def __gt__(self, node): def __gt__(self, node):
# sort per UUID if defined # sort per UUID if defined
...@@ -236,12 +182,6 @@ class Node(object): ...@@ -236,12 +182,6 @@ class Node(object):
return self._uuid > node._uuid return self._uuid > node._uuid
return self._address > node._address return self._address > node._address
def getType(self):
try:
return NODE_CLASS_MAPPING[self.__class__]
except KeyError:
raise NotImplementedError
def whoSetState(self): def whoSetState(self):
""" """
Debugging method: call this method to know who set the current Debugging method: call this method to know who set the current
...@@ -251,43 +191,6 @@ class Node(object): ...@@ -251,43 +191,6 @@ class Node(object):
attributeTracker.track(Node) attributeTracker.track(Node)
class MasterNode(Node):
"""This class represents a master node."""
def isMaster(self):
return True
class StorageNode(Node):
"""This class represents a storage node."""
def isStorage(self):
return True
class ClientNode(Node):
"""This class represents a client node."""
def isClient(self):
return True
class AdminNode(Node):
"""This class represents an admin node."""
def isAdmin(self):
return True
NODE_TYPE_MAPPING = {
NodeTypes.MASTER: MasterNode,
NodeTypes.STORAGE: StorageNode,
NodeTypes.CLIENT: ClientNode,
NodeTypes.ADMIN: AdminNode,
}
NODE_CLASS_MAPPING = {
StorageNode: NodeTypes.STORAGE,
MasterNode: NodeTypes.MASTER,
ClientNode: NodeTypes.CLIENT,
AdminNode: NodeTypes.ADMIN,
}
class MasterDB(object): class MasterDB(object):
""" """
...@@ -337,7 +240,7 @@ class NodeManager(object): ...@@ -337,7 +240,7 @@ class NodeManager(object):
def __init__(self, master_db=None): def __init__(self, master_db=None):
""" """
master_db (string) master_db (string)
Path to a file containing master nodes's addresses. Used to automate Path to a file containing master nodes' addresses. Used to automate
master list updates. If not provided, no automation will happen. master list updates. If not provided, no automation will happen.
""" """
self._node_set = set() self._node_set = set()
...@@ -345,7 +248,6 @@ class NodeManager(object): ...@@ -345,7 +248,6 @@ class NodeManager(object):
self._uuid_dict = {} self._uuid_dict = {}
self._type_dict = {} self._type_dict = {}
self._state_dict = {} self._state_dict = {}
self._identified_dict = {}
if master_db is not None: if master_db is not None:
self._master_db = db = MasterDB(master_db) self._master_db = db = MasterDB(master_db)
for addr in db: for addr in db:
...@@ -361,9 +263,8 @@ class NodeManager(object): ...@@ -361,9 +263,8 @@ class NodeManager(object):
self._node_set.add(node) self._node_set.add(node)
self._updateAddress(node, None) self._updateAddress(node, None)
self._updateUUID(node, None) self._updateUUID(node, None)
self.__updateSet(self._type_dict, None, node.__class__, node) self.__updateSet(self._type_dict, None, node.getType(), node)
self.__updateSet(self._state_dict, None, node.getState(), node) self.__updateSet(self._state_dict, None, node.getState(), node)
self._updateIdentified(node)
if node.isMaster() and self._master_db is not None: if node.isMaster() and self._master_db is not None:
self._master_db.add(node.getAddress()) self._master_db.add(node.getAddress())
...@@ -372,25 +273,17 @@ class NodeManager(object): ...@@ -372,25 +273,17 @@ class NodeManager(object):
logging.warning('removing unknown node %r, ignoring', node) logging.warning('removing unknown node %r, ignoring', node)
return return
self._node_set.remove(node) self._node_set.remove(node)
self.__drop(self._address_dict, node.getAddress()) # a node may have not be indexed by uuid or address, eg.:
self.__drop(self._uuid_dict, node.getUUID()) # - a client or admin node that don't have listening address
self._address_dict.pop(node.getAddress(), None)
# - a master known by address but without UUID
self._uuid_dict.pop(node.getUUID(), None)
self.__dropSet(self._state_dict, node.getState(), node) self.__dropSet(self._state_dict, node.getState(), node)
self.__dropSet(self._type_dict, node.__class__, node) self.__dropSet(self._type_dict, node.getType(), node)
uuid = node.getUUID() uuid = node.getUUID()
if uuid in self._identified_dict:
del self._identified_dict[uuid]
if node.isMaster() and self._master_db is not None: if node.isMaster() and self._master_db is not None:
self._master_db.discard(node.getAddress()) self._master_db.discard(node.getAddress())
def __drop(self, index_dict, key):
try:
del index_dict[key]
except KeyError:
# a node may have not be indexed by uuid or address, eg.:
# - a master known by address but without UUID
# - a client or admin node that don't have listening address
pass
def __update(self, index_dict, old_key, new_key, node): def __update(self, index_dict, old_key, new_key, node):
""" Update an index from old to new key """ """ Update an index from old to new key """
if old_key is not None: if old_key is not None:
...@@ -403,17 +296,6 @@ class NodeManager(object): ...@@ -403,17 +296,6 @@ class NodeManager(object):
'would overwrite %r' % (node, new_key, index_dict[new_key]) 'would overwrite %r' % (node, new_key, index_dict[new_key])
index_dict[new_key] = node index_dict[new_key] = node
def _updateIdentified(self, node):
uuid = node.getUUID()
if uuid:
# XXX: It's probably a bug to include connecting nodes but there's
# no API yet to update manager when connection is established.
if node.isConnected(connecting=True):
assert node in self._node_set, node
self._identified_dict[uuid] = node
else:
self._identified_dict.pop(uuid, None)
def _updateAddress(self, node, old_address): def _updateAddress(self, node, old_address):
self.__update(self._address_dict, old_address, node.getAddress(), node) self.__update(self._address_dict, old_address, node.getAddress(), node)
...@@ -421,15 +303,14 @@ class NodeManager(object): ...@@ -421,15 +303,14 @@ class NodeManager(object):
self.__update(self._uuid_dict, old_uuid, node.getUUID(), node) self.__update(self._uuid_dict, old_uuid, node.getUUID(), node)
def __dropSet(self, set_dict, key, node): def __dropSet(self, set_dict, key, node):
if key in set_dict and node in set_dict[key]: if key in set_dict:
set_dict[key].remove(node) set_dict[key].remove(node)
def __updateSet(self, set_dict, old_key, new_key, node): def __updateSet(self, set_dict, old_key, new_key, node):
""" Update a set index from old to new key """ """ Update a set index from old to new key """
if old_key in set_dict: if old_key in set_dict:
set_dict[old_key].remove(node) set_dict[old_key].remove(node)
if new_key is not None: set_dict.setdefault(new_key, set()).add(node)
set_dict.setdefault(new_key, set()).add(node)
def _updateState(self, node, old_state): def _updateState(self, node, old_state):
assert not node.isDown(), node assert not node.isDown(), node
...@@ -445,10 +326,8 @@ class NodeManager(object): ...@@ -445,10 +326,8 @@ class NodeManager(object):
Returns a generator to iterate over identified nodes Returns a generator to iterate over identified nodes
pool_set is an iterable of UUIDs allowed pool_set is an iterable of UUIDs allowed
""" """
if pool_set is not None: return [x for x in self._node_set if x.isIdentified() and (
identified_nodes = self._identified_dict.items() pool_set is None or x.getUUID() in pool_set)]
return [v for k, v in identified_nodes if k in pool_set]
return self._identified_dict.values()
def getConnectedList(self): def getConnectedList(self):
""" """
...@@ -457,48 +336,25 @@ class NodeManager(object): ...@@ -457,48 +336,25 @@ class NodeManager(object):
# TODO: use an index # TODO: use an index
return [x for x in self._node_set if x.isConnected()] return [x for x in self._node_set if x.isConnected()]
def __getList(self, index_dict, key):
return index_dict.setdefault(key, set())
def getByStateList(self, state): def getByStateList(self, state):
""" Get a node list filtered per the node state """ """ Get a node list filtered per the node state """
return list(self.__getList(self._state_dict, state)) return list(self._state_dict.get(state, ()))
def __getTypeList(self, type_klass, only_identified=False): def _getTypeList(self, node_type, only_identified=False):
node_set = self.__getList(self._type_dict, type_klass) node_set = self._type_dict.get(node_type, ())
if only_identified: if only_identified:
return [x for x in node_set if x.getUUID() in self._identified_dict] return [x for x in node_set if x.isIdentified()]
return list(node_set) return list(node_set)
def getMasterList(self, only_identified=False):
""" Return a list with master nodes """
return self.__getTypeList(MasterNode, only_identified)
def getStorageList(self, only_identified=False):
""" Return a list with storage nodes """
return self.__getTypeList(StorageNode, only_identified)
def getClientList(self, only_identified=False):
""" Return a list with client nodes """
return self.__getTypeList(ClientNode, only_identified)
def getAdminList(self, only_identified=False):
""" Return a list with admin nodes """
return self.__getTypeList(AdminNode, only_identified)
def getByAddress(self, address): def getByAddress(self, address):
""" Return the node that match with a given address """ """ Return the node that match with a given address """
return self._address_dict.get(address, None) return self._address_dict.get(address, None)
def getByUUID(self, uuid): def getByUUID(self, uuid, *id_timestamp):
""" Return the node that match with a given UUID """ """ Return the node that match with a given UUID """
return self._uuid_dict.get(uuid, None) node = self._uuid_dict.get(uuid)
if not id_timestamp or node and (node.id_timestamp,) == id_timestamp:
def hasAddress(self, address): return node
return address in self._address_dict
def hasUUID(self, uuid):
return uuid in self._uuid_dict
def _createNode(self, klass, address=None, uuid=None, **kw): def _createNode(self, klass, address=None, uuid=None, **kw):
by_address = self.getByAddress(address) by_address = self.getByAddress(address)
...@@ -531,50 +387,29 @@ class NodeManager(object): ...@@ -531,50 +387,29 @@ class NodeManager(object):
assert node.__class__ is klass, (node.__class__, klass) assert node.__class__ is klass, (node.__class__, klass)
return node return node
def createMaster(self, **kw):
""" Create and register a new master """
return self._createNode(MasterNode, **kw)
def createStorage(self, **kw):
""" Create and register a new storage """
return self._createNode(StorageNode, **kw)
def createClient(self, **kw):
""" Create and register a new client """
return self._createNode(ClientNode, **kw)
def createAdmin(self, **kw):
""" Create and register a new admin """
return self._createNode(AdminNode, **kw)
def _getClassFromNodeType(self, node_type):
klass = NODE_TYPE_MAPPING.get(node_type)
if klass is None:
raise ValueError('Unknown node type : %s' % node_type)
return klass
def createFromNodeType(self, node_type, **kw): def createFromNodeType(self, node_type, **kw):
return self._createNode(self._getClassFromNodeType(node_type), **kw) return self._createNode(NODE_TYPE_MAPPING[node_type], **kw)
def update(self, node_list): def update(self, app, node_list):
for node_type, addr, uuid, state in node_list: node_set = self._node_set.copy() if app.id_timestamp is None else None
for node_type, addr, uuid, state, id_timestamp in node_list:
# This should be done here (although klass might not be used in this # This should be done here (although klass might not be used in this
# iteration), as it raises if type is not valid. # iteration), as it raises if type is not valid.
klass = self._getClassFromNodeType(node_type) klass = NODE_TYPE_MAPPING[node_type]
# lookup in current table # lookup in current table
node_by_uuid = self.getByUUID(uuid) node_by_uuid = self.getByUUID(uuid)
node_by_addr = self.getByAddress(addr) node_by_addr = self.getByAddress(addr)
node = node_by_uuid or node_by_addr node = node_by_uuid or node_by_addr
log_args = node_type, uuid_str(uuid), addr, state log_args = node_type, uuid_str(uuid), addr, state, id_timestamp
if node is None: if node is None:
if state == NodeStates.DOWN: if state == NodeStates.DOWN:
logging.debug('NOT creating node %s %s %s %s', *log_args) logging.debug('NOT creating node %s %s %s %s %s', *log_args)
else: continue
node = self._createNode(klass, address=addr, uuid=uuid, node = self._createNode(klass, address=addr, uuid=uuid,
state=state) state=state)
logging.debug('creating node %r', node) logging.debug('creating node %r', node)
else: else:
assert isinstance(node, klass), 'node %r is not ' \ assert isinstance(node, klass), 'node %r is not ' \
'of expected type: %r' % (node, klass) 'of expected type: %r' % (node, klass)
...@@ -583,8 +418,8 @@ class NodeManager(object): ...@@ -583,8 +418,8 @@ class NodeManager(object):
'Discrepancy between node_by_uuid (%r) and ' \ 'Discrepancy between node_by_uuid (%r) and ' \
'node_by_addr (%r)' % (node_by_uuid, node_by_addr) 'node_by_addr (%r)' % (node_by_uuid, node_by_addr)
if state == NodeStates.DOWN: if state == NodeStates.DOWN:
logging.debug('droping node %r (%r), found with %s ' logging.debug('dropping node %r (%r), found with %s '
'%s %s %s', node, node.isConnected(), *log_args) '%s %s %s %s', node, node.isConnected(), *log_args)
if node.isConnected(): if node.isConnected():
# Cut this connection, node removed by handler. # Cut this connection, node removed by handler.
# It's important for a storage to disconnect nodes that # It's important for a storage to disconnect nodes that
...@@ -594,12 +429,20 @@ class NodeManager(object): ...@@ -594,12 +429,20 @@ class NodeManager(object):
# partition table upon disconnection. # partition table upon disconnection.
node.getConnection().close() node.getConnection().close()
self.remove(node) self.remove(node)
else: continue
logging.debug('updating node %r to %s %s %s %s', logging.debug('updating node %r to %s %s %s %s %s',
node, *log_args) node, *log_args)
node.setUUID(uuid) node.setUUID(uuid)
node.setAddress(addr) node.setAddress(addr)
node.setState(state) node.setState(state)
node.id_timestamp = id_timestamp
if app.uuid == uuid:
app.id_timestamp = id_timestamp
if node_set:
# For the first notification, we receive a full list of nodes from
# the master. Remove all unknown nodes from a previous connection.
for node in node_set - self._node_set:
self.remove(node)
self.log() self.log()
def log(self): def log(self):
...@@ -614,3 +457,40 @@ class NodeManager(object): ...@@ -614,3 +457,40 @@ class NodeManager(object):
address = '%s:%d' % address address = '%s:%d' % address
logging.info(' * %*s | %8s | %22s | %s', logging.info(' * %*s | %8s | %22s | %s',
max_len, uuid, node.getType(), address, node.getState()) max_len, uuid, node.getType(), address, node.getState())
@apply
def NODE_TYPE_MAPPING():
def setmethod(cls, attr, value):
assert not hasattr(cls, attr), (cls, attr)
setattr(cls, attr, value)
def setfullmethod(cls, attr, value):
value.__name__ = attr
setmethod(cls, attr, value)
def camel_case(enum):
return str(enum).replace('_', ' ').title().replace(' ', '')
def setStateAccessors(state):
name = camel_case(state)
setfullmethod(Node, 'set' + name, lambda self: self.setState(state))
setfullmethod(Node, 'is' + name, lambda self: self._state == state)
map(setStateAccessors, NodeStates)
node_type_dict = {}
getType = lambda node_type: staticmethod(lambda: node_type)
true = staticmethod(lambda: True)
createNode = lambda cls: lambda self, **kw: self._createNode(cls, **kw)
getList = lambda node_type: lambda self, only_identified=False: \
self._getTypeList(node_type, only_identified)
bases = Node,
for node_type in NodeTypes:
name = camel_case(node_type)
is_name = 'is' + name
setmethod(Node, is_name, bool)
node_type_dict[node_type] = cls = type(name + 'Node', bases, {
'getType': getType(node_type),
is_name: true,
})
setfullmethod(NodeManager, 'create' + name, createNode(cls))
setfullmethod(NodeManager, 'get%sList' % name, getList(node_type))
return node_type_dict
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# #
def speedupFileStorageTxnLookup(): def speedupFileStorageTxnLookup():
"""Speed up lookup of start position when instanciating an iterator """Speed up lookup of start position when instantiating an iterator
FileStorage does not index the file positions of transactions. FileStorage does not index the file positions of transactions.
With this patch, we use the existing {oid->file_pos} index to bisect the With this patch, we use the existing {oid->file_pos} index to bisect the
......
...@@ -20,7 +20,7 @@ import traceback ...@@ -20,7 +20,7 @@ import traceback
from cStringIO import StringIO from cStringIO import StringIO
from struct import Struct from struct import Struct
PROTOCOL_VERSION = 7 PROTOCOL_VERSION = 8
# Size restrictions. # Size restrictions.
MIN_PACKET_SIZE = 10 MIN_PACKET_SIZE = 10
...@@ -235,6 +235,7 @@ class Packet(object): ...@@ -235,6 +235,7 @@ class Packet(object):
_code = None _code = None
_fmt = None _fmt = None
_id = None _id = None
poll_thread = False
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
assert self._code is not None, "Packet class not registered" assert self._code is not None, "Packet class not registered"
...@@ -330,7 +331,7 @@ class ParseError(Exception): ...@@ -330,7 +331,7 @@ class ParseError(Exception):
class PItem(object): class PItem(object):
""" """
Base class for any packet item, _encode and _decode must be overriden Base class for any packet item, _encode and _decode must be overridden
by subclasses. by subclasses.
""" """
def __init__(self, name): def __init__(self, name):
...@@ -386,9 +387,9 @@ class PStructItem(PItem): ...@@ -386,9 +387,9 @@ class PStructItem(PItem):
""" """
A single value encoded with struct A single value encoded with struct
""" """
def __init__(self, name, fmt): def __init__(self, name):
PItem.__init__(self, name) PItem.__init__(self, name)
struct = Struct(fmt) struct = Struct(self._fmt)
self.pack = struct.pack self.pack = struct.pack
self.unpack = struct.unpack self.unpack = struct.unpack
self.size = struct.size self.size = struct.size
...@@ -399,12 +400,23 @@ class PStructItem(PItem): ...@@ -399,12 +400,23 @@ class PStructItem(PItem):
def _decode(self, reader): def _decode(self, reader):
return self.unpack(reader(self.size))[0] return self.unpack(reader(self.size))[0]
class PStructItemOrNone(PStructItem):
def _encode(self, writer, value):
return writer(self._None if value is None else self.pack(value))
def _decode(self, reader):
value = reader(self.size)
return None if value == self._None else self.unpack(value)[0]
class PList(PStructItem): class PList(PStructItem):
""" """
A list of homogeneous items A list of homogeneous items
""" """
_fmt = '!L'
def __init__(self, name, item): def __init__(self, name, item):
PStructItem.__init__(self, name, '!L') PStructItem.__init__(self, name)
self._item = item self._item = item
def _encode(self, writer, items): def _encode(self, writer, items):
...@@ -422,8 +434,10 @@ class PDict(PStructItem): ...@@ -422,8 +434,10 @@ class PDict(PStructItem):
""" """
A dictionary with custom key and value formats A dictionary with custom key and value formats
""" """
_fmt = '!L'
def __init__(self, name, key, value): def __init__(self, name, key, value):
PStructItem.__init__(self, name, '!L') PStructItem.__init__(self, name)
self._key = key self._key = key
self._value = value self._value = value
...@@ -449,15 +463,15 @@ class PEnum(PStructItem): ...@@ -449,15 +463,15 @@ class PEnum(PStructItem):
""" """
Encapsulate an enumeration value Encapsulate an enumeration value
""" """
_fmt = '!l'
def __init__(self, name, enum): def __init__(self, name, enum):
PStructItem.__init__(self, name, '!l') PStructItem.__init__(self, name)
self._enum = enum self._enum = enum
def _encode(self, writer, item): def _encode(self, writer, item):
if item is None: if item is None:
item = -1 item = -1
else:
assert isinstance(item, int), item
writer(self.pack(item)) writer(self.pack(item))
def _decode(self, reader): def _decode(self, reader):
...@@ -474,8 +488,7 @@ class PString(PStructItem): ...@@ -474,8 +488,7 @@ class PString(PStructItem):
""" """
A variable-length string A variable-length string
""" """
def __init__(self, name): _fmt = '!L'
PStructItem.__init__(self, name, '!L')
def _encode(self, writer, value): def _encode(self, writer, value):
writer(self.pack(len(value))) writer(self.pack(len(value)))
...@@ -512,46 +525,26 @@ class PBoolean(PStructItem): ...@@ -512,46 +525,26 @@ class PBoolean(PStructItem):
""" """
A boolean value, encoded as a single byte A boolean value, encoded as a single byte
""" """
def __init__(self, name): _fmt = '!?'
PStructItem.__init__(self, name, '!B')
def _encode(self, writer, value):
writer(self.pack(bool(value)))
def _decode(self, reader):
return bool(self.unpack(reader(self.size))[0])
class PNumber(PStructItem): class PNumber(PStructItem):
""" """
A integer number (4-bytes length) A integer number (4-bytes length)
""" """
def __init__(self, name): _fmt = '!L'
PStructItem.__init__(self, name, '!L')
class PIndex(PStructItem): class PIndex(PStructItem):
""" """
A big integer to defined indexes in a huge list. A big integer to defined indexes in a huge list.
""" """
def __init__(self, name): _fmt = '!Q'
PStructItem.__init__(self, name, '!Q')
class PPTID(PStructItem): class PPTID(PStructItemOrNone):
""" """
A None value means an invalid PTID A None value means an invalid PTID
""" """
def __init__(self, name): _fmt = '!Q'
PStructItem.__init__(self, name, '!Q') _None = Struct(_fmt).pack(0)
def _encode(self, writer, value):
if value is None:
value = 0
PStructItem._encode(self, writer, value)
def _decode(self, reader):
value = PStructItem._decode(self, reader)
if value == 0:
value = None
return value
class PProtocol(PNumber): class PProtocol(PNumber):
""" """
...@@ -577,18 +570,12 @@ class PChecksum(PItem): ...@@ -577,18 +570,12 @@ class PChecksum(PItem):
def _decode(self, reader): def _decode(self, reader):
return reader(20) return reader(20)
class PUUID(PStructItem): class PUUID(PStructItemOrNone):
""" """
An UUID (node identifier, 4-bytes signed integer) An UUID (node identifier, 4-bytes signed integer)
""" """
def __init__(self, name): _fmt = '!l'
PStructItem.__init__(self, name, '!l') _None = Struct(_fmt).pack(0)
def _encode(self, writer, uuid):
writer(self.pack(uuid or 0))
def _decode(self, reader):
return self.unpack(reader(self.size))[0] or None
class PTID(PItem): class PTID(PItem):
""" """
...@@ -609,6 +596,13 @@ class PTID(PItem): ...@@ -609,6 +596,13 @@ class PTID(PItem):
# same definition, for now # same definition, for now
POID = PTID POID = PTID
class PFloat(PStructItemOrNone):
"""
A float number (8-bytes length)
"""
_fmt = '!d'
_None = '\xff' * 8
# common definitions # common definitions
PFEmpty = PStruct('no_content') PFEmpty = PStruct('no_content')
...@@ -622,6 +616,7 @@ PFNodeList = PList('node_list', ...@@ -622,6 +616,7 @@ PFNodeList = PList('node_list',
PAddress('address'), PAddress('address'),
PUUID('uuid'), PUUID('uuid'),
PFNodeState, PFNodeState,
PFloat('id_timestamp'),
), ),
) )
...@@ -695,6 +690,7 @@ class RequestIdentification(Packet): ...@@ -695,6 +690,7 @@ class RequestIdentification(Packet):
Request a node identification. This must be the first packet for any Request a node identification. This must be the first packet for any
connection. Any -> Any. connection. Any -> Any.
""" """
poll_thread = True
_fmt = PStruct('request_identification', _fmt = PStruct('request_identification',
PProtocol('protocol_version'), PProtocol('protocol_version'),
...@@ -702,6 +698,7 @@ class RequestIdentification(Packet): ...@@ -702,6 +698,7 @@ class RequestIdentification(Packet):
PUUID('uuid'), PUUID('uuid'),
PAddress('address'), PAddress('address'),
PString('name'), PString('name'),
PFloat('id_timestamp'),
) )
_answer = PStruct('accept_identification', _answer = PStruct('accept_identification',
...@@ -882,6 +879,8 @@ class FinishTransaction(Packet): ...@@ -882,6 +879,8 @@ class FinishTransaction(Packet):
Finish a transaction. C -> PM. Finish a transaction. C -> PM.
Answer when a transaction is finished. PM -> C. Answer when a transaction is finished. PM -> C.
""" """
poll_thread = True
_fmt = PStruct('ask_finish_transaction', _fmt = PStruct('ask_finish_transaction',
PTID('tid'), PTID('tid'),
PFOidList, PFOidList,
...@@ -1167,12 +1166,6 @@ class NotifyNodeInformation(Packet): ...@@ -1167,12 +1166,6 @@ class NotifyNodeInformation(Packet):
PFNodeList, PFNodeList,
) )
class NodeInformation(Packet):
"""
Ask node information
"""
_answer = PFEmpty
class SetClusterState(Packet): class SetClusterState(Packet):
""" """
Set the cluster state Set the cluster state
...@@ -1388,6 +1381,7 @@ class LastTransaction(Packet): ...@@ -1388,6 +1381,7 @@ class LastTransaction(Packet):
Answer last committed TID. Answer last committed TID.
M -> C M -> C
""" """
poll_thread = True
_answer = PStruct('answer_last_transaction', _answer = PStruct('answer_last_transaction',
PTID('tid'), PTID('tid'),
...@@ -1492,8 +1486,8 @@ class Replicate(Packet): ...@@ -1492,8 +1486,8 @@ class Replicate(Packet):
class ReplicationDone(Packet): class ReplicationDone(Packet):
""" """
Notify the master node that a partition has been successully replicated from Notify the master node that a partition has been successfully replicated
a storage to another. from a storage to another.
S -> M S -> M
""" """
_fmt = PStruct('notify_replication_done', _fmt = PStruct('notify_replication_done',
...@@ -1528,7 +1522,7 @@ def register(request, ignore_when_closed=None): ...@@ -1528,7 +1522,7 @@ def register(request, ignore_when_closed=None):
# By default, on a closed connection: # By default, on a closed connection:
# - request: ignore # - request: ignore
# - answer: keep # - answer: keep
# - nofitication: keep # - notification: keep
ignore_when_closed = answer is not None ignore_when_closed = answer is not None
request._ignore_when_closed = ignore_when_closed request._ignore_when_closed = ignore_when_closed
if answer in (Error, None): if answer in (Error, None):
...@@ -1536,6 +1530,7 @@ def register(request, ignore_when_closed=None): ...@@ -1536,6 +1530,7 @@ def register(request, ignore_when_closed=None):
# build a class for the answer # build a class for the answer
answer = type('Answer%s' % (request.__name__, ), (Packet, ), {}) answer = type('Answer%s' % (request.__name__, ), (Packet, ), {})
answer._fmt = request._answer answer._fmt = request._answer
answer.poll_thread = request.poll_thread
# compute the answer code # compute the answer code
code = code | RESPONSE_MASK code = code | RESPONSE_MASK
answer._request = request answer._request = request
...@@ -1565,7 +1560,7 @@ class ParserState(object): ...@@ -1565,7 +1560,7 @@ class ParserState(object):
class Packets(dict): class Packets(dict):
""" """
Packet registry that check packet code unicity and provide an index Packet registry that checks packet code uniqueness and provides an index
""" """
def __metaclass__(name, base, d): def __metaclass__(name, base, d):
for k, v in d.iteritems(): for k, v in d.iteritems():
...@@ -1688,8 +1683,6 @@ class Packets(dict): ...@@ -1688,8 +1683,6 @@ class Packets(dict):
AddPendingNodes, ignore_when_closed=False) AddPendingNodes, ignore_when_closed=False)
TweakPartitionTable = register( TweakPartitionTable = register(
TweakPartitionTable, ignore_when_closed=False) TweakPartitionTable, ignore_when_closed=False)
AskNodeInformation, AnswerNodeInformation = register(
NodeInformation)
SetClusterState = register( SetClusterState = register(
SetClusterState, ignore_when_closed=False) SetClusterState, ignore_when_closed=False)
NotifyClusterInformation = register( NotifyClusterInformation = register(
......
...@@ -43,6 +43,8 @@ class ThreadContainer(threading.local): ...@@ -43,6 +43,8 @@ class ThreadContainer(threading.local):
class ThreadedApplication(BaseApplication): class ThreadedApplication(BaseApplication):
"""The client node application.""" """The client node application."""
uuid = None
def __init__(self, master_nodes, name, **kw): def __init__(self, master_nodes, name, **kw):
super(ThreadedApplication, self).__init__(**kw) super(ThreadedApplication, self).__init__(**kw)
self.poll_thread = threading.Thread(target=self.run, name=name) self.poll_thread = threading.Thread(target=self.run, name=name)
...@@ -56,8 +58,6 @@ class ThreadedApplication(BaseApplication): ...@@ -56,8 +58,6 @@ class ThreadedApplication(BaseApplication):
for address in master_nodes: for address in master_nodes:
self.nm.createMaster(address=address) self.nm.createMaster(address=address)
# no self-assigned UUID, primary master will supply us one
self.uuid = None
# Internal attribute distinct between thread # Internal attribute distinct between thread
self._thread_container = ThreadContainer() self._thread_container = ThreadContainer()
app_set.add(self) # to register self.on_log app_set.add(self) # to register self.on_log
...@@ -150,7 +150,7 @@ class ThreadedApplication(BaseApplication): ...@@ -150,7 +150,7 @@ class ThreadedApplication(BaseApplication):
if msg_id == qpacket.getId(): if msg_id == qpacket.getId():
if is_forgotten: if is_forgotten:
raise ValueError, 'ForgottenPacket for an ' \ raise ValueError, 'ForgottenPacket for an ' \
'explicitely expected packet.' 'explicitly expected packet.'
_handlePacket(qconn, qpacket, kw, handler) _handlePacket(qconn, qpacket, kw, handler)
break break
if not is_forgotten and qpacket is not None: if not is_forgotten and qpacket is not None:
......
...@@ -142,7 +142,7 @@ def parseNodeAddress(address, port_opt=None): ...@@ -142,7 +142,7 @@ def parseNodeAddress(address, port_opt=None):
else: else:
host = address host = address
port = port_opt port = port_opt
# Resolve (maybe) and cast to cannonical form # Resolve (maybe) and cast to canonical form
# XXX: Always pick the first result. This might not be what is desired, and # XXX: Always pick the first result. This might not be what is desired, and
# if so this function should either take a hint on the desired address type # if so this function should either take a hint on the desired address type
# or return either raw host & port or getaddrinfo return value. # or return either raw host & port or getaddrinfo return value.
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys, weakref import sys, weakref
from collections import defaultdict
from time import time from time import time
from neo.lib import logging from neo.lib import logging
...@@ -40,11 +41,10 @@ from .verification import VerificationManager ...@@ -40,11 +41,10 @@ from .verification import VerificationManager
class Application(BaseApplication): class Application(BaseApplication):
"""The master node application.""" """The master node application."""
packing = None packing = None
# Latest completely commited TID # Latest completely committed TID
last_transaction = ZERO_TID last_transaction = ZERO_TID
backup_tid = None backup_tid = None
backup_app = None backup_app = None
uuid = None
truncate_tid = None truncate_tid = None
def __init__(self, config): def __init__(self, config):
...@@ -79,9 +79,7 @@ class Application(BaseApplication): ...@@ -79,9 +79,7 @@ class Application(BaseApplication):
self.primary_master_node = None self.primary_master_node = None
self.cluster_state = None self.cluster_state = None
uuid = config.getUUID() self.uuid = config.getUUID()
if uuid:
self.uuid = uuid
# election related data # election related data
self.unconnected_master_node_set = set() self.unconnected_master_node_set = set()
...@@ -227,19 +225,20 @@ class Application(BaseApplication): ...@@ -227,19 +225,20 @@ class Application(BaseApplication):
Broadcast changes for a set a nodes Broadcast changes for a set a nodes
Send only one packet per connection to reduce bandwidth Send only one packet per connection to reduce bandwidth
""" """
node_dict = {} node_dict = defaultdict(list)
# group modified nodes by destination node type # group modified nodes by destination node type
for node in node_list: for node in node_list:
node_info = node.asTuple() node_info = node.asTuple()
def assign_for_notification(node_type): if node.isAdmin():
# helper function continue
node_dict.setdefault(node_type, []).append(node_info) node_dict[NodeTypes.ADMIN].append(node_info)
if node.isMaster() or node.isStorage(): node_dict[NodeTypes.STORAGE].append(node_info)
# client get notifications for master and storage only if node.isClient():
assign_for_notification(NodeTypes.CLIENT) continue
if node.isMaster() or node.isStorage() or node.isClient(): node_dict[NodeTypes.CLIENT].append(node_info)
assign_for_notification(NodeTypes.STORAGE) if node.isStorage():
assign_for_notification(NodeTypes.ADMIN) continue
node_dict[NodeTypes.MASTER].append(node_info)
# send at most one non-empty notification packet per node # send at most one non-empty notification packet per node
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
...@@ -261,7 +260,7 @@ class Application(BaseApplication): ...@@ -261,7 +260,7 @@ class Application(BaseApplication):
def provideService(self): def provideService(self):
""" """
This is the normal mode for a primary master node. Handle transactions This is the normal mode for a primary master node. Handle transactions
and stop the service only if a catastrophy happens or the user commits and stop the service only if a catastrophe happens or the user commits
a shutdown. a shutdown.
""" """
logging.info('provide service') logging.info('provide service')
...@@ -298,7 +297,7 @@ class Application(BaseApplication): ...@@ -298,7 +297,7 @@ class Application(BaseApplication):
# secondaries, rather than the other way around. This requires # secondaries, rather than the other way around. This requires
# a bit more work when a new master joins a cluster but makes # a bit more work when a new master joins a cluster but makes
# it easier to resolve UUID conflicts with minimal cluster # it easier to resolve UUID conflicts with minimal cluster
# impact, and ensure primary master unicity (primary masters # impact, and ensure primary master uniqueness (primary masters
# become noisy, in that they actively try to maintain # become noisy, in that they actively try to maintain
# connections to all other master nodes, so duplicate # connections to all other master nodes, so duplicate
# primaries will eventually get in touch with each other and # primaries will eventually get in touch with each other and
...@@ -308,6 +307,11 @@ class Application(BaseApplication): ...@@ -308,6 +307,11 @@ class Application(BaseApplication):
# masters will reconnect nevertheless, but it's dirty. # masters will reconnect nevertheless, but it's dirty.
# Currently, it's not trivial to preserve connected nodes, # Currently, it's not trivial to preserve connected nodes,
# because of poor node status tracking during election. # because of poor node status tracking during election.
# XXX: The above comment is partially wrong in that the primary
# master is now responsible of allocating node ids, and all
# other nodes must only create/update/remove nodes when
# processing node notification. We probably want to keep the
# current behaviour: having only server connections.
conn.abort() conn.abort()
# If I know any storage node, make sure that they are not in the # If I know any storage node, make sure that they are not in the
...@@ -493,7 +497,7 @@ class Application(BaseApplication): ...@@ -493,7 +497,7 @@ class Application(BaseApplication):
conn.setHandler(handler) conn.setHandler(handler)
conn.notify(Packets.NotifyNodeInformation((( conn.notify(Packets.NotifyNodeInformation(((
node.getType(), node.getAddress(), node.getUUID(), node.getType(), node.getAddress(), node.getUUID(),
NodeStates.TEMPORARILY_DOWN),))) NodeStates.TEMPORARILY_DOWN, None),)))
conn.abort() conn.abort()
elif conn.pending(): elif conn.pending():
conn.abort() conn.abort()
......
...@@ -66,6 +66,7 @@ There is no UUID conflict between the 2 clusters: ...@@ -66,6 +66,7 @@ There is no UUID conflict between the 2 clusters:
class BackupApplication(object): class BackupApplication(object):
pt = None pt = None
uuid = None
def __init__(self, app, name, master_addresses): def __init__(self, app, name, master_addresses):
self.app = weakref.proxy(app) self.app = weakref.proxy(app)
...@@ -93,7 +94,7 @@ class BackupApplication(object): ...@@ -93,7 +94,7 @@ class BackupApplication(object):
pt = app.pt pt = app.pt
while True: while True:
app.changeClusterState(ClusterStates.STARTING_BACKUP) app.changeClusterState(ClusterStates.STARTING_BACKUP)
bootstrap = BootstrapManager(self, self.name, NodeTypes.CLIENT) bootstrap = BootstrapManager(self, NodeTypes.CLIENT)
# {offset -> node} (primary storage for off which will be talking to upstream cluster) # {offset -> node} (primary storage for off which will be talking to upstream cluster)
self.primary_partition_dict = {} self.primary_partition_dict = {}
# [[tid]] part -> []tid↑ (currently scheduled-for-sync txns) # [[tid]] part -> []tid↑ (currently scheduled-for-sync txns)
...@@ -106,7 +107,7 @@ class BackupApplication(object): ...@@ -106,7 +107,7 @@ class BackupApplication(object):
else: else:
break break
poll(1) poll(1)
node, conn, uuid, num_partitions, num_replicas = \ node, conn, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
try: try:
app.changeClusterState(ClusterStates.BACKINGUP) app.changeClusterState(ClusterStates.BACKINGUP)
...@@ -115,7 +116,6 @@ class BackupApplication(object): ...@@ -115,7 +116,6 @@ class BackupApplication(object):
raise RuntimeError("inconsistent number of partitions") raise RuntimeError("inconsistent number of partitions")
self.pt = PartitionTable(num_partitions, num_replicas) self.pt = PartitionTable(num_partitions, num_replicas)
conn.setHandler(BackupHandler(self)) conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskNodeInformation())
conn.ask(Packets.AskPartitionTable()) conn.ask(Packets.AskPartitionTable())
conn.ask(Packets.AskLastTransaction()) conn.ask(Packets.AskLastTransaction())
# debug variable to log how big 'tid_list' can be. # debug variable to log how big 'tid_list' can be.
......
...@@ -18,7 +18,7 @@ from neo.lib import logging ...@@ -18,7 +18,7 @@ from neo.lib import logging
from neo.lib.exception import StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets, from neo.lib.protocol import (uuid_str, NodeTypes, NodeStates, Packets,
BrokenNodeDisallowedError, BrokenNodeDisallowedError, ProtocolError,
) )
X = 0 X = 0
...@@ -29,18 +29,19 @@ class MasterHandler(EventHandler): ...@@ -29,18 +29,19 @@ class MasterHandler(EventHandler):
def connectionCompleted(self, conn, new=None): def connectionCompleted(self, conn, new=None):
if new is None: if new is None:
super(MasterHandler, self).connectionCompleted(conn) super(MasterHandler, self).connectionCompleted(conn)
elif new:
self._notifyNodeInformation(conn)
def requestIdentification(self, conn, node_type, uuid, address, name): def requestIdentification(self, conn, node_type, uuid, address, name, _):
self.checkClusterName(name) self.checkClusterName(name)
app = self.app app = self.app
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
if node: if node:
assert node_type is not NodeTypes.MASTER or node.getAddress() in ( if node_type is NodeTypes.MASTER and not (
address, None), (node, address) None != address == node.getAddress()):
raise ProtocolError
if node.isBroken(): if node.isBroken():
raise BrokenNodeDisallowedError raise BrokenNodeDisallowedError
else:
node = app.nm.getByAddress(address)
peer_uuid = self._setupNode(conn, node_type, uuid, address, node) peer_uuid = self._setupNode(conn, node_type, uuid, address, node)
if app.primary: if app.primary:
primary_address = app.server primary_address = app.server
...@@ -99,10 +100,6 @@ class MasterHandler(EventHandler): ...@@ -99,10 +100,6 @@ class MasterHandler(EventHandler):
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(node_list))
def askNodeInformation(self, conn):
self._notifyNodeInformation(conn)
conn.answer(Packets.AnswerNodeInformation())
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList())) conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList()))
......
...@@ -31,12 +31,6 @@ class BackupHandler(EventHandler): ...@@ -31,12 +31,6 @@ class BackupHandler(EventHandler):
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm) self.app.pt.update(ptid, cell_list, self.app.nm)
def answerNodeInformation(self, conn):
pass
def notifyNodeInformation(self, conn, node_list):
self.app.nm.update(node_list)
# NOTE invalidation from M -> Mb (all partitions) # NOTE invalidation from M -> Mb (all partitions)
def answerLastTransaction(self, conn, tid): def answerLastTransaction(self, conn, tid):
app = self.app app = self.app
......
...@@ -31,14 +31,12 @@ class ClientServiceHandler(MasterHandler): ...@@ -31,14 +31,12 @@ class ClientServiceHandler(MasterHandler):
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
app.nm.remove(node) app.nm.remove(node)
def askNodeInformation(self, conn): def _notifyNodeInformation(self, conn):
# send informations about master and storages only
nm = self.app.nm nm = self.app.nm
node_list = [] node_list = [nm.getByUUID(conn.getUUID()).asTuple()] # for id_timestamp
node_list.extend(n.asTuple() for n in nm.getMasterList()) node_list.extend(n.asTuple() for n in nm.getMasterList())
node_list.extend(n.asTuple() for n in nm.getStorageList()) node_list.extend(n.asTuple() for n in nm.getStorageList())
conn.notify(Packets.NotifyNodeInformation(node_list)) conn.notify(Packets.NotifyNodeInformation(node_list))
conn.answer(Packets.AnswerNodeInformation())
def askBeginTransaction(self, conn, tid): def askBeginTransaction(self, conn, tid):
""" """
......
...@@ -23,6 +23,9 @@ from . import MasterHandler ...@@ -23,6 +23,9 @@ from . import MasterHandler
class BaseElectionHandler(EventHandler): class BaseElectionHandler(EventHandler):
def _notifyNodeInformation(self, conn):
pass
def reelectPrimary(self, conn): def reelectPrimary(self, conn):
raise ElectionFailure, 'reelection requested' raise ElectionFailure, 'reelection requested'
...@@ -53,6 +56,11 @@ class BaseElectionHandler(EventHandler): ...@@ -53,6 +56,11 @@ class BaseElectionHandler(EventHandler):
class ClientElectionHandler(BaseElectionHandler): class ClientElectionHandler(BaseElectionHandler):
def notifyNodeInformation(self, conn, node_list):
# XXX: For the moment, do nothing because
# we'll close this connection and reconnect.
pass
def connectionFailed(self, conn): def connectionFailed(self, conn):
addr = conn.getAddress() addr = conn.getAddress()
node = self.app.nm.getByAddress(addr) node = self.app.nm.getByAddress(addr)
...@@ -68,6 +76,7 @@ class ClientElectionHandler(BaseElectionHandler): ...@@ -68,6 +76,7 @@ class ClientElectionHandler(BaseElectionHandler):
app.uuid, app.uuid,
app.server, app.server,
app.name, app.name,
None,
)) ))
super(ClientElectionHandler, self).connectionCompleted(conn) super(ClientElectionHandler, self).connectionCompleted(conn)
...@@ -126,8 +135,8 @@ class ServerElectionHandler(BaseElectionHandler, MasterHandler): ...@@ -126,8 +135,8 @@ class ServerElectionHandler(BaseElectionHandler, MasterHandler):
logging.info('reject a connection from a non-master') logging.info('reject a connection from a non-master')
raise NotReadyError raise NotReadyError
if node is None: if node is None is app.nm.getByAddress(address):
node = app.nm.createMaster(address=address) app.nm.createMaster(address=address)
self.elect(conn, address) self.elect(conn, address)
return uuid return uuid
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from time import time
from neo.lib import logging from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \ from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \
NotReadyError, ProtocolError, uuid_str NotReadyError, ProtocolError, uuid_str
...@@ -30,14 +31,32 @@ class IdentificationHandler(MasterHandler): ...@@ -30,14 +31,32 @@ class IdentificationHandler(MasterHandler):
def _setupNode(self, conn, node_type, uuid, address, node): def _setupNode(self, conn, node_type, uuid, address, node):
app = self.app app = self.app
if node: by_addr = address and app.nm.getByAddress(address)
if node.isRunning(): while 1:
# cloned/evil/buggy node connecting to us if by_addr:
raise ProtocolError('already connected') if not by_addr.isConnected():
if node is by_addr:
break
if not node or uuid < 0:
# In case of address conflict for a peer with temporary
# ids, we'll generate a new id.
node = by_addr
break
elif node:
if node.isConnected():
if uuid < 0:
# The peer wants a temporary id that's already assigned.
# Let's give it another one.
node = uuid = None
break
else:
node.setAddress(address)
break
# Id conflict for a storage node.
else: else:
assert not node.isConnected() break
node.setAddress(address) # cloned/evil/buggy node connecting to us
node.setRunning() raise ProtocolError('already connected')
state = NodeStates.RUNNING state = NodeStates.RUNNING
if node_type == NodeTypes.CLIENT: if node_type == NodeTypes.CLIENT:
...@@ -64,14 +83,16 @@ class IdentificationHandler(MasterHandler): ...@@ -64,14 +83,16 @@ class IdentificationHandler(MasterHandler):
handler = app.administration_handler handler = app.administration_handler
human_readable_node_type = 'n admin ' human_readable_node_type = 'n admin '
else: else:
raise NotImplementedError(node_type) raise ProtocolError
uuid = app.getNewUUID(uuid, address, node_type) uuid = app.getNewUUID(uuid, address, node_type)
logging.info('Accept a' + human_readable_node_type + uuid_str(uuid)) logging.info('Accept a' + human_readable_node_type + uuid_str(uuid))
if node is None: if node is None:
node = app.nm.createFromNodeType(node_type, node = app.nm.createFromNodeType(node_type,
uuid=uuid, address=address) uuid=uuid, address=address)
node.setUUID(uuid) else:
node.setUUID(uuid)
node.id_timestamp = time()
node.setState(state) node.setState(state)
node.setConnection(conn) node.setConnection(conn)
conn.setHandler(handler) conn.setHandler(handler)
......
...@@ -36,6 +36,10 @@ class SecondaryMasterHandler(MasterHandler): ...@@ -36,6 +36,10 @@ class SecondaryMasterHandler(MasterHandler):
def reelectPrimary(self, conn): def reelectPrimary(self, conn):
raise ElectionFailure, 'reelection requested' raise ElectionFailure, 'reelection requested'
def _notifyNodeInformation(self, conn):
node_list = [n.asTuple() for n in self.app.nm.getMasterList()]
conn.notify(Packets.NotifyNodeInformation(node_list))
class PrimaryHandler(EventHandler): class PrimaryHandler(EventHandler):
""" Handler used by secondaries to handle primary master""" """ Handler used by secondaries to handle primary master"""
...@@ -51,13 +55,14 @@ class PrimaryHandler(EventHandler): ...@@ -51,13 +55,14 @@ class PrimaryHandler(EventHandler):
app = self.app app = self.app
addr = conn.getAddress() addr = conn.getAddress()
node = app.nm.getByAddress(addr) node = app.nm.getByAddress(addr)
# connection successfull, set it as running # connection successful, set it as running
node.setRunning() node.setRunning()
conn.ask(Packets.RequestIdentification( conn.ask(Packets.RequestIdentification(
NodeTypes.MASTER, NodeTypes.MASTER,
app.uuid, app.uuid,
app.server, app.server,
app.name, app.name,
None,
)) ))
super(PrimaryHandler, self).connectionCompleted(conn) super(PrimaryHandler, self).connectionCompleted(conn)
...@@ -68,27 +73,11 @@ class PrimaryHandler(EventHandler): ...@@ -68,27 +73,11 @@ class PrimaryHandler(EventHandler):
self.app.cluster_state = state self.app.cluster_state = state
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
app = self.app super(PrimaryHandler, self).notifyNodeInformation(conn, node_list)
for node_type, addr, uuid, state in node_list: for node_type, _, uuid, state, _ in node_list:
if node_type != NodeTypes.MASTER: assert node_type == NodeTypes.MASTER, node_type
# No interest. if uuid == self.app.uuid and state == NodeStates.UNKNOWN:
continue
if uuid == app.uuid and state == NodeStates.UNKNOWN:
sys.exit() sys.exit()
# Register new master nodes.
if app.server == addr:
# This is self.
continue
else:
n = app.nm.getByAddress(addr)
# master node must be known
assert n is not None
if uuid is not None:
# If I don't know the UUID yet, believe what the peer
# told me at the moment.
if n.getUUID() is None:
n.setUUID(uuid)
def _acceptIdentification(self, node, uuid, num_partitions, def _acceptIdentification(self, node, uuid, num_partitions,
num_replicas, your_uuid, primary, known_master_list): num_replicas, your_uuid, primary, known_master_list):
...@@ -101,4 +90,5 @@ class PrimaryHandler(EventHandler): ...@@ -101,4 +90,5 @@ class PrimaryHandler(EventHandler):
logging.info('My UUID: ' + uuid_str(your_uuid)) logging.info('My UUID: ' + uuid_str(your_uuid))
node.setUUID(uuid) node.setUUID(uuid)
app.id_timestamp = None
...@@ -27,13 +27,11 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -27,13 +27,11 @@ class StorageServiceHandler(BaseServiceHandler):
def connectionCompleted(self, conn, new): def connectionCompleted(self, conn, new):
app = self.app app = self.app
uuid = conn.getUUID() uuid = conn.getUUID()
node = app.nm.getByUUID(uuid)
app.setStorageNotReady(uuid) app.setStorageNotReady(uuid)
if new: if new:
super(StorageServiceHandler, self).connectionCompleted(conn, new) super(StorageServiceHandler, self).connectionCompleted(conn, new)
# XXX: what other values could happen ? if app.nm.getByUUID(uuid).isRunning(): # node may be PENDING
if node.isRunning(): conn.notify(Packets.StartOperation(app.backup_tid))
conn.notify(Packets.StartOperation(bool(app.backup_tid)))
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
...@@ -85,7 +83,7 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -85,7 +83,7 @@ class StorageServiceHandler(BaseServiceHandler):
try: try:
cell_list = self.app.pt.setUpToDate(node, offset) cell_list = self.app.pt.setUpToDate(node, offset)
if not cell_list: if not cell_list:
raise ProtocolError('Non-oudated partition') raise ProtocolError('Non-outdated partition')
except PartitionTableException, e: except PartitionTableException, e:
raise ProtocolError(str(e)) raise ProtocolError(str(e))
logging.debug("%s is up for offset %s", node, offset) logging.debug("%s is up for offset %s", node, offset)
......
...@@ -334,7 +334,7 @@ class TransactionManager(object): ...@@ -334,7 +334,7 @@ class TransactionManager(object):
""" """
Set that a node has locked the transaction. Set that a node has locked the transaction.
If transaction is completely locked, calls function given at If transaction is completely locked, calls function given at
instanciation time. instantiation time.
""" """
logging.debug('Lock TXN %s for %s', dump(ttid), uuid_str(uuid)) logging.debug('Lock TXN %s for %s', dump(ttid), uuid_str(uuid))
if self[ttid].lock(uuid) and self._queue[0] == ttid: if self[ttid].lock(uuid) and self._queue[0] == ttid:
......
...@@ -174,7 +174,7 @@ class TerminalNeoCTL(object): ...@@ -174,7 +174,7 @@ class TerminalNeoCTL(object):
def tweakPartitionTable(self, params): def tweakPartitionTable(self, params):
""" """
Optimize partition table. Optimize partition table.
No partitition will be assigned to specified storage nodes. No partition will be assigned to specified storage nodes.
Parameters: [node [...]] Parameters: [node [...]]
""" """
return self.neoctl.tweakPartitionTable(map(self.asNode, params)) return self.neoctl.tweakPartitionTable(map(self.asNode, params))
...@@ -294,7 +294,7 @@ class Application(object): ...@@ -294,7 +294,7 @@ class Application(object):
if docstring is None: if docstring is None:
docstring = '(no docstring)' docstring = '(no docstring)'
docstring_line_list = docstring.split('\n') docstring_line_list = docstring.split('\n')
# Strip empty lines at begining & end of line list # Strip empty lines at beginning & end of line list
for end in (0, -1): for end in (0, -1):
while len(docstring_line_list) \ while len(docstring_line_list) \
and docstring_line_list[end] == '': and docstring_line_list[end] == '':
......
...@@ -146,15 +146,14 @@ class Log(object): ...@@ -146,15 +146,14 @@ class Log(object):
def notifyNodeInformation(self, node_list): def notifyNodeInformation(self, node_list):
node_list.sort(key=lambda x: x[2]) node_list.sort(key=lambda x: x[2])
node_list = [(self.uuid_str(uuid), str(node_type), node_list = [(self.uuid_str(x[2]), str(x[0]),
'%s:%u' % address if address else '?', state) '%s:%u' % x[1] if x[1] else '?', str(x[3]))
for node_type, address, uuid, state in node_list] + ((repr(x[4]),) if len(x) > 4 else ()) # BBB
for x in node_list]
if node_list: if node_list:
t = ' ! %%%us | %%%us | %%%us | %%s' % ( t = ''.join(' %%%us |' % max(len(x[i]) for x in node_list)
max(len(x[0]) for x in node_list), for i in xrange(len(node_list[0]) - 1))
max(len(x[1]) for x in node_list), return map((' !' + t + ' %s').__mod__, node_list)
max(len(x[2]) for x in node_list))
return map(t.__mod__, node_list)
return () return ()
......
...@@ -43,7 +43,7 @@ def main(args=None): ...@@ -43,7 +43,7 @@ def main(args=None):
# TODO: Forbid using "reset" along with any unneeded argument. # TODO: Forbid using "reset" along with any unneeded argument.
# "reset" is too dangerous to let user a chance of accidentally # "reset" is too dangerous to let user a chance of accidentally
# letting it slip through in a long option list. # letting it slip through in a long option list.
# We should drop support configation files to make such check useful. # We should drop support configuration files to make such check useful.
(options, args) = parser.parse_args(args=args) (options, args) = parser.parse_args(args=args)
config = ConfigurationManager(defaults, options, 'storage') config = ConfigurationManager(defaults, options, 'storage')
......
...@@ -29,11 +29,7 @@ from unittest.runner import _WritelnDecorator ...@@ -29,11 +29,7 @@ from unittest.runner import _WritelnDecorator
if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]): if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]):
# Start coverage as soon as possible. # Start coverage as soon as possible.
import coverage import coverage
coverage = coverage.Coverage(source=('neo',), omit=( coverage = coverage.Coverage()
'neo/debug.py',
'neo/scripts/runner.py',
'neo/tests/*',
))
coverage.start() coverage.start()
import neo import neo
......
...@@ -219,14 +219,11 @@ class Application(BaseApplication): ...@@ -219,14 +219,11 @@ class Application(BaseApplication):
conn.close() conn.close()
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, bootstrap = BootstrapManager(self, NodeTypes.STORAGE, self.server)
NodeTypes.STORAGE, self.uuid, self.server) self.master_node, self.master_conn, num_partitions, num_replicas = \
data = bootstrap.getPrimaryConnection() bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data uuid = self.uuid
self.master_node = node
self.master_conn = conn
logging.info('I am %s', uuid_str(uuid)) logging.info('I am %s', uuid_str(uuid))
self.uuid = uuid
self.dm.setUUID(uuid) self.dm.setUUID(uuid)
# Reload a partition table from the database. This is necessary # Reload a partition table from the database. This is necessary
......
...@@ -50,8 +50,8 @@ class Checker(object): ...@@ -50,8 +50,8 @@ class Checker(object):
conn.asClient() conn.asClient()
else: else:
conn = ClientConnection(app, StorageOperationHandler(app), node) conn = ClientConnection(app, StorageOperationHandler(app), node)
conn.ask(Packets.RequestIdentification( conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
NodeTypes.STORAGE, uuid, app.server, name)) uuid, app.server, name, app.id_timestamp))
self.conn_dict[conn] = node.isIdentified() self.conn_dict[conn] = node.isIdentified()
conn_set = set(self.conn_dict) conn_set = set(self.conn_dict)
conn_set.discard(None) conn_set.discard(None)
......
...@@ -78,7 +78,7 @@ class DatabaseManager(object): ...@@ -78,7 +78,7 @@ class DatabaseManager(object):
@abstract @abstract
def _parse(self, database): def _parse(self, database):
"""Called during instanciation, to process database parameter.""" """Called during instantiation, to process database parameter."""
def setup(self, reset=0): def setup(self, reset=0):
"""Set up a database, discarding existing data first if reset is True """Set up a database, discarding existing data first if reset is True
...@@ -94,7 +94,7 @@ class DatabaseManager(object): ...@@ -94,7 +94,7 @@ class DatabaseManager(object):
@abstract @abstract
def _setup(self): def _setup(self):
"""To be overriden by the backend to set up a database """To be overridden by the backend to set up a database
It must recover self._uncommitted_data from temporary object table. It must recover self._uncommitted_data from temporary object table.
_uncommitted_data is already instantiated and must be updated with _uncommitted_data is already instantiated and must be updated with
...@@ -417,7 +417,7 @@ class DatabaseManager(object): ...@@ -417,7 +417,7 @@ class DatabaseManager(object):
@abstract @abstract
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
"""To be overriden by the backend to delete any unreferenced data """To be overridden by the backend to delete any unreferenced data
'unreferenced' means: 'unreferenced' means:
- not in self._uncommitted_data - not in self._uncommitted_data
...@@ -427,7 +427,7 @@ class DatabaseManager(object): ...@@ -427,7 +427,7 @@ class DatabaseManager(object):
@abstract @abstract
def storeData(self, checksum, data, compression): def storeData(self, checksum, data, compression):
"""To be overriden by the backend to store object raw data """To be overridden by the backend to store object raw data
If same data was already stored, the storage only has to check there's If same data was already stored, the storage only has to check there's
no hash collision. no hash collision.
...@@ -491,7 +491,7 @@ class DatabaseManager(object): ...@@ -491,7 +491,7 @@ class DatabaseManager(object):
tid tid
Transation doing the undo Transation doing the undo
ltid ltid
Upper (exclued) bound of transactions visible to transaction doing Upper (excluded) bound of transactions visible to transaction doing
the undo. the undo.
undone_tid undone_tid
Transaction to undo Transaction to undo
...@@ -643,7 +643,7 @@ class DatabaseManager(object): ...@@ -643,7 +643,7 @@ class DatabaseManager(object):
@abstract @abstract
def checkTIDRange(self, partition, length, min_tid, max_tid): def checkTIDRange(self, partition, length, min_tid, max_tid):
""" """
Generate a diggest from transaction list. Generate a digest from transaction list.
min_tid (packed) min_tid (packed)
TID at which verification starts. TID at which verification starts.
length (int) length (int)
...@@ -660,7 +660,7 @@ class DatabaseManager(object): ...@@ -660,7 +660,7 @@ class DatabaseManager(object):
@abstract @abstract
def checkSerialRange(self, partition, length, min_tid, max_tid, min_oid): def checkSerialRange(self, partition, length, min_tid, max_tid, min_oid):
""" """
Generate a diggest from object list. Generate a digest from object list.
min_oid (packed) min_oid (packed)
OID at which verification starts. OID at which verification starts.
min_tid (packed) min_tid (packed)
......
...@@ -216,7 +216,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -216,7 +216,7 @@ class MySQLDatabaseManager(DatabaseManager):
engine += " compression='tokudb_uncompressed'" engine += " compression='tokudb_uncompressed'"
# The table "data" stores object data. # The table "data" stores object data.
# We'd like to have partial index on 'hash' colum (e.g. hash(4)) # We'd like to have partial index on 'hash' column (e.g. hash(4))
# but 'UNIQUE' constraint would not work as expected. # but 'UNIQUE' constraint would not work as expected.
q("""CREATE TABLE IF NOT EXISTS data ( q("""CREATE TABLE IF NOT EXISTS data (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
...@@ -634,7 +634,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -634,7 +634,7 @@ class MySQLDatabaseManager(DatabaseManager):
return oid_list, user, desc, ext, bool(packed), util.p64(ttid) return oid_list, user, desc, ext, bool(packed), util.p64(ttid)
def getObjectHistory(self, oid, offset, length): def getObjectHistory(self, oid, offset, length):
# FIXME: This method doesn't take client's current ransaction id as # FIXME: This method doesn't take client's current transaction id as
# parameter, which means it can return transactions in the future of # parameter, which means it can return transactions in the future of
# client's transaction. # client's transaction.
oid = util.u64(oid) oid = util.u64(oid)
......
...@@ -38,8 +38,8 @@ class BaseMasterHandler(EventHandler): ...@@ -38,8 +38,8 @@ class BaseMasterHandler(EventHandler):
def notifyNodeInformation(self, conn, node_list): def notifyNodeInformation(self, conn, node_list):
"""Store information on nodes, only if this is sent by a primary """Store information on nodes, only if this is sent by a primary
master node.""" master node."""
self.app.nm.update(node_list) super(BaseMasterHandler, self).notifyNodeInformation(conn, node_list)
for node_type, addr, uuid, state in node_list: for node_type, _, uuid, state, _ in node_list:
if uuid == self.app.uuid: if uuid == self.app.uuid:
# This is me, do what the master tell me # This is me, do what the master tell me
logging.info("I was told I'm %s", state) logging.info("I was told I'm %s", state)
......
...@@ -58,13 +58,6 @@ class ClientOperationHandler(EventHandler): ...@@ -58,13 +58,6 @@ class ClientOperationHandler(EventHandler):
compression, checksum, data, data_serial) compression, checksum, data, data_serial)
conn.answer(p) conn.answer(p)
def connectionLost(self, conn, new_state):
uuid = conn.getUUID()
node = self.app.nm.getByUUID(uuid)
if self.app.listening_conn: # if running
assert node is not None, conn
self.app.nm.remove(node)
def abortTransaction(self, conn, ttid): def abortTransaction(self, conn, ttid):
self.app.tm.abort(ttid) self.app.tm.abort(ttid)
...@@ -88,7 +81,7 @@ class ClientOperationHandler(EventHandler): ...@@ -88,7 +81,7 @@ class ClientOperationHandler(EventHandler):
except DelayedError: except DelayedError:
# locked by a previous transaction, retry later # locked by a previous transaction, retry later
# If we are unlocking, we want queueEvent to raise # If we are unlocking, we want queueEvent to raise
# AlreadyPendingError, to avoid making lcient wait for an unneeded # AlreadyPendingError, to avoid making client wait for an unneeded
# response. # response.
try: try:
self.app.queueEvent(self._askStoreObject, conn, (oid, serial, self.app.queueEvent(self._askStoreObject, conn, (oid, serial,
......
...@@ -27,13 +27,13 @@ class IdentificationHandler(EventHandler): ...@@ -27,13 +27,13 @@ class IdentificationHandler(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
logging.warning('A connection was lost during identification') logging.warning('A connection was lost during identification')
def requestIdentification(self, conn, node_type, def requestIdentification(self, conn, node_type, uuid, address, name,
uuid, address, name): id_timestamp):
self.checkClusterName(name) self.checkClusterName(name)
app = self.app
# reject any incoming connections if not ready # reject any incoming connections if not ready
if not self.app.ready: if not app.ready:
raise NotReadyError raise NotReadyError
app = self.app
if uuid is None: if uuid is None:
if node_type != NodeTypes.STORAGE: if node_type != NodeTypes.STORAGE:
raise ProtocolError('reject anonymous non-storage node') raise ProtocolError('reject anonymous non-storage node')
...@@ -42,9 +42,14 @@ class IdentificationHandler(EventHandler): ...@@ -42,9 +42,14 @@ class IdentificationHandler(EventHandler):
else: else:
if uuid == app.uuid: if uuid == app.uuid:
raise ProtocolError("uuid conflict or loopback connection") raise ProtocolError("uuid conflict or loopback connection")
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid, id_timestamp)
# If this node is broken, reject it. if node is None:
if node is not None and node.isBroken(): # Do never create node automatically, or we could get id
# conflicts. We must only rely on the notifications from the
# master to recognize nodes. So this is not always an error:
# maybe there are incoming notifications.
raise NotReadyError('unknown node: retry later')
if node.isBroken():
raise BrokenNodeDisallowedError raise BrokenNodeDisallowedError
# choose the handler according to the node type # choose the handler according to the node type
if node_type == NodeTypes.CLIENT: if node_type == NodeTypes.CLIENT:
...@@ -52,20 +57,9 @@ class IdentificationHandler(EventHandler): ...@@ -52,20 +57,9 @@ class IdentificationHandler(EventHandler):
handler = ClientReadOnlyOperationHandler handler = ClientReadOnlyOperationHandler
else: else:
handler = ClientOperationHandler handler = ClientOperationHandler
if node is None: assert not node.isConnected(), node
node = app.nm.createClient(uuid=uuid) assert node.isRunning(), node
elif node.isConnected():
# This can happen if we haven't processed yet a notification
# from the master, telling us the existing node is not
# running anymore. If we accept the new client, we won't
# know what to do with this late notification.
raise NotReadyError('uuid conflict: retry later')
node.setRunning()
elif node_type == NodeTypes.STORAGE: elif node_type == NodeTypes.STORAGE:
if node is None:
logging.error('reject an unknown storage node %s',
uuid_str(uuid))
raise NotReadyError
handler = StorageOperationHandler handler = StorageOperationHandler
else: else:
raise ProtocolError('reject non-client-or-storage node') raise ProtocolError('reject non-client-or-storage node')
......
...@@ -20,16 +20,13 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID ...@@ -20,16 +20,13 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID
class InitializationHandler(BaseMasterHandler): class InitializationHandler(BaseMasterHandler):
def answerNodeInformation(self, conn):
pass
def sendPartitionTable(self, conn, ptid, row_list): def sendPartitionTable(self, conn, ptid, row_list):
app = self.app app = self.app
pt = app.pt pt = app.pt
pt.load(ptid, row_list, self.app.nm) pt.load(ptid, row_list, self.app.nm)
if not pt.filled(): if not pt.filled():
raise ProtocolError('Partial partition table received') raise ProtocolError('Partial partition table received')
# Install the partition table into the database for persistency. # Install the partition table into the database for persistence.
cell_list = [] cell_list = []
num_partitions = pt.getPartitions() num_partitions = pt.getPartitions()
unassigned_set = set(xrange(num_partitions)) unassigned_set = set(xrange(num_partitions))
......
...@@ -258,7 +258,8 @@ class Replicator(object): ...@@ -258,7 +258,8 @@ class Replicator(object):
conn = ClientConnection(app, StorageOperationHandler(app), node) conn = ClientConnection(app, StorageOperationHandler(app), node)
try: try:
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name)) None if name else app.uuid, app.server, name or app.name,
app.id_timestamp))
except ConnectionClosed: except ConnectionClosed:
if previous_node is self.current_node: if previous_node is self.current_node:
return return
...@@ -336,7 +337,7 @@ class Replicator(object): ...@@ -336,7 +337,7 @@ class Replicator(object):
offset, message and ' (%s)' % message) offset, message and ' (%s)' % message)
if offset in self.partition_dict: if offset in self.partition_dict:
# XXX: Try another partition if possible, to increase probability to # XXX: Try another partition if possible, to increase probability to
# connect to another node. It would be better to explicitely # connect to another node. It would be better to explicitly
# search for another node instead. # search for another node instead.
tid = self.replicate_dict.pop(offset, None) or self.replicate_tid tid = self.replicate_dict.pop(offset, None) or self.replicate_tid
if self.replicate_dict: if self.replicate_dict:
......
...@@ -32,6 +32,7 @@ from functools import wraps ...@@ -32,6 +32,7 @@ from functools import wraps
from mock import Mock from mock import Mock
from neo.lib import debug, logging, protocol from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
from neo.lib.util import cached_property
from time import time from time import time
from struct import pack, unpack from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess from unittest.case import _ExpectedFailure, _UnexpectedSuccess
...@@ -194,6 +195,15 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -194,6 +195,15 @@ class NeoUnitTestBase(NeoTestBase):
self.uuid_dict = {} self.uuid_dict = {}
NeoTestBase.setUp(self) NeoTestBase.setUp(self)
@cached_property
def nm(self):
from neo.lib import node
return node.NodeManager()
def createStorage(self, *args):
return self.nm.createStorage(**dict(zip(
('address', 'uuid', 'state'), args)))
def prepareDatabase(self, number, prefix=DB_PREFIX): def prepareDatabase(self, number, prefix=DB_PREFIX):
""" create empty databases """ """ create empty databases """
adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL') adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL')
...@@ -312,7 +322,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -312,7 +322,7 @@ class NeoUnitTestBase(NeoTestBase):
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs) self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs): def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception wxas raised """ """ Check if the UnexpectedPacketError exception was raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs) self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs): def checkIdenficationRequired(self, method, *args, **kwargs):
...@@ -320,11 +330,11 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -320,11 +330,11 @@ class NeoUnitTestBase(NeoTestBase):
self.checkUnexpectedPacketRaised(method, *args, **kwargs) self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs): def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception wxas raised """ """ Check if the BrokenNodeDisallowedError exception was raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs) self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs): def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception wxas raised """ """ Check if the NotReadyError exception was raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs) self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
def checkAborted(self, conn): def checkAborted(self, conn):
...@@ -372,7 +382,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -372,7 +382,7 @@ class NeoUnitTestBase(NeoTestBase):
self.assertEqual(found_uuid, uuid) self.assertEqual(found_uuid, uuid)
# in check(Ask|Answer|Notify)Packet we return the packet so it can be used # in check(Ask|Answer|Notify)Packet we return the packet so it can be used
# in tests if more accurates checks are required # in tests if more accurate checks are required
def checkErrorPacket(self, conn, decode=False): def checkErrorPacket(self, conn, decode=False):
""" Check if an error packet was answered """ """ Check if an error packet was answered """
......
...@@ -81,7 +81,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -81,7 +81,7 @@ class ClientApplicationTests(NeoUnitTestBase):
# stop threads # stop threads
for app in self._to_stop_list: for app in self._to_stop_list:
app.close() app.close()
# restore environnement # restore environment
Application._ask = self._ask Application._ask = self._ask
Application._getMasterConnection = self._getMasterConnection Application._getMasterConnection = self._getMasterConnection
NeoUnitTestBase._tearDown(self, success) NeoUnitTestBase._tearDown(self, success)
...@@ -596,7 +596,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -596,7 +596,7 @@ class ClientApplicationTests(NeoUnitTestBase):
Object oid previous revision before tid1 is tid0. Object oid previous revision before tid1 is tid0.
Transaction tid2 modified oid (and contains its data). Transaction tid2 modified oid (and contains its data).
Undo is rejeced with a raise, because conflict resolution fails. Undo is rejected with a raise, because conflict resolution fails.
""" """
oid0 = self.makeOID(1) oid0 = self.makeOID(1)
tid0 = self.getNextTID() tid0 = self.getNextTID()
...@@ -753,11 +753,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -753,11 +753,7 @@ class ClientApplicationTests(NeoUnitTestBase):
# will raise IndexError at the third iteration # will raise IndexError at the third iteration
app = self.getApp('127.0.0.1:10010 127.0.0.1:10011') app = self.getApp('127.0.0.1:10010 127.0.0.1:10011')
# TODO: test more connection failure cases # TODO: test more connection failure cases
all_passed = []
# askLastTransaction # askLastTransaction
def _ask9(_):
all_passed.append(1)
# Seventh packet : askNodeInformation succeeded
def _ask8(_): def _ask8(_):
pass pass
# Sixth packet : askPartitionTable succeeded # Sixth packet : askPartitionTable succeeded
...@@ -789,19 +785,18 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -789,19 +785,18 @@ class ClientApplicationTests(NeoUnitTestBase):
# telling us what its address is.) # telling us what its address is.)
def _ask1(_): def _ask1(_):
pass pass
ask_func_list = [_ask1, _ask2, _ask3, _ask4, _ask6, _ask7, ask_func_list = [_ask1, _ask2, _ask3, _ask4, _ask6, _ask7, _ask8]
_ask8, _ask9]
def _ask_base(conn, _, handler=None): def _ask_base(conn, _, handler=None):
ask_func_list.pop(0)(conn) ask_func_list.pop(0)(conn)
app.nm.getByAddress(conn.getAddress())._connection = None app.nm.getByAddress(conn.getAddress())._connection = None
app._ask = _ask_base app._ask = _ask_base
# faked environnement # fake environment
app.em.close() app.em.close()
app.em = Mock({'getConnectionList': []}) app.em = Mock({'getConnectionList': []})
app.pt = Mock({ 'operational': False}) app.pt = Mock({ 'operational': False})
app.start = lambda: None app.start = lambda: None
app.master_conn = app._connectToPrimaryNode() app.master_conn = app._connectToPrimaryNode()
self.assertEqual(len(all_passed), 1) self.assertFalse(ask_func_list)
self.assertTrue(app.master_conn is not None) self.assertTrue(app.master_conn is not None)
self.assertTrue(app.pt.operational()) self.assertTrue(app.pt.operational())
...@@ -831,11 +826,11 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -831,11 +826,11 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertTrue(self.test_ok) self.assertTrue(self.test_ok)
# check NEOStorageError is raised when the primary connection is lost # check NEOStorageError is raised when the primary connection is lost
app.master_conn = None app.master_conn = None
# check disabled since we reonnect to pmn # check disabled since we reconnect to pmn
#self.assertRaises(NEOStorageError, app._askPrimary, packet) #self.assertRaises(NEOStorageError, app._askPrimary, packet)
def test_threadContextIsolation(self): def test_threadContextIsolation(self):
""" Thread context properties must not be visible accross instances """ Thread context properties must not be visible across instances
while remaining in the same thread """ while remaining in the same thread """
app1 = self.getApp() app1 = self.getApp()
app1_local = app1._thread_container app1_local = app1._thread_container
......
...@@ -44,69 +44,6 @@ class MasterHandlerTests(NeoUnitTestBase): ...@@ -44,69 +44,6 @@ class MasterHandlerTests(NeoUnitTestBase):
node.setConnection(conn) node.setConnection(conn)
return node, conn return node, conn
class MasterBootstrapHandlerTests(MasterHandlerTests):
def setUp(self):
super(MasterBootstrapHandlerTests, self).setUp()
self.handler = PrimaryBootstrapHandler(self.app)
def checkCalledOnApp(self, method, index=0):
calls = self.app.mockGetNamedCalls(method)
self.assertTrue(len(calls) > index)
return calls[index].params
def test_notReady(self):
conn = self.getFakeConnection()
self.handler.notReady(conn, 'message')
self.assertEqual(self.app.trying_master_node, None)
def test_acceptIdentification1(self):
""" Non-master node """
node, conn = self.getKnownMaster()
self.handler.acceptIdentification(conn, NodeTypes.CLIENT,
node.getUUID(), 100, 0, None, None, [])
self.checkClosed(conn)
def test_acceptIdentification2(self):
""" No UUID supplied """
node, conn = self.getKnownMaster()
uuid = self.getMasterUUID()
addr = conn.getAddress()
self.checkProtocolErrorRaised(self.handler.acceptIdentification,
conn, NodeTypes.MASTER, uuid, 100, 0, None,
addr, [(addr, uuid)],
)
def test_acceptIdentification3(self):
""" identification accepted """
node, conn = self.getKnownMaster()
uuid = self.getMasterUUID()
addr = conn.getAddress()
your_uuid = self.getClientUUID()
self.handler.acceptIdentification(conn, NodeTypes.MASTER, uuid,
100, 2, your_uuid, addr, [(addr, uuid)])
self.assertEqual(self.app.uuid, your_uuid)
self.assertEqual(node.getUUID(), uuid)
self.assertTrue(isinstance(self.app.pt, PartitionTable))
def _getMasterList(self, uuid_list):
port = 1000
master_list = []
for uuid in uuid_list:
master_list.append((('127.0.0.1', port), uuid))
port += 1
return master_list
def test_answerPartitionTable(self):
conn = self.getFakeConnection()
self.app.pt = Mock()
ptid = 0
row_list = ([], [])
self.handler.answerPartitionTable(conn, ptid, row_list)
load_calls = self.app.pt.mockGetNamedCalls('load')
self.assertEqual(len(load_calls), 1)
# load_calls[0].checkArgs(ptid, row_list, self.app.nm)
class MasterNotificationsHandlerTests(MasterHandlerTests): class MasterNotificationsHandlerTests(MasterHandlerTests):
......
...@@ -119,7 +119,7 @@ class NEOProcess(object): ...@@ -119,7 +119,7 @@ class NEOProcess(object):
except ImportError: except ImportError:
raise NotFound, '%s not found' % (command) raise NotFound, '%s not found' % (command)
self.command = command self.command = command
self.arg_dict = {'--' + k: v for k, v in arg_dict.iteritems()} self.arg_dict = arg_dict
self.with_uuid = True self.with_uuid = True
self.setUUID(uuid) self.setUUID(uuid)
...@@ -131,11 +131,11 @@ class NEOProcess(object): ...@@ -131,11 +131,11 @@ class NEOProcess(object):
args = [] args = []
self.with_uuid = with_uuid self.with_uuid = with_uuid
for arg, param in self.arg_dict.iteritems(): for arg, param in self.arg_dict.iteritems():
if with_uuid is False and arg == '--uuid': args.append('--' + arg)
continue
args.append(arg)
if param is not None: if param is not None:
args.append(str(param)) args.append(str(param))
if with_uuid:
args += '--uuid', str(self.uuid)
self.pid = os.fork() self.pid = os.fork()
if self.pid == 0: if self.pid == 0:
# Child # Child
...@@ -183,7 +183,7 @@ class NEOProcess(object): ...@@ -183,7 +183,7 @@ class NEOProcess(object):
self.wait() self.wait()
except: except:
# We can ignore all exceptions at this point, since there is no # We can ignore all exceptions at this point, since there is no
# garanteed way to handle them (other objects we would depend on # guaranteed way to handle them (other objects we would depend on
# might already have been deleted). # might already have been deleted).
pass pass
...@@ -213,7 +213,6 @@ class NEOProcess(object): ...@@ -213,7 +213,6 @@ class NEOProcess(object):
Note: for this change to take effect, the node must be restarted. Note: for this change to take effect, the node must be restarted.
""" """
self.uuid = uuid self.uuid = uuid
self.arg_dict['--uuid'] = str(uuid)
def isAlive(self): def isAlive(self):
try: try:
...@@ -305,7 +304,6 @@ class NEOCluster(object): ...@@ -305,7 +304,6 @@ class NEOCluster(object):
def _newProcess(self, node_type, logfile=None, port=None, **kw): def _newProcess(self, node_type, logfile=None, port=None, **kw):
self.uuid_dict[node_type] = uuid = 1 + self.uuid_dict.get(node_type, 0) self.uuid_dict[node_type] = uuid = 1 + self.uuid_dict.get(node_type, 0)
uuid += UUID_NAMESPACES[node_type] << 24 uuid += UUID_NAMESPACES[node_type] << 24
kw['uuid'] = uuid
kw['cluster'] = self.cluster_name kw['cluster'] = self.cluster_name
kw['masters'] = self.master_nodes kw['masters'] = self.master_nodes
if logfile: if logfile:
...@@ -491,13 +489,9 @@ class NEOCluster(object): ...@@ -491,13 +489,9 @@ class NEOCluster(object):
return self.__getNodeList(NodeTypes.CLIENT, state) return self.__getNodeList(NodeTypes.CLIENT, state)
def __getNodeState(self, node_type, uuid): def __getNodeState(self, node_type, uuid):
node_list = self.__getNodeList(node_type) for node in self.__getNodeList(node_type):
for node_type, address, node_uuid, state in node_list: if node[2] == uuid:
if node_uuid == uuid: return node[3]
break
else:
state = None
return state
def getMasterNodeState(self, uuid): def getMasterNodeState(self, uuid):
return self.__getNodeState(NodeTypes.MASTER, uuid) return self.__getNodeState(NodeTypes.MASTER, uuid)
...@@ -573,7 +567,7 @@ class NEOCluster(object): ...@@ -573,7 +567,7 @@ class NEOCluster(object):
def callback(last_try): def callback(last_try):
current_try = self.getPrimary() current_try = self.getPrimary()
if None not in (uuid, current_try) and uuid != current_try: if None not in (uuid, current_try) and uuid != current_try:
raise AssertionError, 'An unexpected primary arised: %r, ' \ raise AssertionError, 'An unexpected primary arose: %r, ' \
'expected %r' % (dump(current_try), dump(uuid)) 'expected %r' % (dump(current_try), dump(uuid))
return uuid is None or uuid == current_try, current_try return uuid is None or uuid == current_try, current_try
self.expectCondition(callback, *args, **kw) self.expectCondition(callback, *args, **kw)
...@@ -581,12 +575,12 @@ class NEOCluster(object): ...@@ -581,12 +575,12 @@ class NEOCluster(object):
def expectOudatedCells(self, number, *args, **kw): def expectOudatedCells(self, number, *args, **kw):
def callback(last_try): def callback(last_try):
row_list = self.neoctl.getPartitionRowList()[1] row_list = self.neoctl.getPartitionRowList()[1]
number_of_oudated = 0 number_of_outdated = 0
for row in row_list: for row in row_list:
for cell in row[1]: for cell in row[1]:
if cell[1] == CellStates.OUT_OF_DATE: if cell[1] == CellStates.OUT_OF_DATE:
number_of_oudated += 1 number_of_outdated += 1
return number_of_oudated == number, number_of_oudated return number_of_outdated == number, number_of_outdated
self.expectCondition(callback, *args, **kw) self.expectCondition(callback, *args, **kw)
def expectAssignedCells(self, process, number, *args, **kw): def expectAssignedCells(self, process, number, *args, **kw):
......
...@@ -43,7 +43,7 @@ class Tree(Persistent): ...@@ -43,7 +43,7 @@ class Tree(Persistent):
self.left = Tree(depth) self.left = Tree(depth)
# simple persitent object with conflict resolution # simple persistent object with conflict resolution
class PCounter(Persistent): class PCounter(Persistent):
_value = 0 _value = 0
...@@ -131,7 +131,7 @@ class ClientTests(NEOFunctionalTest): ...@@ -131,7 +131,7 @@ class ClientTests(NEOFunctionalTest):
c2.root()['other'] c2.root()['other']
c1.root()['item'] = 1 c1.root()['item'] = 1
t1.commit() t1.commit()
# load objet from zope cache # load object from zope cache
self.assertEqual(c1.root()['item'], 1) self.assertEqual(c1.root()['item'], 1)
self.assertEqual(c2.root()['item'], 0) self.assertEqual(c2.root()['item'], 0)
...@@ -334,7 +334,7 @@ class ClientTests(NEOFunctionalTest): ...@@ -334,7 +334,7 @@ class ClientTests(NEOFunctionalTest):
t3.user = 'user' t3.user = 'user'
t3.description = 'desc' t3.description = 'desc'
st3.tpc_begin(t3) st3.tpc_begin(t3)
# retreive the last revision # retrieve the last revision
data, serial = st3.load(oid) data, serial = st3.load(oid)
# try to store again, should not be delayed # try to store again, should not be delayed
st3.store(oid, serial, data, '', t3) st3.store(oid, serial, data, '', t3)
......
...@@ -63,7 +63,7 @@ class MasterTests(NEOFunctionalTest): ...@@ -63,7 +63,7 @@ class MasterTests(NEOFunctionalTest):
# BUG: The following check expects neoctl to reconnect before # BUG: The following check expects neoctl to reconnect before
# the election finishes. # the election finishes.
self.assertEqual(self.neo.getPrimary(), None) self.assertEqual(self.neo.getPrimary(), None)
# Check that a primary master arised. # Check that a primary master arose.
self.neo.expectPrimary(timeout=10) self.neo.expectPrimary(timeout=10)
# Check that the uuid really changed. # Check that the uuid really changed.
new_uuid = self.neo.getPrimary() new_uuid = self.neo.getPrimary()
...@@ -83,7 +83,7 @@ class MasterTests(NEOFunctionalTest): ...@@ -83,7 +83,7 @@ class MasterTests(NEOFunctionalTest):
uuid, = self.neo.killPrimary() uuid, = self.neo.killPrimary()
# Check the state of the primary we just killed # Check the state of the primary we just killed
self.neo.expectMasterState(uuid, (None, NodeStates.UNKNOWN)) self.neo.expectMasterState(uuid, (None, NodeStates.UNKNOWN))
# Check that a primary master arised. # Check that a primary master arose.
self.neo.expectPrimary(timeout=10) self.neo.expectPrimary(timeout=10)
# Check that the uuid really changed. # Check that the uuid really changed.
self.assertNotEqual(self.neo.getPrimary(), uuid) self.assertNotEqual(self.neo.getPrimary(), uuid)
......
...@@ -69,7 +69,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -69,7 +69,7 @@ class StorageTests(NEOFunctionalTest):
def __checkDatabase(self, db_name): def __checkDatabase(self, db_name):
db = self.neo.getSQLConnection(db_name) db = self.neo.getSQLConnection(db_name)
# wait for the sql transaction to be commited # wait for the sql transaction to be committed
def callback(last_try): def callback(last_try):
db.commit() # to get a fresh view db.commit() # to get a fresh view
# One revision per object and two for the root, before and after # One revision per object and two for the root, before and after
...@@ -157,7 +157,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -157,7 +157,7 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
def testOudatedCellsOnDownStorage(self): def testOudatedCellsOnDownStorage(self):
""" Check that the storage cells are set as oudated when the node is """ Check that the storage cells are set as outdated when the node is
down, the cluster remains up since there is a replica """ down, the cluster remains up since there is a replica """
# populate the two storages # populate the two storages
...@@ -185,7 +185,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -185,7 +185,7 @@ class StorageTests(NEOFunctionalTest):
def testVerificationTriggered(self): def testVerificationTriggered(self):
""" Check that the verification stage is executed when a storage node """ Check that the verification stage is executed when a storage node
required to be operationnal is lost, and the cluster come back in required to be operational is lost, and the cluster come back in
running state when the storage is up again """ running state when the storage is up again """
# start neo with one storages # start neo with one storages
...@@ -444,7 +444,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -444,7 +444,7 @@ class StorageTests(NEOFunctionalTest):
st.tpc_begin(t) st.tpc_begin(t)
st.store(oid, rev, data, '', t) st.store(oid, rev, data, '', t)
# start the oudated storage # start the outdated storage
stopped[0].start() stopped[0].start()
self.neo.expectPending(stopped[0]) self.neo.expectPending(stopped[0])
self.neo.neoctl.enableStorageList([stopped[0].getUUID()]) self.neo.neoctl.enableStorageList([stopped[0].getUUID()])
......
...@@ -108,10 +108,12 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -108,10 +108,12 @@ class MasterClientHandlerTests(NeoUnitTestBase):
# do the right job # do the right job
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
storage_uuid = self.storage_uuid storage_uuid = self.storage_uuid
storage_conn = self.getFakeConnection(storage_uuid, self.storage_address) storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address, is_server=True)
storage2_uuid = self.identifyToMasterNode(port=10022) storage2_uuid = self.identifyToMasterNode(port=10022)
storage2_conn = self.getFakeConnection(storage2_uuid, storage2_conn = self.getFakeConnection(storage2_uuid,
(self.storage_address[0], self.storage_address[1] + 1)) (self.storage_address[0], self.storage_address[1] + 1),
is_server=True)
self.app.setStorageReady(storage2_uuid) self.app.setStorageReady(storage2_uuid)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
self.app.pt = Mock({ self.app.pt = Mock({
...@@ -142,18 +144,6 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -142,18 +144,6 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.assertEqual(len(txn.getOIDList()), 0) self.assertEqual(len(txn.getOIDList()), 0)
self.assertEqual(len(txn.getUUIDList()), 1) self.assertEqual(len(txn.getUUIDList()), 1)
def test_askNodeInformations(self):
# check that only informations about master and storages nodes are
# send to a client
self.app.nm.createClient()
conn = self.getFakeConnection()
self.service.askNodeInformation(conn)
calls = conn.mockGetNamedCalls('notify')
self.assertEqual(len(calls), 1)
packet = calls[0].getParam(0)
(node_list, ) = packet.decode()
self.assertEqual(len(node_list), 2)
def test_connectionClosed(self): def test_connectionClosed(self):
# give a client uuid which have unfinished transactions # give a client uuid which have unfinished transactions
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
...@@ -176,7 +166,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -176,7 +166,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
conn = self.getFakeConnection(peer_id=peer_id) conn = self.getFakeConnection(peer_id=peer_id)
storage_uuid = self.storage_uuid storage_uuid = self.storage_uuid
storage_conn = self.getFakeConnection(storage_uuid, storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address) self.storage_address, is_server=True)
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
...@@ -189,7 +179,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -189,7 +179,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
# Asking again to pack will cause an immediate error # Asking again to pack will cause an immediate error
storage_uuid = self.identifyToMasterNode(port=10022) storage_uuid = self.identifyToMasterNode(port=10022)
storage_conn = self.getFakeConnection(storage_uuid, storage_conn = self.getFakeConnection(storage_uuid,
self.storage_address) self.storage_address, is_server=True)
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
......
...@@ -225,13 +225,13 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -225,13 +225,13 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
def _tearDown(self, success): def _tearDown(self, success):
NeoUnitTestBase._tearDown(self, success) NeoUnitTestBase._tearDown(self, success)
# restore environnement # restore environment
del ClientConnection._addPacket del ClientConnection._addPacket
def test_requestIdentification1(self): def test_requestIdentification1(self):
""" A non-master node request identification """ """ A non-master node request identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
args = (node.getUUID(), node.getAddress(), self.app.name) args = node.getUUID(), node.getAddress(), self.app.name, None
self.assertRaises(protocol.NotReadyError, self.assertRaises(protocol.NotReadyError,
self.election.requestIdentification, self.election.requestIdentification,
conn, NodeTypes.CLIENT, *args) conn, NodeTypes.CLIENT, *args)
...@@ -240,7 +240,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -240,7 +240,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
""" A broken master node request identification """ """ A broken master node request identification """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
node.setBroken() node.setBroken()
args = (node.getUUID(), node.getAddress(), self.app.name) args = node.getUUID(), node.getAddress(), self.app.name, None
self.assertRaises(protocol.BrokenNodeDisallowedError, self.assertRaises(protocol.BrokenNodeDisallowedError,
self.election.requestIdentification, self.election.requestIdentification,
conn, NodeTypes.MASTER, *args) conn, NodeTypes.MASTER, *args)
...@@ -248,7 +248,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -248,7 +248,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
def test_requestIdentification4(self): def test_requestIdentification4(self):
""" No conflict """ """ No conflict """
node, conn = self.identifyToMasterNode() node, conn = self.identifyToMasterNode()
args = (node.getUUID(), node.getAddress(), self.app.name) args = node.getUUID(), node.getAddress(), self.app.name, None
self.election.requestIdentification(conn, self.election.requestIdentification(conn,
NodeTypes.MASTER, *args) NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID()) self.checkUUIDSet(conn, node.getUUID())
...@@ -280,11 +280,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -280,11 +280,12 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
conn = self.__getClient() conn = self.__getClient()
self.checkNotReadyErrorRaised( self.checkNotReadyErrorRaised(
self.election.requestIdentification, self.election.requestIdentification,
conn=conn, conn,
node_type=NodeTypes.CLIENT, NodeTypes.CLIENT,
uuid=conn.getUUID(), conn.getUUID(),
address=conn.getAddress(), conn.getAddress(),
name=self.app.name self.app.name,
None,
) )
def _requestIdentification(self): def _requestIdentification(self):
...@@ -297,6 +298,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -297,6 +298,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
peer_uuid, peer_uuid,
address, address,
self.app.name, self.app.name,
None,
) )
node_type, uuid, partitions, replicas, _peer_uuid, primary, \ node_type, uuid, partitions, replicas, _peer_uuid, primary, \
master_list = self.checkAcceptIdentification(conn, decode=True) master_list = self.checkAcceptIdentification(conn, decode=True)
......
...@@ -41,8 +41,8 @@ class MasterAppTests(NeoUnitTestBase): ...@@ -41,8 +41,8 @@ class MasterAppTests(NeoUnitTestBase):
client = self.app.nm.createClient(uuid=client_uuid) client = self.app.nm.createClient(uuid=client_uuid)
# create conn and patch em # create conn and patch em
master_conn = self.getFakeConnection() master_conn = self.getFakeConnection()
storage_conn = self.getFakeConnection() storage_conn = self.getFakeConnection(is_server=True)
client_conn = self.getFakeConnection() client_conn = self.getFakeConnection(is_server=True)
master.setConnection(master_conn) master.setConnection(master_conn)
storage.setConnection(storage_conn) storage.setConnection(storage_conn)
client.setConnection(client_conn) client.setConnection(client_conn)
......
...@@ -21,7 +21,6 @@ from .. import NeoUnitTestBase ...@@ -21,7 +21,6 @@ from .. import NeoUnitTestBase
from neo.lib.protocol import NodeStates, CellStates from neo.lib.protocol import NodeStates, CellStates
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
from neo.master.pt import PartitionTable from neo.master.pt import PartitionTable
from neo.lib.node import StorageNode
class MasterPartitionTableTests(NeoUnitTestBase): class MasterPartitionTableTests(NeoUnitTestBase):
...@@ -55,19 +54,19 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -55,19 +54,19 @@ class MasterPartitionTableTests(NeoUnitTestBase):
# create nodes # create nodes
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19002) server2 = ("127.0.0.2", 19002)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19003) server3 = ("127.0.0.3", 19003)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
uuid4 = self.getStorageUUID() uuid4 = self.getStorageUUID()
server4 = ("127.0.0.4", 19004) server4 = ("127.0.0.4", 19004)
sn4 = StorageNode(Mock(), server4, uuid4) sn4 = self.createStorage(server4, uuid4)
uuid5 = self.getStorageUUID() uuid5 = self.getStorageUUID()
server5 = ("127.0.0.5", 19005) server5 = ("127.0.0.5", 19005)
sn5 = StorageNode(Mock(), server5, uuid5) sn5 = self.createStorage(server5, uuid5)
# create partition table # create partition table
num_partitions = 5 num_partitions = 5
num_replicas = 3 num_replicas = 3
...@@ -117,7 +116,7 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -117,7 +116,7 @@ class MasterPartitionTableTests(NeoUnitTestBase):
self.assertEqual(cell.getState(), CellStates.UP_TO_DATE) self.assertEqual(cell.getState(), CellStates.UP_TO_DATE)
def test_15_dropNodeList(self): def test_15_dropNodeList(self):
sn = [StorageNode(Mock(), None, i + 1, NodeStates.RUNNING) sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(3)] for i in xrange(3)]
pt = PartitionTable(3, 0) pt = PartitionTable(3, 0)
pt.setCell(0, sn[0], CellStates.OUT_OF_DATE) pt.setCell(0, sn[0], CellStates.OUT_OF_DATE)
...@@ -153,22 +152,22 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -153,22 +152,22 @@ class MasterPartitionTableTests(NeoUnitTestBase):
# add nodes # add nodes
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1, NodeStates.RUNNING) sn1 = self.createStorage(server1, uuid1, NodeStates.RUNNING)
# add not running node # add not running node
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
sn2.setState(NodeStates.TEMPORARILY_DOWN) sn2.setState(NodeStates.TEMPORARILY_DOWN)
# add node without uuid # add node without uuid
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, None, NodeStates.RUNNING) sn3 = self.createStorage(server3, None, NodeStates.RUNNING)
# add clear node # add clear node
uuid4 = self.getStorageUUID() uuid4 = self.getStorageUUID()
server4 = ("127.0.0.4", 19001) server4 = ("127.0.0.4", 19001)
sn4 = StorageNode(Mock(), server4, uuid4, NodeStates.RUNNING) sn4 = self.createStorage(server4, uuid4, NodeStates.RUNNING)
uuid5 = self.getStorageUUID() uuid5 = self.getStorageUUID()
server5 = ("127.0.0.5", 1900) server5 = ("127.0.0.5", 1900)
sn5 = StorageNode(Mock(), server5, uuid5, NodeStates.RUNNING) sn5 = self.createStorage(server5, uuid5, NodeStates.RUNNING)
# make the table # make the table
pt.make([sn1, sn2, sn3, sn4, sn5]) pt.make([sn1, sn2, sn3, sn4, sn5])
# check it's ok, only running nodes and node with uuid # check it's ok, only running nodes and node with uuid
...@@ -231,7 +230,7 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -231,7 +230,7 @@ class MasterPartitionTableTests(NeoUnitTestBase):
return change_list return change_list
def test_17_tweak(self): def test_17_tweak(self):
sn = [StorageNode(Mock(), None, i + 1, NodeStates.RUNNING) sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(5)] for i in xrange(5)]
pt = PartitionTable(5, 2) pt = PartitionTable(5, 2)
# part 0 # part 0
......
...@@ -63,7 +63,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -63,7 +63,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
uuid = self.getNewUUID(node_type) uuid = self.getNewUUID(node_type)
node = nm.createFromNodeType(node_type, address=(ip, port), node = nm.createFromNodeType(node_type, address=(ip, port),
uuid=uuid) uuid=uuid)
conn = self.getFakeConnection(node.getUUID(), node.getAddress()) conn = self.getFakeConnection(node.getUUID(), node.getAddress(), True)
node.setConnection(conn) node.setConnection(conn)
return (node, conn) return (node, conn)
...@@ -160,7 +160,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -160,7 +160,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.assertEqual(lptid, self.app.pt.getID()) self.assertEqual(lptid, self.app.pt.getID())
def test_answerPack(self): def test_answerPack(self):
# Note: incomming status has no meaning here, so it's left to False. # Note: incoming status has no meaning here, so it's left to False.
node1, conn1 = self._getStorage() node1, conn1 = self._getStorage()
node2, conn2 = self._getStorage() node2, conn2 = self._getStorage()
self.app.packing = None self.app.packing = None
......
...@@ -169,7 +169,7 @@ class testTransactionManager(NeoUnitTestBase): ...@@ -169,7 +169,7 @@ class testTransactionManager(NeoUnitTestBase):
""" """
Transaction lock is present to ensure invalidation TIDs are sent in Transaction lock is present to ensure invalidation TIDs are sent in
strictly increasing order. strictly increasing order.
Note: this implementation might change later, to allow more paralelism. Note: this implementation might change later, for more parallelism.
""" """
client_uuid, client = self.makeNode(NodeTypes.CLIENT) client_uuid, client = self.makeNode(NodeTypes.CLIENT)
tm = TransactionManager(lambda tid, txn: None) tm = TransactionManager(lambda tid, txn: None)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
from mock import Mock from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NotReadyError, \ from neo.lib.protocol import NodeStates, NodeTypes, NotReadyError, \
BrokenNodeDisallowedError BrokenNodeDisallowedError
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.storage.app import Application from neo.storage.app import Application
...@@ -50,6 +50,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -50,6 +50,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
self.getClientUUID(), self.getClientUUID(),
None, None,
self.app.name, self.app.name,
None,
) )
self.app.ready = True self.app.ready = True
self.assertRaises( self.assertRaises(
...@@ -60,6 +61,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -60,6 +61,7 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
self.getStorageUUID(), self.getStorageUUID(),
None, None,
self.app.name, self.app.name,
None,
) )
def test_requestIdentification3(self): def test_requestIdentification3(self):
...@@ -75,19 +77,20 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase): ...@@ -75,19 +77,20 @@ class StorageIdentificationHandlerTests(NeoUnitTestBase):
uuid, uuid,
None, None,
self.app.name, self.app.name,
None,
) )
def test_requestIdentification2(self): def test_requestIdentification2(self):
""" accepted client must be connected and running """ """ accepted client must be connected and running """
uuid = self.getClientUUID() uuid = self.getClientUUID()
conn = self.getFakeConnection(uuid=uuid) conn = self.getFakeConnection(uuid=uuid)
node = self.app.nm.createClient(uuid=uuid) node = self.app.nm.createClient(uuid=uuid, state=NodeStates.RUNNING)
master = (self.local_ip, 3000) master = (self.local_ip, 3000)
self.app.master_node = Mock({ self.app.master_node = Mock({
'getAddress': master, 'getAddress': master,
}) })
self.identification.requestIdentification(conn, NodeTypes.CLIENT, uuid, self.identification.requestIdentification(conn, NodeTypes.CLIENT, uuid,
None, self.app.name) None, self.app.name, None)
self.assertTrue(node.isRunning()) self.assertTrue(node.isRunning())
self.assertTrue(node.isConnected()) self.assertTrue(node.isConnected())
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
......
...@@ -167,17 +167,17 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -167,17 +167,17 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid1, tid1), None) self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None) self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one non-commited version # one non-committed version
with self.commitTransaction(tid1, objs1, txn1): with self.commitTransaction(tid1, objs1, txn1):
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid1, tid1), None) self.assertEqual(self.db.getObject(oid1, tid1), None)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None) self.assertEqual(self.db.getObject(oid1, before_tid=tid1), None)
# one commited version # one committed version
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
FOUND_BUT_NOT_VISIBLE) FOUND_BUT_NOT_VISIBLE)
# two version available, one non-commited # two version available, one non-committed
with self.commitTransaction(tid2, objs2, txn2): with self.commitTransaction(tid2, objs2, txn2):
self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1), OBJECT_T1_NO_NEXT)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT) self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NO_NEXT)
...@@ -187,7 +187,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -187,7 +187,7 @@ class StorageDBTests(NeoUnitTestBase):
FOUND_BUT_NOT_VISIBLE) FOUND_BUT_NOT_VISIBLE)
self.assertEqual(self.db.getObject(oid1, before_tid=tid2), self.assertEqual(self.db.getObject(oid1, before_tid=tid2),
OBJECT_T1_NO_NEXT) OBJECT_T1_NO_NEXT)
# two commited versions # two committed versions
self.assertEqual(self.db.getObject(oid1), OBJECT_T2) self.assertEqual(self.db.getObject(oid1), OBJECT_T2)
self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NEXT) self.assertEqual(self.db.getObject(oid1, tid1), OBJECT_T1_NEXT)
self.assertEqual(self.db.getObject(oid1, before_tid=tid1), self.assertEqual(self.db.getObject(oid1, before_tid=tid1),
......
...@@ -187,7 +187,7 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -187,7 +187,7 @@ class TransactionManagerTests(NeoUnitTestBase):
ttid1, serial, *obj) ttid1, serial, *obj)
def testResolvableConflict(self): def testResolvableConflict(self):
""" Try to store an object with the lastest revision """ """ Try to store an object with the latest revision """
uuid = self.getClientUUID() uuid = self.getClientUUID()
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
serial, obj = self._getObject(1) serial, obj = self._getObject(1)
......
...@@ -28,7 +28,7 @@ class BootstrapManagerTests(NeoUnitTestBase): ...@@ -28,7 +28,7 @@ class BootstrapManagerTests(NeoUnitTestBase):
# create an application object # create an application object
config = self.getStorageConfiguration() config = self.getStorageConfiguration()
self.app = Application(config) self.app = Application(config)
self.bootstrap = BootstrapManager(self.app, 'main', NodeTypes.STORAGE) self.bootstrap = BootstrapManager(self.app, NodeTypes.STORAGE)
# define some variable to simulate client and storage node # define some variable to simulate client and storage node
self.master_port = 10010 self.master_port = 10010
self.storage_port = 10020 self.storage_port = 10020
......
...@@ -330,7 +330,7 @@ class HandlerSwitcherTests(NeoUnitTestBase): ...@@ -330,7 +330,7 @@ class HandlerSwitcherTests(NeoUnitTestBase):
r2 = self._makeRequest(2) r2 = self._makeRequest(2)
a2 = self._makeAnswer(2) a2 = self._makeAnswer(2)
h = self._makeHandler() h = self._makeHandler()
# emit requests aroung state setHandler # emit requests around state setHandler
self._handlers.emit(r1, 0, None) self._handlers.emit(r1, 0, None)
applied = self._handlers.setHandler(h) applied = self._handlers.setHandler(h)
self.assertFalse(applied) self.assertFalse(applied)
......
...@@ -18,8 +18,7 @@ import unittest ...@@ -18,8 +18,7 @@ import unittest
from mock import Mock from mock import Mock
from neo.lib import protocol from neo.lib import protocol
from neo.lib.protocol import NodeTypes, NodeStates from neo.lib.protocol import NodeTypes, NodeStates
from neo.lib.node import Node, MasterNode, StorageNode, \ from neo.lib.node import Node, NodeManager, MasterDB
ClientNode, AdminNode, NodeManager, MasterDB
from . import NeoUnitTestBase, getTempDirectory from . import NeoUnitTestBase, getTempDirectory
from time import time from time import time
from os import chmod, mkdir, rmdir, unlink from os import chmod, mkdir, rmdir, unlink
...@@ -29,15 +28,15 @@ class NodesTests(NeoUnitTestBase): ...@@ -29,15 +28,15 @@ class NodesTests(NeoUnitTestBase):
def setUp(self): def setUp(self):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
self.manager = Mock() self.nm = Mock()
def _updatedByAddress(self, node, index=0): def _updatedByAddress(self, node, index=0):
calls = self.manager.mockGetNamedCalls('_updateAddress') calls = self.nm.mockGetNamedCalls('_updateAddress')
self.assertEqual(len(calls), index + 1) self.assertEqual(len(calls), index + 1)
self.assertEqual(calls[index].getParam(0), node) self.assertEqual(calls[index].getParam(0), node)
def _updatedByUUID(self, node, index=0): def _updatedByUUID(self, node, index=0):
calls = self.manager.mockGetNamedCalls('_updateUUID') calls = self.nm.mockGetNamedCalls('_updateUUID')
self.assertEqual(len(calls), index + 1) self.assertEqual(len(calls), index + 1)
self.assertEqual(calls[index].getParam(0), node) self.assertEqual(calls[index].getParam(0), node)
...@@ -45,7 +44,7 @@ class NodesTests(NeoUnitTestBase): ...@@ -45,7 +44,7 @@ class NodesTests(NeoUnitTestBase):
""" Check the node initialization """ """ Check the node initialization """
address = ('127.0.0.1', 10000) address = ('127.0.0.1', 10000)
uuid = self.getNewUUID(None) uuid = self.getNewUUID(None)
node = Node(self.manager, address=address, uuid=uuid) node = Node(self.nm, address=address, uuid=uuid)
self.assertEqual(node.getState(), NodeStates.UNKNOWN) self.assertEqual(node.getState(), NodeStates.UNKNOWN)
self.assertEqual(node.getAddress(), address) self.assertEqual(node.getAddress(), address)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
...@@ -53,7 +52,7 @@ class NodesTests(NeoUnitTestBase): ...@@ -53,7 +52,7 @@ class NodesTests(NeoUnitTestBase):
def testState(self): def testState(self):
""" Check if the last changed time is updated when state is changed """ """ Check if the last changed time is updated when state is changed """
node = Node(self.manager) node = Node(self.nm)
self.assertEqual(node.getState(), NodeStates.UNKNOWN) self.assertEqual(node.getState(), NodeStates.UNKNOWN)
self.assertTrue(time() - 1 < node.getLastStateChange() < time()) self.assertTrue(time() - 1 < node.getLastStateChange() < time())
previous_time = node.getLastStateChange() previous_time = node.getLastStateChange()
...@@ -64,7 +63,7 @@ class NodesTests(NeoUnitTestBase): ...@@ -64,7 +63,7 @@ class NodesTests(NeoUnitTestBase):
def testAddress(self): def testAddress(self):
""" Check if the node is indexed by address """ """ Check if the node is indexed by address """
node = Node(self.manager) node = Node(self.nm)
self.assertEqual(node.getAddress(), None) self.assertEqual(node.getAddress(), None)
address = ('127.0.0.1', 10000) address = ('127.0.0.1', 10000)
node.setAddress(address) node.setAddress(address)
...@@ -72,107 +71,51 @@ class NodesTests(NeoUnitTestBase): ...@@ -72,107 +71,51 @@ class NodesTests(NeoUnitTestBase):
def testUUID(self): def testUUID(self):
""" As for Address but UUID """ """ As for Address but UUID """
node = Node(self.manager) node = Node(self.nm)
self.assertEqual(node.getAddress(), None) self.assertEqual(node.getAddress(), None)
uuid = self.getNewUUID(None) uuid = self.getNewUUID(None)
node.setUUID(uuid) node.setUUID(uuid)
self._updatedByUUID(node) self._updatedByUUID(node)
def testTypes(self):
""" Check that the abstract node has no type """
node = Node(self.manager)
self.assertRaises(NotImplementedError, node.getType)
self.assertFalse(node.isStorage())
self.assertFalse(node.isMaster())
self.assertFalse(node.isClient())
self.assertFalse(node.isAdmin())
def testMaster(self):
""" Check Master sub class """
node = MasterNode(self.manager)
self.assertEqual(node.getType(), protocol.NodeTypes.MASTER)
self.assertTrue(node.isMaster())
self.assertFalse(node.isStorage())
self.assertFalse(node.isClient())
self.assertFalse(node.isAdmin())
def testStorage(self):
""" Check Storage sub class """
node = StorageNode(self.manager)
self.assertEqual(node.getType(), protocol.NodeTypes.STORAGE)
self.assertTrue(node.isStorage())
self.assertFalse(node.isMaster())
self.assertFalse(node.isClient())
self.assertFalse(node.isAdmin())
def testClient(self):
""" Check Client sub class """
node = ClientNode(self.manager)
self.assertEqual(node.getType(), protocol.NodeTypes.CLIENT)
self.assertTrue(node.isClient())
self.assertFalse(node.isMaster())
self.assertFalse(node.isStorage())
self.assertFalse(node.isAdmin())
def testAdmin(self):
""" Check Admin sub class """
node = AdminNode(self.manager)
self.assertEqual(node.getType(), protocol.NodeTypes.ADMIN)
self.assertTrue(node.isAdmin())
self.assertFalse(node.isMaster())
self.assertFalse(node.isStorage())
self.assertFalse(node.isClient())
class NodeManagerTests(NeoUnitTestBase): class NodeManagerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.manager = NodeManager()
def _addStorage(self): def _addStorage(self):
self.storage = StorageNode(self.manager, ('127.0.0.1', 1000), self.getStorageUUID()) self.storage = self.nm.createStorage(
address=('127.0.0.1', 1000), uuid=self.getStorageUUID())
def _addMaster(self): def _addMaster(self):
self.master = MasterNode(self.manager, ('127.0.0.1', 2000), self.getMasterUUID()) self.master = self.nm.createMaster(
address=('127.0.0.1', 2000), uuid=self.getMasterUUID())
def _addClient(self): def _addClient(self):
self.client = ClientNode(self.manager, None, self.getClientUUID()) self.client = self.nm.createClient(uuid=self.getClientUUID())
def _addAdmin(self): def _addAdmin(self):
self.admin = AdminNode(self.manager, ('127.0.0.1', 4000), self.getAdminUUID()) self.admin = self.nm.createAdmin(
address=('127.0.0.1', 4000), uuid=self.getAdminUUID())
def checkNodes(self, node_list): def checkNodes(self, node_list):
manager = self.manager self.assertEqual(sorted(self.nm.getList()), sorted(node_list))
self.assertEqual(sorted(manager.getList()), sorted(node_list))
def checkMasters(self, master_list): def checkMasters(self, master_list):
manager = self.manager self.assertEqual(self.nm.getMasterList(), master_list)
self.assertEqual(manager.getMasterList(), master_list)
def checkStorages(self, storage_list): def checkStorages(self, storage_list):
manager = self.manager self.assertEqual(self.nm.getStorageList(), storage_list)
self.assertEqual(manager.getStorageList(), storage_list)
def checkClients(self, client_list): def checkClients(self, client_list):
manager = self.manager self.assertEqual(self.nm.getClientList(), client_list)
self.assertEqual(manager.getClientList(), client_list)
def checkByServer(self, node): def checkByServer(self, node):
node_found = self.manager.getByAddress(node.getAddress()) self.assertEqual(node, self.nm.getByAddress(node.getAddress()))
self.assertEqual(node_found, node)
def checkByUUID(self, node): def checkByUUID(self, node):
node_found = self.manager.getByUUID(node.getUUID()) self.assertEqual(node, self.nm.getByUUID(node.getUUID()))
self.assertEqual(node_found, node)
def checkIdentified(self, node_list, pool_set=None):
identified_node_list = self.manager.getIdentifiedList(pool_set)
self.assertEqual(set(identified_node_list), set(node_list))
def testInit(self): def testInit(self):
""" Check the manager is empty when started """ """ Check the manager is empty when started """
manager = self.manager manager = self.nm
self.checkNodes([]) self.checkNodes([])
self.checkMasters([]) self.checkMasters([])
self.checkStorages([]) self.checkStorages([])
...@@ -186,7 +129,7 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -186,7 +129,7 @@ class NodeManagerTests(NeoUnitTestBase):
def testAdd(self): def testAdd(self):
""" Check if new nodes are registered in the manager """ """ Check if new nodes are registered in the manager """
manager = self.manager manager = self.nm
self.checkNodes([]) self.checkNodes([])
# storage # storage
self._addStorage() self._addStorage()
...@@ -225,7 +168,7 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -225,7 +168,7 @@ class NodeManagerTests(NeoUnitTestBase):
def testUpdate(self): def testUpdate(self):
""" Check manager content update """ """ Check manager content update """
# set up four nodes # set up four nodes
manager = self.manager manager = self.nm
self._addMaster() self._addMaster()
self._addStorage() self._addStorage()
self._addClient() self._addClient()
...@@ -240,15 +183,15 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -240,15 +183,15 @@ class NodeManagerTests(NeoUnitTestBase):
old_uuid = self.storage.getUUID() old_uuid = self.storage.getUUID()
new_uuid = self.getStorageUUID() new_uuid = self.getStorageUUID()
node_list = ( node_list = (
(NodeTypes.CLIENT, None, self.client.getUUID(), NodeStates.DOWN), (NodeTypes.CLIENT, None, self.client.getUUID(), NodeStates.DOWN, None),
(NodeTypes.MASTER, new_address, self.master.getUUID(), NodeStates.RUNNING), (NodeTypes.MASTER, new_address, self.master.getUUID(), NodeStates.RUNNING, None),
(NodeTypes.STORAGE, self.storage.getAddress(), new_uuid, (NodeTypes.STORAGE, self.storage.getAddress(), new_uuid,
NodeStates.RUNNING), NodeStates.RUNNING, None),
(NodeTypes.ADMIN, self.admin.getAddress(), self.admin.getUUID(), (NodeTypes.ADMIN, self.admin.getAddress(), self.admin.getUUID(),
NodeStates.UNKNOWN), NodeStates.UNKNOWN, None),
) )
# update manager content # update manager content
manager.update(node_list) manager.update(Mock(), node_list)
# - the client gets down # - the client gets down
self.checkClients([]) self.checkClients([])
# - master change it's address # - master change it's address
...@@ -266,31 +209,6 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -266,31 +209,6 @@ class NodeManagerTests(NeoUnitTestBase):
self.checkNodes([self.master, self.admin, new_storage]) self.checkNodes([self.master, self.admin, new_storage])
self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN) self.assertEqual(self.admin.getState(), NodeStates.UNKNOWN)
def testIdentified(self):
# set up four nodes
manager = self.manager
self._addMaster()
self._addStorage()
self._addClient()
self._addAdmin()
# switch node to connected
self.checkIdentified([])
self.master.setConnection(Mock())
self.checkIdentified([self.master])
self.storage.setConnection(Mock())
self.checkIdentified([self.master, self.storage])
self.client.setConnection(Mock())
self.checkIdentified([self.master, self.storage, self.client])
self.admin.setConnection(Mock())
self.checkIdentified([self.master, self.storage, self.client, self.admin])
# check the pool_set attribute
self.checkIdentified([self.master], pool_set=[self.master.getUUID()])
self.checkIdentified([self.storage], pool_set=[self.storage.getUUID()])
self.checkIdentified([self.client], pool_set=[self.client.getUUID()])
self.checkIdentified([self.admin], pool_set=[self.admin.getUUID()])
self.checkIdentified([self.master, self.storage], pool_set=[
self.master.getUUID(), self.storage.getUUID()])
class MasterDBTests(NeoUnitTestBase): class MasterDBTests(NeoUnitTestBase):
def _checkMasterDB(self, path, expected_master_list): def _checkMasterDB(self, path, expected_master_list):
db = list(MasterDB(path)) db = list(MasterDB(path))
...@@ -301,7 +219,7 @@ class MasterDBTests(NeoUnitTestBase): ...@@ -301,7 +219,7 @@ class MasterDBTests(NeoUnitTestBase):
def testInitialAccessRights(self): def testInitialAccessRights(self):
""" """
Verify MasterDB raises immediately on instanciation if it cannot Verify MasterDB raises immediately on instantiation if it cannot
create a non-existing database. This does not guarantee any later create a non-existing database. This does not guarantee any later
open will succeed, but makes the simple error case obvious. open will succeed, but makes the simple error case obvious.
""" """
......
...@@ -18,7 +18,6 @@ import unittest ...@@ -18,7 +18,6 @@ import unittest
from mock import Mock from mock import Mock
from neo.lib.protocol import NodeStates, CellStates from neo.lib.protocol import NodeStates, CellStates
from neo.lib.pt import Cell, PartitionTable, PartitionTableException from neo.lib.pt import Cell, PartitionTable, PartitionTableException
from neo.lib.node import StorageNode
from . import NeoUnitTestBase from . import NeoUnitTestBase
class PartitionTableTests(NeoUnitTestBase): class PartitionTableTests(NeoUnitTestBase):
...@@ -26,7 +25,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -26,7 +25,7 @@ class PartitionTableTests(NeoUnitTestBase):
def test_01_Cell(self): def test_01_Cell(self):
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
server = ("127.0.0.1", 19001) server = ("127.0.0.1", 19001)
sn = StorageNode(Mock(), server, uuid) sn = self.createStorage(server, uuid)
cell = Cell(sn) cell = Cell(sn)
self.assertEqual(cell.node, sn) self.assertEqual(cell.node, sn)
self.assertEqual(cell.state, CellStates.UP_TO_DATE) self.assertEqual(cell.state, CellStates.UP_TO_DATE)
...@@ -50,7 +49,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -50,7 +49,7 @@ class PartitionTableTests(NeoUnitTestBase):
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
self.assertEqual(len(pt.partition_list[x]), 0) self.assertEqual(len(pt.partition_list[x]), 0)
# add a cell to an empty row # add a cell to an empty row
...@@ -65,9 +64,9 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -65,9 +64,9 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertEqual(cell.getState(), CellStates.UP_TO_DATE) self.assertEqual(cell.getState(), CellStates.UP_TO_DATE)
else: else:
self.assertEqual(len(pt.partition_list[x]), 0) self.assertEqual(len(pt.partition_list[x]), 0)
# try to add to an unexistant partition # try to add to a nonexistent partition
self.assertRaises(IndexError, pt.setCell, 10, sn1, CellStates.UP_TO_DATE) self.assertRaises(IndexError, pt.setCell, 10, sn1, CellStates.UP_TO_DATE)
# if we add in discardes state, must be removed # if we add in discards state, must be removed
pt.setCell(0, sn1, CellStates.DISCARDED) pt.setCell(0, sn1, CellStates.DISCARDED)
for x in xrange(num_partitions): for x in xrange(num_partitions):
self.assertEqual(len(pt.partition_list[x]), 0) self.assertEqual(len(pt.partition_list[x]), 0)
...@@ -131,7 +130,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -131,7 +130,7 @@ class PartitionTableTests(NeoUnitTestBase):
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
self.assertEqual(len(pt.partition_list[x]), 0) self.assertEqual(len(pt.partition_list[x]), 0)
# add a cell to an empty row # add a cell to an empty row
...@@ -168,22 +167,22 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -168,22 +167,22 @@ class PartitionTableTests(NeoUnitTestBase):
num_partitions = 5 num_partitions = 5
num_replicas = 2 num_replicas = 2
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
# add two kind of node, usable and unsable # add two kind of node, usable and unusable
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
pt.setCell(0, sn2, CellStates.OUT_OF_DATE) pt.setCell(0, sn2, CellStates.OUT_OF_DATE)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
pt.setCell(0, sn3, CellStates.FEEDING) pt.setCell(0, sn3, CellStates.FEEDING)
uuid4 = self.getStorageUUID() uuid4 = self.getStorageUUID()
server4 = ("127.0.0.4", 19001) server4 = ("127.0.0.4", 19001)
sn4 = StorageNode(Mock(), server4, uuid4) sn4 = self.createStorage(server4, uuid4)
pt.setCell(0, sn4, CellStates.DISCARDED) # won't be added pt.setCell(0, sn4, CellStates.DISCARDED) # won't be added
# now checks result # now checks result
self.assertEqual(len(pt.partition_list[0]), 3) self.assertEqual(len(pt.partition_list[0]), 3)
...@@ -214,18 +213,18 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -214,18 +213,18 @@ class PartitionTableTests(NeoUnitTestBase):
num_partitions = 5 num_partitions = 5
num_replicas = 2 num_replicas = 2
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
# add two kind of node, usable and unsable # add two kind of node, usable and unusable
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
pt.setCell(1, sn2, CellStates.OUT_OF_DATE) pt.setCell(1, sn2, CellStates.OUT_OF_DATE)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
pt.setCell(2, sn3, CellStates.FEEDING) pt.setCell(2, sn3, CellStates.FEEDING)
# now checks result # now checks result
self.assertEqual(len(pt.partition_list[0]), 1) self.assertEqual(len(pt.partition_list[0]), 1)
...@@ -244,22 +243,22 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -244,22 +243,22 @@ class PartitionTableTests(NeoUnitTestBase):
num_partitions = 5 num_partitions = 5
num_replicas = 2 num_replicas = 2
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
# add two kind of node, usable and unsable # add two kind of node, usable and unusable
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
pt.setCell(0, sn2, CellStates.OUT_OF_DATE) pt.setCell(0, sn2, CellStates.OUT_OF_DATE)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
pt.setCell(0, sn3, CellStates.FEEDING) pt.setCell(0, sn3, CellStates.FEEDING)
uuid4 = self.getStorageUUID() uuid4 = self.getStorageUUID()
server4 = ("127.0.0.4", 19001) server4 = ("127.0.0.4", 19001)
sn4 = StorageNode(Mock(), server4, uuid4) sn4 = self.createStorage(server4, uuid4)
pt.setCell(0, sn4, CellStates.DISCARDED) # won't be added pt.setCell(0, sn4, CellStates.DISCARDED) # won't be added
# must get only two node as feeding and discarded not taken # must get only two node as feeding and discarded not taken
# into account # into account
...@@ -276,7 +275,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -276,7 +275,7 @@ class PartitionTableTests(NeoUnitTestBase):
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.UP_TO_DATE) pt.setCell(x, sn1, CellStates.UP_TO_DATE)
self.assertEqual(pt.num_filled_rows, num_partitions) self.assertEqual(pt.num_filled_rows, num_partitions)
...@@ -286,27 +285,28 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -286,27 +285,28 @@ class PartitionTableTests(NeoUnitTestBase):
num_partitions = 5 num_partitions = 5
num_replicas = 2 num_replicas = 2
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
# add two kind of node, usable and unsable # add two kind of node, usable and unusable
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
# now test # now test
self.assertTrue(pt.hasOffset(0)) self.assertTrue(pt.hasOffset(0))
self.assertFalse(pt.hasOffset(1)) self.assertFalse(pt.hasOffset(1))
# unknonw partition # unknown partition
self.assertFalse(pt.hasOffset(50)) self.assertFalse(pt.hasOffset(50))
def test_10_operational(self): def test_10_operational(self):
def createStorage():
uuid = self.getStorageUUID()
return self.createStorage(("127.0.0.1", uuid), uuid)
num_partitions = 5 num_partitions = 5
num_replicas = 2 num_replicas = 2
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() sn1 = createStorage()
server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.UP_TO_DATE) pt.setCell(x, sn1, CellStates.UP_TO_DATE)
self.assertTrue(pt.filled()) self.assertTrue(pt.filled())
...@@ -318,9 +318,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -318,9 +318,7 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() sn1 = createStorage()
server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.FEEDING) pt.setCell(x, sn1, CellStates.FEEDING)
self.assertTrue(pt.filled()) self.assertTrue(pt.filled())
...@@ -333,9 +331,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -333,9 +331,7 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() sn1 = createStorage()
server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1)
sn1.setState(NodeStates.TEMPORARILY_DOWN) sn1.setState(NodeStates.TEMPORARILY_DOWN)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.FEEDING) pt.setCell(x, sn1, CellStates.FEEDING)
...@@ -348,9 +344,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -348,9 +344,7 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() sn1 = createStorage()
server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.OUT_OF_DATE) pt.setCell(x, sn1, CellStates.OUT_OF_DATE)
self.assertTrue(pt.filled()) self.assertTrue(pt.filled())
...@@ -364,18 +358,18 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -364,18 +358,18 @@ class PartitionTableTests(NeoUnitTestBase):
# add nodes # add nodes
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
pt.setCell(1, sn1, CellStates.UP_TO_DATE) pt.setCell(1, sn1, CellStates.UP_TO_DATE)
pt.setCell(2, sn1, CellStates.UP_TO_DATE) pt.setCell(2, sn1, CellStates.UP_TO_DATE)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
pt.setCell(0, sn2, CellStates.UP_TO_DATE) pt.setCell(0, sn2, CellStates.UP_TO_DATE)
pt.setCell(1, sn2, CellStates.UP_TO_DATE) pt.setCell(1, sn2, CellStates.UP_TO_DATE)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
pt.setCell(0, sn3, CellStates.UP_TO_DATE) pt.setCell(0, sn3, CellStates.UP_TO_DATE)
# test # test
row_0 = pt.getRow(0) row_0 = pt.getRow(0)
...@@ -397,7 +391,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -397,7 +391,7 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertEqual(len(row_3), 0) self.assertEqual(len(row_3), 0)
row_4 = pt.getRow(4) row_4 = pt.getRow(4)
self.assertEqual(len(row_4), 0) self.assertEqual(len(row_4), 0)
# unknwon row # unknown row
self.assertRaises(IndexError, pt.getRow, 5) self.assertRaises(IndexError, pt.getRow, 5)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,7 +22,7 @@ from neo.lib.util import ReadBuffer, parseNodeAddress ...@@ -22,7 +22,7 @@ from neo.lib.util import ReadBuffer, parseNodeAddress
class UtilTests(NeoUnitTestBase): class UtilTests(NeoUnitTestBase):
def test_parseNodeAddress(self): def test_parseNodeAddress(self):
""" Parsing of addesses """ """ Parsing of addresses """
def test(parsed, *args): def test(parsed, *args):
self.assertEqual(parsed, parseNodeAddress(*args)) self.assertEqual(parsed, parseNodeAddress(*args))
http_port = socket.getservbyname('http') http_port = socket.getservbyname('http')
......
...@@ -27,14 +27,14 @@ from ZODB import DB, POSException ...@@ -27,14 +27,14 @@ from ZODB import DB, POSException
from ZODB.DB import TransactionalUndo from ZODB.DB import TransactionalUndo
from neo.storage.transactions import TransactionManager, \ from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError DelayedError, ConflictError
from neo.lib.connection import MTClientConnection from neo.lib.connection import ServerConnection, MTClientConnection
from neo.lib.exception import DatabaseFailure, StoppedOperation from neo.lib.exception import DatabaseFailure, StoppedOperation
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_TID ZERO_OID, ZERO_TID
from .. import expectedFailure, Patch from .. import expectedFailure, Patch
from . import LockLock, NEOCluster, NEOThreadedTest from . import LockLock, NEOCluster, NEOThreadedTest
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
...@@ -970,7 +970,7 @@ class Test(NEOThreadedTest): ...@@ -970,7 +970,7 @@ class Test(NEOThreadedTest):
self.assertFalse(storage.tm._transaction_dict) self.assertFalse(storage.tm._transaction_dict)
finally: finally:
db.close() db.close()
# Check we did't get an invalidation, which would cause an # Check we didn't get an invalidation, which would cause an
# assertion failure in the cache. Connection does the same check in # assertion failure in the cache. Connection does the same check in
# _setstate_noncurrent so this could be also done by starting a # _setstate_noncurrent so this could be also done by starting a
# transaction before the last one, and clearing the cache before # transaction before the last one, and clearing the cache before
...@@ -1061,17 +1061,40 @@ class Test(NEOThreadedTest): ...@@ -1061,17 +1061,40 @@ class Test(NEOThreadedTest):
cluster.stop() cluster.stop()
def testClientFailureDuringTpcFinish(self): def testClientFailureDuringTpcFinish(self):
def delayAskLockInformation(conn, packet): """
if isinstance(packet, Packets.AskLockInformation): Third scenario:
C M S | TID known by
---- Finish -----> |
---- Disconnect -- ----- Lock ------> |
----- C down ----> |
---- Connect ----> | M
----- C up ------> |
<---- Locked ----- |
------------------------------------------------+--------------
-- unlock ... |
---- FinalTID ---> | S (TM)
---- Connect + FinalTID --------------> |
... unlock ---> |
------------------------------------------------+--------------
| S (DM)
"""
def delayAnswerLockInformation(conn, packet):
if isinstance(packet, Packets.AnswerInformationLocked):
cluster.client.master_conn.close() cluster.client.master_conn.close()
return True return True
def askFinalTID(orig, *args): def askFinalTID(orig, *args):
m2s.remove(delayAskLockInformation) s2m.remove(delayAnswerLockInformation)
orig(*args) orig(*args)
def _getFinalTID(orig, ttid): def _getFinalTID(orig, ttid):
m2s.remove(delayAskLockInformation) s2m.remove(delayAnswerLockInformation)
self.tic() self.tic()
return orig(ttid) return orig(ttid)
def _connectToPrimaryNode(orig):
conn = orig()
self.tic()
s2m.remove(delayAnswerLockInformation)
return conn
cluster = NEOCluster() cluster = NEOCluster()
try: try:
cluster.start() cluster.start()
...@@ -1079,25 +1102,30 @@ class Test(NEOThreadedTest): ...@@ -1079,25 +1102,30 @@ class Test(NEOThreadedTest):
r = c.root() r = c.root()
r['x'] = PCounter() r['x'] = PCounter()
tid0 = r._p_serial tid0 = r._p_serial
with cluster.master.filterConnection(cluster.storage) as m2s: with cluster.storage.filterConnection(cluster.master) as s2m:
m2s.add(delayAskLockInformation, s2m.add(delayAnswerLockInformation,
Patch(ClientServiceHandler, askFinalTID=askFinalTID)) Patch(ClientServiceHandler, askFinalTID=askFinalTID))
t.commit() # the final TID is returned by the master t.commit() # the final TID is returned by the master
t.begin() t.begin()
r['x'].value += 1 r['x'].value += 1
tid1 = r._p_serial tid1 = r._p_serial
self.assertTrue(tid0 < tid1) self.assertTrue(tid0 < tid1)
with cluster.master.filterConnection(cluster.storage) as m2s: with cluster.storage.filterConnection(cluster.master) as s2m:
m2s.add(delayAskLockInformation, s2m.add(delayAnswerLockInformation,
Patch(cluster.client, _getFinalTID=_getFinalTID)) Patch(cluster.client, _getFinalTID=_getFinalTID))
t.commit() # the final TID is returned by the storage backend t.commit() # the final TID is returned by the storage backend
t.begin() t.begin()
r['x'].value += 1 r['x'].value += 1
tid2 = r['x']._p_serial tid2 = r['x']._p_serial
self.assertTrue(tid1 < tid2) self.assertTrue(tid1 < tid2)
with cluster.master.filterConnection(cluster.storage) as m2s: # The whole test would be simpler if we always delayed the
m2s.add(delayAskLockInformation, # AskLockInformation packet. However, it would also delay
Patch(cluster.client, _getFinalTID=_getFinalTID)) # NotifyNodeInformation and the client would fail to connect
# to the storage node.
with cluster.storage.filterConnection(cluster.master) as s2m, \
cluster.master.filterConnection(cluster.storage) as m2s:
s2m.add(delayAnswerLockInformation, Patch(cluster.client,
_connectToPrimaryNode=_connectToPrimaryNode))
m2s.add(lambda conn, packet: m2s.add(lambda conn, packet:
isinstance(packet, Packets.NotifyUnlockInformation)) isinstance(packet, Packets.NotifyUnlockInformation))
t.commit() # the final TID is returned by the storage (tm) t.commit() # the final TID is returned by the storage (tm)
...@@ -1292,6 +1320,8 @@ class Test(NEOThreadedTest): ...@@ -1292,6 +1320,8 @@ class Test(NEOThreadedTest):
m2c, = cluster.master.getConnectionList(cluster.client) m2c, = cluster.master.getConnectionList(cluster.client)
cluster.client._cache.clear() cluster.client._cache.clear()
c.cacheMinimize() c.cacheMinimize()
# Make the master disconnects the client when the latter is about
# to send a AskObject packet to the storage node.
with cluster.client.filterConnection(cluster.storage) as c2s: with cluster.client.filterConnection(cluster.storage) as c2s:
c2s.add(disconnect) c2s.add(disconnect)
# Storages are currently notified of clients that get # Storages are currently notified of clients that get
...@@ -1299,9 +1329,75 @@ class Test(NEOThreadedTest): ...@@ -1299,9 +1329,75 @@ class Test(NEOThreadedTest):
# Should it change, the clients would have to disconnect on # Should it change, the clients would have to disconnect on
# their own. # their own.
self.assertRaises(TransientError, getattr, c, "root") self.assertRaises(TransientError, getattr, c, "root")
with Patch(ClientOperationHandler, uuid = cluster.client.uuid
askObject=lambda orig, self, conn, *args: conn.close()): # Let's use a second client to steal the node id of the first one.
self.assertRaises(NEOStorageError, getattr, c, "root") client = cluster.newClient()
try:
client.sync()
self.assertEqual(uuid, client.uuid)
# The client reconnects successfully to the master and storage,
# with a different node id. This time, we get a different error
# if it's only disconnected from the storage.
with Patch(ClientOperationHandler,
askObject=lambda orig, self, conn, *args: conn.close()):
self.assertRaises(NEOStorageError, getattr, c, "root")
self.assertNotEqual(uuid, cluster.client.uuid)
# Second reconnection, for a successful load.
c.root
finally:
client.close()
finally:
cluster.stop()
def testIdTimestamp(self):
"""
Given a master M, a storage S, and 2 clients Ca and Cb.
While Ca(id=1) is being identified by S:
1. connection between Ca and M breaks
2. M -> S: C1 down
3. Cb connect to M: id=1
4. M -> S: C1 up
5. S processes RequestIdentification from Ca with id=1
At 5, S must reject Ca, otherwise Cb can't connect to S. This is where
id timestamps come into play: with C1 up since t2, S rejects Ca due to
a request with t1 < t2.
To avoid issues with clocks that are out of sync, the client gets its
connection timestamp by being notified about itself from the master.
"""
s2c = []
def __init__(orig, self, *args, **kw):
orig(self, *args, **kw)
self.readable = bool
s2c.append(self)
ll()
def connectToStorage(client):
next(client.cp.iterateForObject(0))
cluster = NEOCluster()
try:
cluster.start()
Ca = cluster.client
Ca.pt # only connect to the master
# In a separate thread, connect to the storage but suspend the
# processing of the RequestIdentification packet, until the
# storage is notified about the existence of the other client.
with LockLock() as ll, Patch(ServerConnection, __init__=__init__):
t = self.newThread(connectToStorage, Ca)
ll()
s2c, = s2c
m2c, = cluster.master.getConnectionList(cluster.client)
m2c.close()
Cb = cluster.newClient()
try:
Cb.pt # only connect to the master
del s2c.readable
self.assertRaises(NEOPrimaryMasterLost, t.join)
self.assertTrue(s2c.isClosed())
connectToStorage(Cb)
finally:
Cb.close()
finally: finally:
cluster.stop() cluster.stop()
......
...@@ -302,7 +302,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -302,7 +302,7 @@ class ReplicationTests(NEOThreadedTest):
More generally, this checks that when a handler raises when a connection More generally, this checks that when a handler raises when a connection
is closed voluntarily, the connection is in a consistent state and can is closed voluntarily, the connection is in a consistent state and can
be, for example, closed again after the exception is catched, without be, for example, closed again after the exception is caught, without
assertion failure. assertion failure.
""" """
conn, = backup.master.getConnectionList(backup.upstream.master) conn, = backup.master.getConnectionList(backup.upstream.master)
......
...@@ -3,7 +3,7 @@ for COV in coverage python-coverage ...@@ -3,7 +3,7 @@ for COV in coverage python-coverage
do type $COV && break do type $COV && break
done >/dev/null 2>&1 || exit done >/dev/null 2>&1 || exit
$COV html $COV html "$@"
# https://bitbucket.org/ned/coveragepy/issues/474/javascript-in-html-captures-all-keys # https://bitbucket.org/ned/coveragepy/issues/474/javascript-in-html-captures-all-keys
sed -i " sed -i "
/assign_shortkeys *=/s/$/return;/ /assign_shortkeys *=/s/$/return;/
......
...@@ -94,18 +94,18 @@ class ReplicationBenchmark(BenchmarkRunner): ...@@ -94,18 +94,18 @@ class ReplicationBenchmark(BenchmarkRunner):
return self.buildReport(p_time, r_time), content return self.buildReport(p_time, r_time), content
def replicate(self, neo): def replicate(self, neo):
def number_of_oudated_cell(): def number_of_outdated_cell():
row_list = neo.neoctl.getPartitionRowList()[1] row_list = neo.neoctl.getPartitionRowList()[1]
number_of_oudated = 0 number_of_outdated = 0
for row in row_list: for row in row_list:
for cell in row[1]: for cell in row[1]:
if cell[1] == CellStates.OUT_OF_DATE: if cell[1] == CellStates.OUT_OF_DATE:
number_of_oudated += 1 number_of_outdated += 1
return number_of_oudated return number_of_outdated
end_time = time.time() + 3600 end_time = time.time() + 3600
while time.time() <= end_time and number_of_oudated_cell() > 0: while time.time() <= end_time and number_of_outdated_cell() > 0:
time.sleep(1) time.sleep(1)
if number_of_oudated_cell() > 0: if number_of_outdated_cell() > 0:
raise Exception('Replication takes too long') raise Exception('Replication takes too long')
def buildReport(self, p_time, r_time): def buildReport(self, p_time, r_time):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment