Commit 5c6fd74a authored by Grégory Wisniewski's avatar Grégory Wisniewski

No more use None value with PTIDs, all in one commit since it could set neo in an

unstable state. Update tests according to changes.


git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@451 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent baf47d86
...@@ -28,7 +28,7 @@ from neo.client.mq import MQ ...@@ -28,7 +28,7 @@ from neo.client.mq import MQ
from neo.node import NodeManager, MasterNode, StorageNode from neo.node import NodeManager, MasterNode, StorageNode
from neo.connection import MTClientConnection from neo.connection import MTClientConnection
from neo.protocol import Packet, INVALID_UUID, INVALID_TID, INVALID_PARTITION, \ from neo.protocol import Packet, INVALID_UUID, INVALID_TID, INVALID_PARTITION, \
STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \ INVALID_PTID, STORAGE_NODE_TYPE, CLIENT_NODE_TYPE, \
RUNNING_STATE, TEMPORARILY_DOWN_STATE, \ RUNNING_STATE, TEMPORARILY_DOWN_STATE, \
UP_TO_DATE_STATE, FEEDING_STATE, INVALID_SERIAL UP_TO_DATE_STATE, FEEDING_STATE, INVALID_SERIAL
from neo.client.handler import ClientEventHandler, ClientAnswerEventHandler from neo.client.handler import ClientEventHandler, ClientAnswerEventHandler
...@@ -203,7 +203,7 @@ class Application(object): ...@@ -203,7 +203,7 @@ class Application(object):
self.uuid = INVALID_UUID self.uuid = INVALID_UUID
self.mq_cache = MQ() self.mq_cache = MQ()
self.new_oid_list = [] self.new_oid_list = []
self.ptid = None self.ptid = INVALID_PTID
self.num_replicas = 0 self.num_replicas = 0
self.num_partitions = 0 self.num_partitions = 0
self.handler = ClientEventHandler(self, self.dispatcher) self.handler = ClientEventHandler(self, self.dispatcher)
......
...@@ -340,15 +340,14 @@ class Application(object): ...@@ -340,15 +340,14 @@ class Application(object):
for conn in em.getConnectionList(): for conn in em.getConnectionList():
conn.setHandler(handler) conn.setHandler(handler)
prev_lptid = None
self.loid = INVALID_OID self.loid = INVALID_OID
self.ltid = INVALID_TID self.ltid = INVALID_TID
self.lptid = None self.lptid = INVALID_PTID
while 1: while 1:
self.target_uuid = None self.target_uuid = None
self.pt.clear() self.pt.clear()
if self.lptid is not None: if self.lptid != INVALID_PTID:
# I need to retrieve last ids again. # I need to retrieve last ids again.
logging.info('resending Ask Last IDs') logging.info('resending Ask Last IDs')
for conn in em.getConnectionList(): for conn in em.getConnectionList():
...@@ -391,7 +390,7 @@ class Application(object): ...@@ -391,7 +390,7 @@ class Application(object):
if self.lptid == INVALID_PTID: if self.lptid == INVALID_PTID:
# This looks like the first time. So make a fresh table. # This looks like the first time. So make a fresh table.
logging.debug('creating a new partition table') logging.debug('creating a new partition table')
self.getNextPartitionTableID() self.lptid = pack('!Q', 1) # ptid != INVALID_PTID
self.pt.make(nm.getStorageNodeList()) self.pt.make(nm.getStorageNodeList())
else: else:
# Obtain a partition table. It is necessary to split this # Obtain a partition table. It is necessary to split this
...@@ -740,7 +739,7 @@ class Application(object): ...@@ -740,7 +739,7 @@ class Application(object):
em.poll(1) em.poll(1)
def getNextPartitionTableID(self): def getNextPartitionTableID(self):
if self.lptid is None: if self.lptid == INVALID_PTID:
raise RuntimeError, 'I do not know the last Partition Table ID' raise RuntimeError, 'I do not know the last Partition Table ID'
ptid = unpack('!Q', self.lptid)[0] ptid = unpack('!Q', self.lptid)[0]
......
...@@ -226,9 +226,9 @@ class RecoveryEventHandler(MasterEventHandler): ...@@ -226,9 +226,9 @@ class RecoveryEventHandler(MasterEventHandler):
p.askLastIDs(msg_id) p.askLastIDs(msg_id)
conn.addPacket(p) conn.addPacket(p)
conn.expectMessage(msg_id) conn.expectMessage(msg_id)
elif node.getNodeType() == ADMIN_NODE_TYPE and app.lptid not in (INVALID_PTID, None): elif node.getNodeType() == ADMIN_NODE_TYPE and app.lptid != INVALID_PTID:
# send partition table if exists # send partition table if exists
logging.info('sending partition table %s to %s' %(app.lptid, logging.info('sending partition table %s to %s' % (app.lptid,
conn.getAddress())) conn.getAddress()))
# Split the packet if too huge. # Split the packet if too huge.
p = Packet() p = Packet()
...@@ -331,7 +331,7 @@ class RecoveryEventHandler(MasterEventHandler): ...@@ -331,7 +331,7 @@ class RecoveryEventHandler(MasterEventHandler):
app.loid = loid app.loid = loid
if app.ltid < ltid: if app.ltid < ltid:
app.ltid = ltid app.ltid = ltid
if app.lptid is None or app.lptid < lptid: if app.lptid == INVALID_PTID or app.lptid < lptid:
app.lptid = lptid app.lptid = lptid
# I need to use the node which has the max Partition Table ID. # I need to use the node which has the max Partition Table ID.
app.target_uuid = uuid app.target_uuid = uuid
......
...@@ -91,7 +91,7 @@ server: 127.0.0.1:10023 ...@@ -91,7 +91,7 @@ server: 127.0.0.1:10023
def test_01_getNextPartitionTableID(self): def test_01_getNextPartitionTableID(self):
# must raise as we don"t have one # must raise as we don"t have one
self.assertEqual(self.app.lptid, INVALID_PTID) self.assertEqual(self.app.lptid, INVALID_PTID)
self.app.lptid = None self.app.lptid = INVALID_PTID
self.assertRaises(RuntimeError, self.app.getNextPartitionTableID) self.assertRaises(RuntimeError, self.app.getNextPartitionTableID)
# set one # set one
self.app.lptid = p64(23) self.app.lptid = p64(23)
......
...@@ -100,6 +100,7 @@ server: 127.0.0.1:10023 ...@@ -100,6 +100,7 @@ server: 127.0.0.1:10023
tmp_file.close() tmp_file.close()
self.app = Application(self.tmp_path, "mastertest") self.app = Application(self.tmp_path, "mastertest")
self.app.pt.clear() self.app.pt.clear()
self.app.lptid = pack('!Q', 1)
self.app.em = Mock({"getConnectionList" : []}) self.app.em = Mock({"getConnectionList" : []})
self.app.finishing_transaction_dict = {} self.app.finishing_transaction_dict = {}
for server in self.app.master_node_list: for server in self.app.master_node_list:
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import logging import logging
import os import os
from time import time from time import time
from struct import unpack from struct import unpack, pack
from collections import deque from collections import deque
from neo.config import ConfigurationManager from neo.config import ConfigurationManager
...@@ -84,13 +84,7 @@ class Application(object): ...@@ -84,13 +84,7 @@ class Application(object):
dm.setName(self.name) dm.setName(self.name)
elif name != self.name: elif name != self.name:
raise RuntimeError('name does not match with the database') raise RuntimeError('name does not match with the database')
self.ptid = dm.getPTID() # return ptid or INVALID_PTID
ptid = dm.getPTID()
if ptid is None:
self.ptid = INVALID_PTID
dm.setPTID(self.ptid)
else:
self.ptid = ptid
def loadPartitionTable(self): def loadPartitionTable(self):
"""Load a partition table from the database.""" """Load a partition table from the database."""
......
...@@ -26,7 +26,7 @@ from struct import pack, unpack ...@@ -26,7 +26,7 @@ from struct import pack, unpack
from neo.storage.database import DatabaseManager from neo.storage.database import DatabaseManager
from neo.exception import DatabaseFailure from neo.exception import DatabaseFailure
from neo.util import dump from neo.util import dump
from neo.protocol import DISCARDED_STATE from neo.protocol import DISCARDED_STATE, INVALID_PTID
def p64(n): def p64(n):
return pack('!Q', n) return pack('!Q', n)
...@@ -224,7 +224,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -224,7 +224,10 @@ class MySQLDatabaseManager(DatabaseManager):
self.commit() self.commit()
def getPTID(self): def getPTID(self):
return self.getConfiguration('ptid') ptid = self.getConfiguration('ptid')
if ptid is None:
return INVALID_PTID
return ptid
def setPTID(self, ptid): def setPTID(self, ptid):
self.begin() self.begin()
......
...@@ -220,10 +220,13 @@ class StorageMySQSLdbTests(unittest.TestCase): ...@@ -220,10 +220,13 @@ class StorageMySQSLdbTests(unittest.TestCase):
value='TEST_NAME') value='TEST_NAME')
def test_15_PTID(self): def test_15_PTID(self):
self.checkConfigEntry( test = '\x01' * 8
get_call=self.db.getPTID, self.db.setup()
set_call=self.db.setPTID, self.assertEquals(self.db.getPTID(), INVALID_PTID)
value="PTID") self.db.setPTID(test)
self.assertEquals(self.db.getPTID(), test)
self.assertRaises(MySQLdb.IntegrityError, self.db.setPTID, test * 2)
self.assertEquals(self.db.getPTID(), test)
def test_16_getPartitionTable(self): def test_16_getPartitionTable(self):
# insert an entry and check it # insert an entry and check it
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment