#
# Copyright (C) 2009-2010  Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.

import unittest
from mock import Mock
from collections import deque
from neo.tests import NeoTestBase
from neo.storage.app import Application
from neo.storage.handlers.master import MasterOperationHandler
from neo.exception import PrimaryFailure, OperationFailure
from neo.pt import PartitionTable
from neo.protocol import CellStates, ProtocolError
from neo.protocol import INVALID_TID, INVALID_OID

class StorageMasterHandlerTests(NeoTestBase):

    def checkHandleUnexpectedPacket(self, _call, _msg_type, _listening=True, **kwargs):
        conn = self.getMasterConnection(is_server=_listening)
        # hook
        self.operation.peerBroken = lambda c: c.peerBrokendCalled()
        self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs)

    def setUp(self):
        self.prepareDatabase(number=1)
        # create an application object
        config = self.getStorageConfiguration(master_number=1)
        self.app = Application(config)
        self.app.transaction_dict = {}
        self.app.store_lock_dict = {}
        self.app.load_lock_dict = {}
        self.app.event_queue = deque()
        # handler
        self.operation = MasterOperationHandler(self.app)
        # set pmn
        self.master_uuid = self.getNewUUID()
        pmn = self.app.nm.getMasterList()[0]
        pmn.setUUID(self.master_uuid)
        self.app.primary_master_node = pmn
        self.master_port = 10010

    def tearDown(self):
        NeoTestBase.tearDown(self)

    def getMasterConnection(self):
        address = ("127.0.0.1", self.master_port)
        return self.getFakeConnection(uuid=self.master_uuid, address=address)

    def test_06_timeoutExpired(self):
        # client connection
        conn = self.getMasterConnection()
        self.assertRaises(PrimaryFailure, self.operation.timeoutExpired, conn)
        self.checkNoPacketSent(conn)

    def test_07_connectionClosed2(self):
        # primary has closed the connection
        conn = self.getMasterConnection()
        self.assertRaises(PrimaryFailure, self.operation.connectionClosed, conn)
        self.checkNoPacketSent(conn)

    def test_08_peerBroken(self):
        # client connection
        conn = self.getMasterConnection()
        self.assertRaises(PrimaryFailure, self.operation.peerBroken, conn)
        self.checkNoPacketSent(conn)

    def test_14_notifyPartitionChanges1(self):
        # old partition change -> do nothing
        app = self.app
        conn = self.getMasterConnection()
        app.replicator = Mock({})
        self.app.pt = Mock({'getID': 1})
        count = len(self.app.nm.getList())
        self.operation.notifyPartitionChanges(conn, 0, ())
        self.assertEquals(self.app.pt.getID(), 1)
        self.assertEquals(len(self.app.nm.getList()), count)
        calls = self.app.replicator.mockGetNamedCalls('removePartition')
        self.assertEquals(len(calls), 0)
        calls = self.app.replicator.mockGetNamedCalls('addPartition')
        self.assertEquals(len(calls), 0)

    def test_14_notifyPartitionChanges2(self):
        # cases :
        uuid1, uuid2, uuid3 = [self.getNewUUID() for i in range(3)]
        cells = (
            (0, uuid1, CellStates.UP_TO_DATE),
            (1, uuid2, CellStates.DISCARDED),
            (2, uuid3, CellStates.OUT_OF_DATE),
        )
        # context
        conn = self.getMasterConnection()
        app = self.app
        # register nodes
        app.nm.createStorage(uuid=uuid1)
        app.nm.createStorage(uuid=uuid2)
        app.nm.createStorage(uuid=uuid3)
        ptid1, ptid2 = (1, 2)
        self.assertNotEquals(ptid1, ptid2)
        app.pt = PartitionTable(3, 1)
        app.dm = Mock({ })
        app.replicator = Mock({})
        self.operation.notifyPartitionChanges(conn, ptid2, cells)
        # ptid set
        self.assertEquals(app.pt.getID(), ptid2)
        # dm call
        calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
        self.assertEquals(len(calls), 1)
        calls[0].checkArgs(ptid2, cells)

    def test_16_stopOperation1(self):
        # OperationFailure
        conn = Mock({ 'isServer': False })
        self.assertRaises(OperationFailure, self.operation.stopOperation, conn)

    def _getConnection(self):
        return Mock({})

    def test_askLockInformation1(self):
        """ Unknown transaction """
        self.app.tm = Mock({'__contains__': False})
        conn = self._getConnection()
        tid = self.getNextTID()
        handler = self.operation
        self.assertRaises(ProtocolError, handler.askLockInformation, conn, tid)

    def test_askLockInformation2(self):
        """ Lock transaction """
        self.app.tm = Mock({'__contains__': True})
        conn = self._getConnection()
        tid = self.getNextTID()
        self.operation.askLockInformation(conn, tid)
        calls = self.app.tm.mockGetNamedCalls('lock')
        self.assertEqual(len(calls), 1)
        calls[0].checkArgs(tid)
        self.checkAnswerInformationLocked(conn)

    def test_notifyUnlockInformation1(self):
        """ Unknown transaction """
        self.app.tm = Mock({'__contains__': False})
        conn = self._getConnection()
        tid = self.getNextTID()
        handler = self.operation
        self.assertRaises(ProtocolError, handler.notifyUnlockInformation, 
                conn, tid)

    def test_notifyUnlockInformation2(self):
        """ Unlock transaction """
        self.app.tm = Mock({'__contains__': True})
        conn = self._getConnection()
        tid = self.getNextTID()
        self.operation.notifyUnlockInformation(conn, tid)
        calls = self.app.tm.mockGetNamedCalls('unlock')
        self.assertEqual(len(calls), 1)
        calls[0].checkArgs(tid)
        self.checkNoPacketSent(conn)

    def test_30_answerLastIDs(self):
        # set critical TID on replicator
        conn = Mock()
        self.app.replicator = Mock()
        self.operation.answerLastIDs(
            conn=conn,
            loid=INVALID_OID,
            ltid=INVALID_TID,
            lptid=INVALID_TID,
        )
        calls = self.app.replicator.mockGetNamedCalls('setCriticalTID')
        self.assertEquals(len(calls), 1)
        calls[0].checkArgs(conn.getUUID(), INVALID_TID)

    def test_31_answerUnfinishedTransactions(self):
        # set unfinished TID on replicator
        conn = Mock()
        self.app.replicator = Mock()
        self.operation.answerUnfinishedTransactions(
            conn=conn,
            tid_list=(INVALID_TID, ),
        )
        calls = self.app.replicator.mockGetNamedCalls('setUnfinishedTIDList')
        self.assertEquals(len(calls), 1)
        calls[0].checkArgs((INVALID_TID, ))

if __name__ == "__main__":
    unittest.main()