Commit 00141976 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Make iterateForObject always yield a connection for each node.

Don't block on a non-ready node but do not skip it
Fix test method name and update it (it will never raise StopIteration)

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2591 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 71a433f0
...@@ -622,31 +622,28 @@ class Application(object): ...@@ -622,31 +622,28 @@ class Application(object):
def _loadFromStorage(self, oid, at_tid, before_tid): def _loadFromStorage(self, oid, at_tid, before_tid):
self.local_var.asked_object = 0 self.local_var.asked_object = 0
packet = Packets.AskObject(oid, at_tid, before_tid) packet = Packets.AskObject(oid, at_tid, before_tid)
while self.local_var.asked_object == 0: for node, conn in self.cp.iterateForObject(oid, readable=True):
# try without waiting for a node to be ready try:
for node, conn in self.cp.iterateForObject(oid, readable=True, self._askStorage(conn, packet)
wait_ready=False): except ConnectionClosed:
try: continue
self._askStorage(conn, packet)
except ConnectionClosed:
continue
# Check data # Check data
noid, tid, next_tid, compression, checksum, data \ noid, tid, next_tid, compression, checksum, data \
= self.local_var.asked_object = self.local_var.asked_object
if noid != oid: if noid != oid:
# Oops, try with next node # Oops, try with next node
neo.logging.error('got wrong oid %s instead of %s from %s', neo.logging.error('got wrong oid %s instead of %s from %s',
noid, dump(oid), conn) noid, dump(oid), conn)
self.local_var.asked_object = -1 self.local_var.asked_object = -1
continue continue
elif checksum != makeChecksum(data): elif checksum != makeChecksum(data):
# Check checksum. # Check checksum.
neo.logging.error('wrong checksum from %s for oid %s', neo.logging.error('wrong checksum from %s for oid %s',
conn, dump(oid)) conn, dump(oid))
self.local_var.asked_object = -1 self.local_var.asked_object = -1
continue continue
break break
if self.local_var.asked_object == -1: if self.local_var.asked_object == -1:
raise NEOStorageError('inconsistent data') raise NEOStorageError('inconsistent data')
...@@ -735,8 +732,7 @@ class Application(object): ...@@ -735,8 +732,7 @@ class Application(object):
add_involved_nodes = self.local_var.involved_nodes.add add_involved_nodes = self.local_var.involved_nodes.add
packet = Packets.AskStoreObject(oid, serial, compression, packet = Packets.AskStoreObject(oid, serial, compression,
checksum, compressed_data, data_serial, self.local_var.tid) checksum, compressed_data, data_serial, self.local_var.tid)
for node, conn in self.cp.iterateForObject(oid, writable=True, for node, conn in self.cp.iterateForObject(oid, writable=True):
wait_ready=True):
try: try:
conn.ask(packet, on_timeout=on_timeout, queue=queue) conn.ask(packet, on_timeout=on_timeout, queue=queue)
add_involved_nodes(node) add_involved_nodes(node)
...@@ -865,8 +861,7 @@ class Application(object): ...@@ -865,8 +861,7 @@ class Application(object):
str(transaction.description), dumps(transaction._extension), str(transaction.description), dumps(transaction._extension),
local_var.data_list) local_var.data_list)
add_involved_nodes = self.local_var.involved_nodes.add add_involved_nodes = self.local_var.involved_nodes.add
for node, conn in self.cp.iterateForObject(tid, writable=True, for node, conn in self.cp.iterateForObject(tid, writable=True):
wait_ready=False):
neo.logging.debug("voting object %s on %s", dump(tid), neo.logging.debug("voting object %s on %s", dump(tid),
dump(conn.getUUID())) dump(conn.getUUID()))
try: try:
...@@ -1011,7 +1006,7 @@ class Application(object): ...@@ -1011,7 +1006,7 @@ class Application(object):
cell_list = getCellList(partition, readable=True) cell_list = getCellList(partition, readable=True)
shuffle(cell_list) shuffle(cell_list)
cell_list.sort(key=getCellSortKey) cell_list.sort(key=getCellSortKey)
storage_conn = getConnForCell(cell_list[0], wait_ready=False) storage_conn = getConnForCell(cell_list[0])
storage_conn.ask(Packets.AskObjectUndoSerial(self.local_var.tid, storage_conn.ask(Packets.AskObjectUndoSerial(self.local_var.tid,
snapshot_tid, undone_tid, oid_list), queue=queue) snapshot_tid, undone_tid, oid_list), queue=queue)
...@@ -1064,8 +1059,7 @@ class Application(object): ...@@ -1064,8 +1059,7 @@ class Application(object):
def _getTransactionInformation(self, tid): def _getTransactionInformation(self, tid):
packet = Packets.AskTransactionInformation(tid) packet = Packets.AskTransactionInformation(tid)
for node, conn in self.cp.iterateForObject(tid, readable=True, for node, conn in self.cp.iterateForObject(tid, readable=True):
wait_ready=False):
try: try:
self._askStorage(conn, packet) self._askStorage(conn, packet)
except ConnectionClosed: except ConnectionClosed:
...@@ -1162,8 +1156,7 @@ class Application(object): ...@@ -1162,8 +1156,7 @@ class Application(object):
def history(self, oid, version=None, size=1, filter=None): def history(self, oid, version=None, size=1, filter=None):
# Get history informations for object first # Get history informations for object first
packet = Packets.AskObjectHistory(oid, 0, size) packet = Packets.AskObjectHistory(oid, 0, size)
for node, conn in self.cp.iterateForObject(oid, readable=True, for node, conn in self.cp.iterateForObject(oid, readable=True):
wait_ready=False):
# FIXME: we keep overwriting self.local_var.history here, we # FIXME: we keep overwriting self.local_var.history here, we
# should aggregate it instead. # should aggregate it instead.
self.local_var.history = None self.local_var.history = None
...@@ -1299,8 +1292,7 @@ class Application(object): ...@@ -1299,8 +1292,7 @@ class Application(object):
data_dict[oid] = None data_dict[oid] = None
local_var.data_list.append(oid) local_var.data_list.append(oid)
packet = Packets.AskCheckCurrentSerial(local_var.tid, serial, oid) packet = Packets.AskCheckCurrentSerial(local_var.tid, serial, oid)
for node, conn in self.cp.iterateForObject(oid, writable=True, for node, conn in self.cp.iterateForObject(oid, writable=True):
wait_ready=False):
try: try:
conn.ask(packet, queue=queue) conn.ask(packet, queue=queue)
except ConnectionClosed: except ConnectionClosed:
......
...@@ -139,30 +139,34 @@ class ConnectionPool(object): ...@@ -139,30 +139,34 @@ class ConnectionPool(object):
return result return result
@profiler_decorator @profiler_decorator
def getConnForCell(self, cell, wait_ready=False): def getConnForCell(self, cell):
return self.getConnForNode(cell.getNode(), wait_ready=wait_ready) return self.getConnForNode(cell.getNode())
def iterateForObject(self, object_id, readable=False, writable=False, def iterateForObject(self, object_id, readable=False, writable=False):
wait_ready=False): """ Iterate over nodes managing an object """
""" Iterate over nodes responsible of a object by it's ID """
pt = self.app.getPartitionTable() pt = self.app.getPartitionTable()
cell_list = pt.getCellListForOID(object_id, readable, writable) cell_list = pt.getCellListForOID(object_id, readable, writable)
yielded = 0 if not cell_list:
if cell_list: raise NEOStorageError('no storage available')
getConnForNode = self.getConnForNode
while cell_list:
new_cell_list = []
shuffle(cell_list) shuffle(cell_list)
cell_list.sort(key=self.getCellSortKey) cell_list.sort(key=self.getCellSortKey)
getConnForNode = self.getConnForNode
for cell in cell_list: for cell in cell_list:
node = cell.getNode() node = cell.getNode()
conn = getConnForNode(node, wait_ready=wait_ready) conn = getConnForNode(node)
if conn is not None: if conn is not None:
yielded += 1
yield (node, conn) yield (node, conn)
if not yielded: else:
raise NEOStorageError('no storage available') new_cell_list.append(cell)
cell_list = new_cell_list
if new_cell_list:
# wait a bit to avoid a busy loop
time.sleep(1)
@profiler_decorator @profiler_decorator
def getConnForNode(self, node, wait_ready=True): def getConnForNode(self, node):
"""Return a locked connection object to a given node """Return a locked connection object to a given node
If no connection exists, create a new one""" If no connection exists, create a new one"""
if not node.isRunning(): if not node.isRunning():
...@@ -180,9 +184,6 @@ class ConnectionPool(object): ...@@ -180,9 +184,6 @@ class ConnectionPool(object):
# Create new connection to node # Create new connection to node
while True: while True:
conn = self._initNodeConnection(node) conn = self._initNodeConnection(node)
if conn is NOT_READY and wait_ready:
time.sleep(1)
continue
if conn not in (None, NOT_READY): if conn not in (None, NOT_READY):
self.connection_dict[uuid] = conn self.connection_dict[uuid] = conn
return conn return conn
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest import unittest
from mock import Mock from mock import Mock, ReturnValues
from neo.tests import NeoUnitTestBase from neo.tests import NeoUnitTestBase
from neo.client.app import ConnectionPool from neo.client.app import ConnectionPool
...@@ -78,7 +78,7 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -78,7 +78,7 @@ class ConnectionPoolTests(NeoUnitTestBase):
self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next) self.assertRaises(NEOStorageError, pool.iterateForObject(oid).next)
def test_iterateForObject_connectionRefused(self): def test_iterateForObject_connectionRefused(self):
# connection refused # connection refused at the first try
oid = self.getOID(1) oid = self.getOID(1)
node = Mock({'__repr__': 'node'}) node = Mock({'__repr__': 'node'})
cell = Mock({'__repr__': 'cell', 'getNode': node}) cell = Mock({'__repr__': 'cell', 'getNode': node})
...@@ -86,11 +86,11 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -86,11 +86,11 @@ class ConnectionPoolTests(NeoUnitTestBase):
pt = Mock({'getCellListForOID': [cell]}) pt = Mock({'getCellListForOID': [cell]})
app = Mock({'getPartitionTable': pt}) app = Mock({'getPartitionTable': pt})
pool = ConnectionPool(app) pool = ConnectionPool(app)
pool.getConnForNode = Mock({'__call__': None}) pool.getConnForNode = Mock({'__call__': ReturnValues(None, conn)})
self.assertRaises(StopIteration, pool.iterateForObject(oid).next) self.assertEqual(list(pool.iterateForObject(oid)), [(node, conn)])
def test_iterateForObject_connectionRefused(self): def test_iterateForObject_connectionAccepted(self):
# connection refused # connection accepted
oid = self.getOID(1) oid = self.getOID(1)
node = Mock({'__repr__': 'node'}) node = Mock({'__repr__': 'node'})
cell = Mock({'__repr__': 'cell', 'getNode': node}) cell = Mock({'__repr__': 'cell', 'getNode': node})
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment