Commit 460cf445 authored by Yoshinori Okuji's avatar Yoshinori Okuji

Rewrite step one

git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@22 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent eab036aa
......@@ -2,76 +2,92 @@ import socket
import errno
import logging
from select import select
from time import time
from protocol import Packet, ProtocolError
from event import IdleEvent
class IdleEvent:
"""This class represents an event called when a connection is waiting for
a message too long."""
def __init__(self, conn, msg_id, timeout, additional_timeout):
self._conn = conn
self._id = msg_id
t = time()
self._time = t + timeout
self._critical_time = t + timeout + additional_timeout
self._additional_timeout = additional_timeout
def getId(self):
return self._id
def getTime(self):
return self._time
def getCriticalTime(self):
return self._critical_time
def __call__(self, t):
conn = self._conn
if t > self._critical_time:
logging.info('timeout with %s:%d', conn.ip_address, conn.port)
self._conn.timeoutExpired(self)
return True
elif t > self._time:
if self._additional_timeout > 10:
self._additional_timeout -= 10
conn.expectMessage(self._id, 10, self._additional_timeout)
# Start a keep-alive packet.
logging.info('sending a ping to %s:%d', conn.ip_address, conn.port)
msg_id = conn.getNextId()
conn.addPacket(Packet().ping(msg_id))
conn.expectMessage(msg_id, 10, 0)
else:
conn.expectMessage(self._id, self._additional_timeout, 0)
return True
return False
class BaseConnection(object):
"""A base connection."""
def __init__(self, event_manager, handler, s = None, addr = None):
self.em = event_manager
self.s = s
self.addr = addr
self.handler = handler
if s is not None:
event_manager.register(self)
def getSocket(self):
return self.s
class Connection:
"""A connection."""
def setSocket(self, s):
if self.s is not None:
raise RuntimeError, 'cannot overwrite a socket in a connection'
if s is not None:
self.s = s
self.em.register(self)
connecting = False
from_self = False
aborted = False
def getAddress(self):
return self.addr
def __init__(self, connection_manager, s = None, addr = None):
self.s = s
def readable(self):
raise NotImplementedError
def writable(self):
raise NotImplementedError
def getHandler(self):
return self.handler
def setHandler(self):
self.handler = handler
def getEventManager(self):
return self.em
class ListeningConnection(BaseConnection):
"""A listen connection."""
def __init__(self, event_manager, handler, addr = None, **kw):
logging.info('listening to %s:%d', *addr)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.setblocking(0)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(addr)
s.listen(5)
except:
s.close()
raise
BaseConnection.__init__(self, event_manager, handler, s = s, addr = addr)
self.em.addReader(self)
def readable(self):
try:
new_s, addr = self.s.accept()
logging.info('accepted a connection from %s:%d', *addr)
self.handler.connectionAccepted(self, new_s, addr)
except socket.error, m:
if m[0] == errno.EAGAIN:
return
raise
class Connection(BaseConnection):
"""A connection."""
def __init__(self, event_manager, handler, s = None, addr = None):
BaseConnection.__init__(self, handler, event_manager, s = s, addr = addr)
if s is not None:
connection_manager.addReader(s)
self.cm = connection_manager
event_manager.addReader(self)
self.read_buf = []
self.write_buf = []
self.cur_id = 0
self.event_dict = {}
if addr is None:
self.ip_address = None
self.port = None
else:
self.ip_address, self.port = addr
self.aborted = False
self.uuid = None
def getSocket(self):
return self.s
def getUUID(self):
return self.uuid
def setUUID(self, uuid):
self.uuid = uuid
def getNextId(self):
next_id = self.cur_id
......@@ -80,46 +96,15 @@ class Connection:
self.cur_id = 0
return next_id
def connect(self, ip_address, port):
"""Connect to another node."""
if self.s is not None:
raise RuntimeError, 'already connected'
self.ip_address = ip_address
self.port = port
self.from_self = True
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.setblocking(0)
s.connect((ip_address, port))
except socket.error, m:
if m[0] == errno.EINPROGRESS:
self.connecting = True
self.cm.addWriter(s)
else:
s.close()
raise
else:
self.connectionCompleted()
self.cm.addReader(s)
except socket.error:
self.connectionFailed()
return
self.s = s
return s
def close(self):
"""Close the connection."""
s = self.s
em = self.em
if s is not None:
logging.debug('closing a socket for %s:%d', self.ip_address, self.port)
self.cm.removeReader(s)
self.cm.removeWriter(s)
self.cm.unregister(self)
logging.debug('closing a socket for %s:%d', *(self.addr))
em.removeReader(self)
em.removeWriter(self)
em.unregister(self)
try:
# This may fail if the socket is not connected.
s.shutdown(socket.SHUT_RDWR)
......@@ -128,34 +113,22 @@ class Connection:
s.close()
self.s = None
for event in self.event_dict.itervalues():
self.cm.removeIdleEvent(event)
em.removeIdleEvent(event)
self.event_dict.clear()
def abort(self):
"""Abort dealing with this connection."""
logging.debug('aborting a socket for %s:%d', self.ip_address, self.port)
logging.debug('aborting a socket for %s:%d', *(self.addr))
self.aborted = True
def writable(self):
"""Called when self is writable."""
if self.connecting:
err = self.s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err:
self.connectionFailed()
self.close()
return
else:
self.connecting = False
self.connectionCompleted()
self.cm.addReader(self.s)
else:
self.send()
if not self.pending():
if self.aborted:
self.close()
else:
self.cm.removeWriter(self.s)
self.em.removeWriter(self)
def readable(self):
"""Called when self is readable."""
......@@ -163,7 +136,7 @@ class Connection:
self.analyse()
if self.aborted:
self.cm.removeReader(self.s)
self.em.removeReader(self)
def analyse(self):
"""Analyse received data."""
......@@ -177,7 +150,7 @@ class Connection:
try:
packet = Packet.parse(msg)
except ProtocolError, m:
self.packetMalformed(*m)
self.handler.packetMalformed(self, *m)
return
if packet is None:
......@@ -188,11 +161,11 @@ class Connection:
try:
event = self.event_dict[msg_id]
del self.event_dict[msg_id]
self.cm.removeIdleEvent(event)
self.em.removeIdleEvent(event)
except KeyError:
pass
self.packetReceived(packet)
self.handler.packetReceived(self, packet)
msg = msg[len(packet):]
if msg:
......@@ -210,19 +183,17 @@ class Connection:
r = s.recv(4096)
if not r:
logging.error('cannot read')
self.connectionClosed()
self.handler.connectionClosed(self)
self.close()
else:
self.read_buf.append(r)
except socket.error, m:
if m[0] == errno.EAGAIN:
pass
elif m[0] == errno.ECONNRESET:
logging.error('cannot read')
self.connectionClosed()
self.close()
else:
raise
logging.error('%s', m[1])
self.handler.connectionClosed(self)
self.close()
def send(self):
"""Send data to a socket."""
......@@ -236,7 +207,7 @@ class Connection:
r = s.send(msg)
if not r:
logging.error('cannot write')
self.connectionClosed()
self.handler.connectionClosed(self)
self.close()
elif r == len(msg):
del self.write_buf[:]
......@@ -245,21 +216,24 @@ class Connection:
except socket.error, m:
if m[0] == errno.EAGAIN:
return
raise
else:
logging.error('%s', m[1])
self.handler.connectionClosed(self)
self.close()
def addPacket(self, packet):
"""Add a packet into the write buffer."""
try:
self.write_buf.append(str(packet))
self.write_buf.append(packet.encode())
except ProtocolError, m:
logging.critical('trying to send a too big message')
return self.addPacket(Packet().internalError(packet.getId(), m[1]))
return self.addPacket(packet.internalError(packet.getId(), m[1]))
# If this is the first time, enable polling for writing.
if len(self.write_buf) == 1:
self.cm.addWriter(self.s)
self.em.addWriter(self.s)
def expectMessage(self, msg_id = None, timeout = 10, additional_timeout = 100):
def expectMessage(self, msg_id = None, timeout = 5, additional_timeout = 30):
"""Expect a message for a reply to a given message ID or any message.
The purpose of this method is to define how much amount of time is
......@@ -281,139 +255,49 @@ class Connection:
the callback is executed immediately."""
event = IdleEvent(self, msg_id, timeout, additional_timeout)
self.event_dict[msg_id] = event
self.cm.addIdleEvent(event)
# Hooks.
def connectionFailed(self):
"""Called when a connection fails."""
pass
self.em.addIdleEvent(event)
def connectionCompleted(self):
"""Called when a connection is completed."""
pass
def connectionAccepted(self):
"""Called when a connection is accepted."""
# A request for a node identification should arrive.
self.expectMessage(timeout = 10, additional_timeout = 0)
def connectionClosed(self):
"""Called when a connection is closed."""
pass
def timeoutExpired(self):
"""Called when a timeout event occurs."""
self.close()
def peerBroken(self):
"""Called when a peer is broken."""
pass
def packetReceived(self, packet):
"""Called when a packet is received."""
pass
def packetMalformed(self, packet, error_message):
"""Called when a packet is malformed."""
logging.info('malformed packet: %s', error_message)
self.addPacket(Packet().protocolError(packet.getId(), error_message))
self.abort()
self.peerBroken()
class ConnectionManager:
"""This class manages connections and sockets."""
def __init__(self, app = None, connection_klass = Connection):
self.listening_socket = None
self.connection_dict = {}
self.reader_set = set([])
self.writer_set = set([])
self.exc_list = []
self.app = app
self.klass = connection_klass
self.event_list = []
self.prev_time = time()
def listen(self, ip_address, port):
logging.info('listening to %s:%d', ip_address, port)
class ClientConnection(Connection):
"""A connection from this node to a remote node."""
def __init__(self, event_manager, handler, addr = None, **kw):
Connection.__init__(self, event_manager, handler, addr = addr)
self.connecting = False
handler.connectionStarted(self)
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setblocking(0)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((ip_address, port))
s.listen(5)
self.listening_socket = s
self.reader_set.add(s)
def getConnectionList(self):
return self.connection_dict.values()
def register(self, conn):
self.connection_dict[conn.getSocket()] = conn
def unregister(self, conn):
del self.connection_dict[conn.getSocket()]
self.setSocket(s)
def connect(self, ip_address, port):
logging.info('connecting to %s:%d', ip_address, port)
conn = self.klass(self)
if conn.connect(ip_address, port) is not None:
self.register(conn)
def poll(self, timeout = 1):
rlist, wlist, xlist = select(self.reader_set, self.writer_set, self.exc_list,
timeout)
for s in rlist:
if s == self.listening_socket:
try:
new_s, addr = s.accept()
logging.info('accepted a connection from %s:%d', addr[0], addr[1])
conn = self.klass(self, new_s, addr)
self.register(conn)
conn.connectionAccepted()
s.setblocking(0)
s.connect(addr)
except socket.error, m:
if m[0] == errno.EAGAIN:
continue
raise
if m[0] == errno.EINPROGRESS:
self.connecting = True
event_manager.addWriter(self)
else:
conn = self.connection_dict[s]
conn.readable()
for s in wlist:
conn = self.connection_dict[s]
conn.writable()
# Check idle events. Do not check them out too often, because this
# is somehow heavy.
event_list = self.event_list
if event_list:
t = time()
if t - self.prev_time >= 1:
self.prev_time = t
event_list.sort(key = lambda event: event.getTime())
for event in tuple(event_list):
if event(t):
event_list.pop(0)
raise
else:
break
self.handler.connectionCompleted()
event_manager.addReader(self)
except:
handler.connectionFailed(self)
self.close()
def addIdleEvent(self, event):
self.event_list.append(event)
def writable(self):
"""Called when self is writable."""
if self.connecting:
err = self.s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err:
self.connectionFailed()
self.close()
return
else:
self.connecting = False
self.handler.connectionCompleted(self)
self.cm.addReader(self.s)
else:
Connection.writable(self)
def removeIdleEvent(self, event):
try:
self.event_list.remove(event)
except ValueError:
class ServerConnection(Connection):
"""A connection from a remote node to this node."""
pass
def addReader(self, s):
self.reader_set.add(s)
def removeReader(self, s):
self.reader_set.discard(s)
def addWriter(self, s):
self.writer_set.add(s)
def removeWriter(self, s):
self.writer_set.discard(s)
import logging
from select import select
from time import time
class IdleEvent(object):
"""This class represents an event called when a connection is waiting for
a message too long."""
def __init__(self, conn, msg_id, timeout, additional_timeout):
self._conn = conn
self._id = msg_id
t = time()
self._time = t + timeout
self._critical_time = t + timeout + additional_timeout
self._additional_timeout = additional_timeout
def getId(self):
return self._id
def getTime(self):
return self._time
def getCriticalTime(self):
return self._critical_time
def __call__(self, t):
conn = self._conn
if t > self._critical_time:
logging.info('timeout with %s:%d', *(conn.getAddress()))
conn.getHandler().timeoutExpired(conn)
conn.close()
return True
elif t > self._time:
if self._additional_timeout > 5:
self._additional_timeout -= 5
conn.expectMessage(self._id, 5, self._additional_timeout)
# Start a keep-alive packet.
logging.info('sending a ping to %s:%d', *(conn.getAddress()))
msg_id = conn.getNextId()
conn.addPacket(Packet().ping(msg_id))
conn.expectMessage(msg_id, 5, 0)
else:
conn.expectMessage(self._id, self._additional_timeout, 0)
return True
return False
class EventManager(object):
"""This class manages connections and events."""
def __init__(self):
self.connection_dict = {}
self.reader_set = set([])
self.writer_set = set([])
self.exc_list = []
self.event_list = []
self.prev_time = time()
def getConnectionList(self):
return self.connection_dict.values()
def register(self, conn):
self.connection_dict[conn.getSocket()] = conn
def unregister(self, conn):
del self.connection_dict[conn.getSocket()]
def poll(self, timeout = 1):
rlist, wlist, xlist = select(self.reader_set, self.writer_set, self.exc_list,
timeout)
for s in rlist:
conn = self.connection_dict[s]
conn.readable()
for s in wlist:
conn = self.connection_dict[s]
conn.writable()
# Check idle events. Do not check them out too often, because this
# is somehow heavy.
event_list = self.event_list
if event_list:
t = time()
if t - self.prev_time >= 1:
self.prev_time = t
event_list.sort(key = lambda event: event.getTime())
for event in tuple(event_list):
if event(t):
event_list.pop(0)
else:
break
def addIdleEvent(self, event):
self.event_list.append(event)
def removeIdleEvent(self, event):
try:
self.event_list.remove(event)
except ValueError:
pass
def addReader(self, conn):
self.reader_set.add(conn.getSocket())
def removeReader(self, conn):
self.reader_set.discard(conn.getSocket())
def addWriter(self, conn):
self.writer_set.add(conn.getSocket())
def removeWriter(self, conn):
self.writer_set.discard(conn.getSocket())
import logging
from protocol import Packet, ProtocolError
from connection import ServerConnection
from protocol import ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
PING, PONG, ASK_PRIMARY_MASTER, ANSWER_PRIMARY_MASTER, ANNOUNCE_PRIMARY_MASTER, \
REELECT_PRIMARY_MASTER, NOTIFY_NODE_INFORMATION, START_OPERATION, \
STOP_OPERATION, ASK_FINISHING_TRANSACTIONS, ANSWER_FINISHING_TRANSACTIONS, \
FINISH_TRANSACTIONS, \
NOT_READY_CODE, OID_NOT_FOUND_CODE, SERIAL_NOT_FOUND_CODE, TID_NOT_FOUND_CODE, \
PROTOCOL_ERROR_CODE, TIMEOUT_ERROR_CODE, BROKEN_NODE_DISALLOWED_CODE, \
INTERNAL_ERROR_CODE
class EventHandler(object):
"""This class handles events."""
def __init__(self):
self.initPacketDispatchTable()
self.initErrorDispatchTable()
def connectionStarted(self, conn):
"""Called when a connection is started."""
pass
def connectionCompleted(self, conn):
"""Called when a connection is completed."""
pass
def connectionFailed(self, conn):
"""Called when a connection failed."""
pass
def connectionAccepted(self, conn, s, addr):
"""Called when a connection is accepted."""
new_conn = ServerConnection(conn.getEventManager(), conn.getHandler(),
s = s, addr = addr)
# A request for a node identification should arrive.
new_conn.expectMessage(timeout = 10, additional_timeout = 0)
def timeoutExpired(self, conn):
"""Called when a timeout event occurs."""
pass
def connectionClosed(self, conn):
"""Called when a connection is closed by the peer."""
pass
def packetReceived(self, conn, packet):
"""Called when a packet is received."""
self.dispatch(conn, packet)
def packetMalformed(self, conn, packet, error_message):
"""Called when a packet is malformed."""
logging.info('malformed packet: %s', error_message)
conn.addPacket(Packet().protocolError(packet.getId(), error_message))
conn.abort()
self.peerBroken(conn)
def peerBroken(self, conn):
"""Called when a peer is broken."""
logging.error('%s:%d is broken', *(conn.getAddress()))
def dispatch(self, conn, packet):
"""This is a helper method to handle various packet types."""
t = packet.getType()
try:
method = self.packet_dispatch_table[t]
args = packet.decode()
method(conn, packet, *args)
except ValueError:
self.handleUnexpectedPacket(conn, packet)
except ProtocolError, m:
self.packetMalformed(conn, packet, m[1])
def handleUnexpectedPacket(self, conn, packet, message = None):
"""Handle an unexpected packet."""
if message is None:
message = 'unexpected packet type %d' % packet.getType()
else:
message = 'unexpected packet: ' + message
logging.info('%s', message)
conn.addPacket(Packet().protocolError(packet.getId(), message))
conn.abort()
self.peerBroken(conn)
# Packet handlers.
def handleError(self, conn, packet, code, message):
try:
method = self.error_dispatch_table[code]
method(conn, packet, message)
except ValueError:
self.handleUnexpectedPacket(conn, packet, message)
def handleRequestNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port, name):
self.handleUnexpectedPacket(conn, packet)
def handleAcceptNodeIdentification(self, conn, packet, node_type,
uuid, ip_address, port):
self.handleUnexpectedPacket(conn, packet)
def handlePing(self, conn, packet):
logging.info('got a ping packet; am I overloaded?')
conn.addPacket(Packet().pong(packet.getId()))
def handlePong(self, conn, packet):
pass
def handleAskPrimaryNode(self, conn, packet):
self.handleUnexpectedPacket(conn, packet)
def handleAnswerPrimaryNode(self, conn, packet, primary_uuid, known_master_list):
self.handleUnexpectedPacket(conn, packet)
def handleAnnouncePrimaryMaster(self, conn, packet):
self.handleUnexpectedPacket(conn, packet)
def handleReelectPrimaryMaster(self, conn, packet):
self.handleUnexpectedPacket(conn, packet)
def handleNotifyNodeInformation(self, conn, packet, node_list):
self.handleUnexpectedPacket(conn, packet)
# Error packet handlers.
handleNotReady = handleUnexpectedPacket
handleOidNotFound = handleUnexpectedPacket
handleSerialNotFound = handleUnexpectedPacket
handleTidNotFound = handleUnexpectedPacket
def handleProtocolError(self, conn, packet, message):
raise RuntimeError, 'protocol error: %s' % (message,)
def handleTimeoutError(self, conn, packet, message):
raise RuntimeError, 'timeout error: %s' % (message,)
def handleBrokenNodeDisallowedError(self, conn, packet, message):
raise RuntimeError, 'broken node disallowed error: %s' % (message,)
def handleInternalError(self, conn, packet, message):
self.peerBroken(conn)
conn.close()
def initPacketDispatchTable(self):
d = {}
d[ERROR] = self.handleError
d[REQUEST_NODE_IDENTIFICATION] = self.handleRequestNodeIdentification
d[ACCEPT_NODE_IDENTIFICATION] = self.handleAcceptNodeIdentification
d[PING] = self.handlePing
d[PONG] = self.handlePong
d[ASK_PRIMARY_MASTER] = self.handleAskPrimaryMaster
d[ANSWER_PRIMARY_MASTER] = self.handleAnswerPrimaryMaster
d[ANNOUNCE_PRIMARY_MASTER] = self.handleAnnouncePrimaryMaster
d[REELECT_PRIMARY_MASTER] = self.handleReelectPrimaryMaster
d[NOTIFY_NODE_INFORMATION] = self.handleNotifyNodeInformation
self.packet_dispatch_table = d
def initErrorDispatchTable(self):
d = {}
d[NOT_READY_CODE] = self.handleNotReady
d[OID_NOT_FOUND_CODE] = self.handleOidNotFound
d[SERIAL_NOT_FOUND_CODE] = self.handleSerialNotFound
d[TID_NOT_FOUND_CODE] = self.handleTidNotFound
d[PROTOCOL_ERROR_CODE] = self.handleProtocolError
d[TIMEOUT_ERROR_CODE] = self.handleTimeoutError
d[BROKEN_NODE_DISALLOWED_CODE] = self.handleBrokenNodeDisallowedError
d[INTERNAL_ERROR_CODE] = self.handleInternalError
self.error_dispatch_table = d
......@@ -5,66 +5,18 @@ from socket import inet_aton
from time import time
from connection import ConnectionManager
from connection import Connection as BaseConnection
from database import DatabaseManager
from config import ConfigurationManager
from protocol import Packet, ProtocolError, \
MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \
INVALID_UUID, INVALID_TID, INVALID_OID, \
PROTOCOL_ERROR_CODE, TIMEOUT_ERROR_CODE, BROKEN_NODE_DISALLOWED_CODE, \
INTERNAL_ERROR_CODE, \
ERROR, REQUEST_NODE_IDENTIFICATION, ACCEPT_NODE_IDENTIFICATION, \
PING, PONG, ASK_PRIMARY_MASTER, ANSWER_PRIMARY_MASTER, \
ANNOUNCE_PRIMARY_MASTER, REELECT_PRIMARY_MASTER
from node import NodeManager, MasterNode, StorageNode, ClientNode, \
RUNNING_STATE, TEMPORARILY_DOWN_STATE, DOWN_STATE, BROKEN_STATE
from node import NodeManager, MasterNode, StorageNode, ClientNode
from handler import EventHandler
from event import EventManager
from util import dump
class NeoException(Exception): pass
class ElectionFailure(NeoException): pass
class PrimaryFailure(NeoException): pass
class RecoveryFailure(NeoException): pass
class Connection(BaseConnection):
"""This class provides a master-specific connection."""
_uuid = None
def setUUID(self, uuid):
self._uuid = uuid
def getUUID(self):
return self._uuid
# Feed callbacks to the master node.
def connectionFailed(self):
self.cm.app.connectionFailed(self)
BaseConnection.connectionFailed(self)
def connectionCompleted(self):
self.cm.app.connectionCompleted(self)
BaseConnection.connectionCompleted(self)
def connectionAccepted(self):
self.cm.app.connectionAccepted(self)
BaseConnection.connectionAccepted(self)
def connectionClosed(self):
self.cm.app.connectionClosed(self)
BaseConnection.connectionClosed(self)
def packetReceived(self, packet):
self.cm.app.packetReceived(self, packet)
BaseConnection.packetReceived(self, packet)
def timeoutExpired(self):
self.cm.app.timeoutExpired(self)
BaseConnection.timeoutExpired(self)
def peerBroken(self):
self.cm.app.peerBroken(self)
BaseConnection.peerBroken(self)
class Application(object):
"""The master node application."""
......@@ -72,137 +24,51 @@ class Application(object):
def __init__(self, file, section):
config = ConfigurationManager(file, section)
self.database = config.getDatabase()
self.user = config.getUser()
self.password = config.getPassword()
logging.debug('database is %s, user is %s, password is %s',
self.database, self.user, self.password)
self.num_replicas = config.getReplicas()
self.num_partitions = config.getPartitions()
self.name = config.getName()
logging.debug('the number of replicas is %d, the number of partitions is %d, the name is %s',
self.num_replicas, self.num_partitions, self.name)
self.ip_address, self.port = config.getServer()
logging.debug('IP address is %s, port is %d', self.ip_address, self.port)
self.server = config.getServer()
logging.debug('IP address is %s, port is %d', *(self.server))
# Exclude itself from the list.
self.master_node_list = [n for n in config.getMasterNodeList()
if n != (self.ip_address, self.port)]
self.master_node_list = [n for n in config.getMasterNodeList() if n != self.server]
logging.debug('master nodes are %s', self.master_node_list)
# Internal attributes.
self.dm = DatabaseManager(self.database, self.user, self.password)
self.cm = ConnectionManager(app = self, connection_klass = Connection)
self.em = EventManager()
self.nm = NodeManager()
self.primary = None
self.primary_master_node = None
self.ready = False
# Co-operative threads. Simulated by generators.
self.thread_dict = {}
self.server_thread_method = None
self.event = None
def initializeDatabase(self):
"""Initialize a database by recreating all the tables.
In master nodes, the database is used only to make
some data persistent. All operations are executed on memory.
Thus it is not necessary to make indices on the tables."""
q = self.dm.query
e = MySQLdb.escape_string
q("""DROP TABLE IF EXISTS loid, ltid, self, stn, part""")
q("""CREATE TABLE loid (
oid BINARY(8) NOT NULL
) ENGINE = InnoDB COMMENT = 'Last Object ID'""")
q("""INSERT loid VALUES ('%s')""" % e(INVALID_OID))
q("""CREATE TABLE ltid (
tid BINARY(8) NOT NULL
) ENGINE = InnoDB COMMENT = 'Last Transaction ID'""")
q("""INSERT ltid VALUES ('%s')""" % e(INVALID_TID))
q("""CREATE TABLE self (
uuid BINARY(16) NOT NULL
) ENGINE = InnoDB COMMENT = 'UUID'""")
# XXX Generate an UUID for self. For now, just use a random string.
# Avoid an invalid UUID.
while 1:
uuid = os.urandom(16)
if uuid != INVALID_UUID:
break
self.uuid = uuid
q("""INSERT self VALUES ('%s')""" % e(uuid))
q("""CREATE TABLE stn (
nid INT UNSIGNED NOT NULL UNIQUE,
uuid BINARY(16) NOT NULL UNIQUE,
state CHAR(1) NOT NULL
) ENGINE = InnoDB COMMENT = 'Storage Nodes'""")
q("""CREATE TABLE part (
pid INT UNSIGNED NOT NULL,
nid INT UNSIGNED NOT NULL,
state CHAR(1) NOT NULL
) ENGINE = InnoDB COMMENT = 'Partition Table'""")
def loadData(self):
"""Load persistent data from a database."""
logging.info('loading data from MySQL')
q = self.dm.query
result = q("""SELECT oid FROM loid""")
if len(result) != 1:
raise RuntimeError, 'the table loid has %d rows' % len(result)
self.loid = result[0][0]
logging.info('the last OID is %r' % dump(self.loid))
result = q("""SELECT tid FROM ltid""")
if len(result) != 1:
raise RuntimeError, 'the table ltid has %d rows' % len(result)
self.ltid = result[0][0]
logging.info('the last TID is %r' % dump(self.ltid))
result = q("""SELECT uuid FROM self""")
if len(result) != 1:
raise RuntimeError, 'the table self has %d rows' % len(result)
self.uuid = result[0][0]
logging.info('the UUID is %r' % dump(self.uuid))
# FIXME load storage and partition information here.
self.loid = INVALID_OID
self.ltid = INVALID_TID
def run(self):
"""Make sure that the status is sane and start a loop."""
# Sanity checks.
logging.info('checking the database')
result = self.dm.query("""SHOW TABLES""")
table_set = set([r[0] for r in result])
existing_table_list = [t for t in ('loid', 'ltid', 'self', 'stn', 'part')
if t in table_set]
if len(existing_table_list) == 0:
# Possibly this is the first time to launch...
self.initializeDatabase()
elif len(existing_table_list) != 5:
raise RuntimeError, 'database inconsistent'
# XXX More tests are necessary (e.g. check the table structures,
# check the number of partitions, etc.).
# Now ready to load persistent data from the database.
self.loadData()
for ip_address, port in self.master_node_list:
self.nm.add(MasterNode(ip_address = ip_address, port = port))
if self.num_replicas <= 0:
raise RuntimeError, 'replicas must be more than zero'
if self.num_partitions <= 0:
raise RuntimeError, 'partitions must be more than zero'
if len(self.name) == 0:
raise RuntimeError, 'cluster name must be non-empty'
for server in self.master_node_list:
self.nm.add(MasterNode(server = server))
# Make a listening port.
self.cm.listen(self.ip_address, self.port)
ListeningConnection(self.em, None, addr = self.server)
# Start the election of a primary master node.
self.electPrimary()
......@@ -212,95 +78,18 @@ class Application(object):
try:
if self.primary:
while 1:
try:
self.startRecovery()
except RecoveryFailure:
logging.critical('unable to recover the system; use full recovery')
raise
self.playPrimaryRole()
else:
self.playSecondaryRole()
raise RuntimeError, 'should not reach here'
except (ElectionFailure, PrimaryFailure):
# Forget all connections.
for conn in cm.getConnectionList():
for conn in self.em.getConnectionList():
conn.close()
self.thread_dict.clear()
# Reelect a new primary master.
self.electPrimary(bootstrap = False)
CONNECTION_FAILED = 'connection failed'
def connectionFailed(self, conn):
addr = (conn.ip_address, conn.port)
t = self.thread_dict[addr]
self.event = (self.CONNECTION_FAILED, conn)
try:
t.next()
except StopIteration:
del self.thread_dict[addr]
CONNECTION_COMPLETED = 'connection completed'
def connectionCompleted(self, conn):
addr = (conn.ip_address, conn.port)
t = self.thread_dict[addr]
self.event = (self.CONNECTION_COMPLETED, conn)
try:
t.next()
except StopIteration:
del self.thread_dict[addr]
CONNECTION_CLOSED = 'connection closed'
def connectionClosed(self, conn):
addr = (conn.ip_address, conn.port)
t = self.thread_dict[addr]
self.event = (self.CONNECTION_CLOSED, conn)
try:
t.next()
except StopIteration:
del self.thread_dict[addr]
CONNECTION_ACCEPTED = 'connection accepted'
def connectionAccepted(self, conn):
addr = (conn.ip_address, conn.port)
logging.debug('making a server thread for %s:%d', conn.ip_address, conn.port)
t = self.server_thread_method()
self.thread_dict[addr] = t
self.event = (self.CONNECTION_ACCEPTED, conn)
try:
t.next()
except StopIteration:
del self.thread_dict[addr]
TIMEOUT_EXPIRED = 'timeout expired'
def timeoutExpired(self, conn):
addr = (conn.ip_address, conn.port)
t = self.thread_dict[addr]
self.event = (self.TIMEOUT_EXPIRED, conn)
try:
t.next()
except StopIteration:
del self.thread_dict[addr]
PEER_BROKEN = 'peer broken'
def peerBroken(self, conn):
addr = (conn.ip_address, conn.port)
t = self.thread_dict[addr]
self.event = (self.PEER_BROKEN, conn)
try:
t.next()
except StopIteration:
del self.thread_dict[addr]
PACKET_RECEIVED = 'packet received'
def packetReceived(self, conn, packet):
addr = (conn.ip_address, conn.port)
t = self.thread_dict[addr]
self.event = (self.PACKET_RECEIVED, conn, packet)
try:
t.next()
except StopIteration:
del self.thread_dict[addr]
def electPrimaryClientIterator(self):
"""Handle events for a client connection."""
# The first event. This must be a connection failure or connection completion.
......
# Default parameters.
[DEFAULT]
# The list of master nodes.
master_nodes: 127.0.0.1:10010 127.0.0.1:10011 127.0.0.1:10012
#replicas: 1
#partitions: 1009
#name: main
# The number of replicas.
replicas: 1
# The number of partitions.
partitions: 1009
# The name of this cluster.
name: main
# The user name for the database.
user: neo
# The password for the database.
password: neo
# The first master.
[master1]
database: master1
user: neo
#password:
server: 127.0.0.1:10010
# The second master.
[master2]
database: master2
user: neo
server: 127.0.0.1:10011
# The third master.
[master3]
database: master3
user: neo
server: 127.0.0.1:10012
# The first storage.
[storage1]
database: neo1
server: 127.0.0.1:10020
......@@ -19,15 +19,13 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from optparse import OptionParser
from master import Application
from master.app import Application
import logging
# FIXME should be configurable
logging.basicConfig(level = logging.DEBUG)
parser = OptionParser()
parser.add_option('-i', '--initialize', action = 'store_true',
help = 'initialize the database')
parser.add_option('-v', '--verbose', action = 'store_true',
help = 'print verbose messages')
parser.add_option('-c', '--config', help = 'specify a configuration file')
parser.add_option('-s', '--section', help = 'specify a configuration section')
......@@ -36,9 +34,10 @@ parser.add_option('-s', '--section', help = 'specify a configuration section')
config = options.config or 'neo.conf'
section = options.section or 'master'
app = Application(config, section)
if options.initialize:
app.initializeDatabase()
if options.verbose:
logging.basicConfig(level = logging.DEBUG)
else:
logging.basicConfig(level = logging.ERROR)
app = Application(config, section)
app.run()
from time import time
from protocol import RUNNING_STATE, TEMPORARILY_DOWN_STATE, DOWN_STATE, BROKEN_STATE
from protocol import RUNNING_STATE, TEMPORARILY_DOWN_STATE, DOWN_STATE, BROKEN_STATE, \
MASTER_NODE_TYPE, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE
class Node(object):
"""This class represents a node."""
def __init__(self, ip_address = None, port = None, uuid = None):
def __init__(self, server = None, uuid = None):
self.state = RUNNING_STATE
self.ip_address = ip_address
self.port = port
self.server = server
self.uuid = uuid
self.manager = None
self.last_state_change = time()
......@@ -27,16 +27,15 @@ class Node(object):
self.state = new_state
self.last_state_change = time()
def setServer(self, ip_address, port):
if self.ip_address is not None:
def setServer(self, server):
if self.server is not None:
self.manager.unregisterServer(self)
self.ip_address = ip_address
self.port = port
self.server = server
self.manager.registerServer(self)
def getServer(self):
return self.ip_address, self.port
return self.server
def setUUID(self, uuid):
if self.uuid is not None:
......@@ -48,17 +47,23 @@ class Node(object):
def getUUID(self):
return self.uuid
def getNodeType(self):
raise NotImplementedError
class MasterNode(Node):
"""This class represents a master node."""
pass
def getNodeType(self):
return MASTER_NODE_TYPE
class StorageNode(Node):
"""This class represents a storage node."""
pass
def getNodeType(self):
return STORAGE_NODE_TYPE
class ClientNode(Node):
"""This class represents a client node."""
pass
def getNodeType(self):
return CLIENT_NODE_TYPE
class NodeManager(object):
"""This class manages node status."""
......@@ -71,7 +76,7 @@ class NodeManager(object):
def add(self, node):
node.setManager(self)
self.node_list.append(node)
if node.getServer()[0] is not None:
if node.getServer() is not None:
self.registerServer(node)
if node.getUUID() is not None:
self.registerUUID(node)
......@@ -113,8 +118,8 @@ class NodeManager(object):
def getClientNodeList(self):
return self.getNodeList(filter = lambda node: isinstance(node, ClientNode))
def getNodeByServer(self, ip_address, port):
return self.server_dict.get((ip_address, port))
def getNodeByServer(self, server):
return self.server_dict.get(server)
def getNodeByUUID(self, uuid):
return self.uuid_dict.get(uuid)
......@@ -21,14 +21,12 @@ ASK_PRIMARY_MASTER = 0x0003
ANSWER_PRIMARY_MASTER = 0x8003
ANNOUNCE_PRIMARY_MASTER = 0x0004
REELECT_PRIMARY_MASTER = 0x0005
NOTIFY_NODE_STATE_CHANGE = 0x0006
SEND_NODE_INFORMATION = 0x0007
START_OPERATION = 0x0008
STOP_OPERATION = 0x0009
ASK_FINISHING_TRANSACTIONS = 0x000a
ANSWER_FINISHING_TRANSACTIONS = 0x800a
FINISH_TRANSACTIONS = 0x000b
NOTIFY_NODE_INFORMATION = 0x0006
START_OPERATION = 0x0007
STOP_OPERATION = 0x0008
ASK_FINISHING_TRANSACTIONS = 0x0009
ANSWER_FINISHING_TRANSACTIONS = 0x8009
FINISH_TRANSACTIONS = 0x000a
# Error codes.
NOT_READY_CODE = 1
......@@ -63,7 +61,7 @@ INVALID_OID = '\0\0\0\0\0\0\0\0'
class ProtocolError(Exception): pass
class Packet:
class Packet(object):
"""A packet."""
_id = None
......@@ -152,16 +150,16 @@ class Packet:
self._body = pack('!H16s4sH', node_type, uuid, inet_aton(ip_address), port)
return self
def askPrimaryMaster(self, msg_id, ltid, loid):
def askPrimaryMaster(self, msg_id):
self._id = msg_id
self._type = ASK_PRIMARY_MASTER
self._body = ltid + loid
self._body = ''
return self
def answerPrimaryMaster(self, msg_id, ltid, loid, primary_uuid, known_master_list):
def answerPrimaryMaster(self, msg_id, primary_uuid, known_master_list):
self._id = msg_id
self._type = ANSWER_PRIMARY_MASTER
body = [ltid, loid, primary_uuid, pack('!L', len(known_master_list))]
body = [primary_uuid, pack('!L', len(known_master_list))]
for master in known_master_list:
body.append(pack('!4sH16s', inet_aton(master[0]), master[1], master[2]))
self._body = ''.join(body)
......@@ -179,21 +177,9 @@ class Packet:
self._body = ''
return self
def notifyNodeStateChange(self, msg_id, node_type, ip_address, port, uuid, state):
self._id = msg_id
self._type = NOTIFY_NODE_STATE_CHANGE
self._body = pack('!H4sH16sH', node_type, inet_aton(ip_address), port, uuid, state)
return self
def askNodeInformation(self, msg_id):
def notifyNodeInformation(self, msg_id, node_list):
self._id = msg_id
self._type = ASK_NODE_INFORMATION
self._body = ''
return self
def answerNodeInformation(self, msg_id, node_list):
self._id = msg_id
self._type = ANSWER_NODE_INFORMATION
self._type = NOTIFY_NODE_INFORMATION
body = [pack('!L', len(node_list))]
for node_type, ip_address, port, uuid, state in node_list:
body.append(pack('!H4sH16sH', node_type, inet_aton(ip_address), port,
......@@ -261,16 +247,12 @@ class Packet:
decode_table[ACCEPT_NODE_IDENTIFICATION] = _decodeAcceptNodeIdentification
def _decodeAskPrimaryMaster(self):
try:
ltid, loid = unpack('!8s8s', self._body)
except:
raise ProtocolError(self, 'invalid ask primary master')
return ltid, loid
pass
decode_table[ASK_PRIMARY_MASTER] = _decodeAskPrimaryMaster
def _decodeAnswerPrimaryMaster(self):
try:
ltid, loid, primary_uuid, n = unpack('!8s8s16sL', self._body[:36])
primary_uuid, n = unpack('!16sL', self._body[:36])
known_master_list = []
for i in xrange(n):
ip_address, port, uuid = unpack('!4sH16s', self._body[36+i*22:58+i*22])
......@@ -278,7 +260,7 @@ class Packet:
known_master_list.append((ip_address, port, uuid))
except:
raise ProtocolError(self, 'invalid answer primary master')
return ltid, loid, primary_uuid, known_master_list
return primary_uuid, known_master_list
decode_table[ANSWER_PRIMARY_MASTER] = _decodeAnswerPrimaryMaster
def _decodeAnnouncePrimaryMaster(self):
......@@ -289,24 +271,7 @@ class Packet:
pass
decode_table[REELECT_PRIMARY_MASTER] = _decodeReelectPrimaryMaster
def _decodeNotifyNodeStateChange(self):
try:
node_type, ip_address, port, uuid, state = unpack('!H4sH16sH', self._body[:26])
ip_address = inet_ntoa(ip_address)
except:
raise ProtocolError(self, 'invalid notify node state change')
if node_type not in VALID_NODE_TYPE_LIST:
raise ProtocolError(self, 'invalid node type %d' % node_type)
if state not in VALID_NODE_STATE_LIST:
raise ProtocolError(self, 'invalid node state %d' % state)
return node_type, ip_address, port, uuid, state
decode_table[NOTIFY_NODE_STATE_CHANGE] = _decodeNotifyNodeStateChange
def _decodeAskNodeInformation(self):
pass
decode_table[ASK_NODE_INFORMATION] = _decodeAskNodeInformation
def _decodeAnswerNodeInformation(self):
def _decodeNotifyNodeInformation(self):
try:
n = unpack('!L', self._body[:4])
node_list = []
......@@ -324,4 +289,4 @@ class Packet:
except:
raise ProtocolError(self, 'invalid answer node information')
return node_list
decode_table[ANSWER_NODE_INFORMATION] = _decodeAnswerNodeInformation
decode_table[NOTIFY_NODE_INFORMATION] = _decodeNotifyNodeInformation
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment