#
# Copyright (C) 2009-2010  Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from time import time
from mock import Mock
from neo.connection import ListeningConnection, Connection, \
     ClientConnection, ServerConnection, MTClientConnection, \
     HandlerSwitcher, Timeout
from neo.connector import getConnectorHandler, registerConnectorHandler
from neo.tests import DoNothingConnector
from neo.connector import ConnectorException, ConnectorTryAgainException, \
     ConnectorInProgressException, ConnectorConnectionRefusedException
from neo.protocol import Packets
from neo.tests import NeoTestBase

class ConnectionTests(NeoTestBase):

    def setUp(self):
        self.app = Mock({'__repr__': 'Fake App'})
        self.em = Mock({'__repr__': 'Fake Em'})
        self.handler = Mock({'__repr__': 'Fake Handler'})
        self.address = ("127.0.0.7", 93413)

    def tearDown(self):
        pass

    def _makeListeningConnection(self, addr):
        # create instance after monkey patches
        self.connector = DoNothingConnector()
        return ListeningConnection(event_manager=self.em, handler=self.handler,
                connector=self.connector, addr=addr)

    def _makeConnection(self):
        self.connector = DoNothingConnector()
        return Connection(event_manager=self.em, handler=self.handler,
                connector=self.connector, addr=self.address)

    def _makeClientConnection(self):
        self.connector = DoNothingConnector()
        return ClientConnection(event_manager=self.em, handler=self.handler,
                connector=self.connector, addr=self.address)

    def _makeServerConnection(self):
        self.connector = DoNothingConnector()
        return ServerConnection(event_manager=self.em, handler=self.handler,
                connector=self.connector, addr=self.address)

    def _checkRegistered(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("register")), n)

    def _checkUnregistered(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("unregister")), n)

    def _checkReaderAdded(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("addReader")), n)

    def _checkReaderRemoved(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("removeReader")), n)

    def _checkWriterAdded(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("addWriter")), n)

    def _checkWriterRemoved(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("removeWriter")), n)

    def _checkShutdown(self, n=1):
        self.assertEquals(len(self.connector.mockGetNamedCalls("shutdown")), n)

    def _checkClose(self, n=1):
        self.assertEquals(len(self.connector.mockGetNamedCalls("close")), n)

    def _checkGetNewConnection(self, n=1):
        calls = self.connector.mockGetNamedCalls('getNewConnection')
        self.assertEqual(len(calls), n)

    def _checkSend(self, n=1, data=None):
        calls = self.connector.mockGetNamedCalls('send')
        self.assertEqual(len(calls), n)
        if n > 1 and data is not None:
            data = calls[n-1].getParam(0)
            self.assertEquals(data, "testdata")

    def _checkConnectionAccepted(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionAccepted')
        self.assertEqual(len(calls), n)

    def _checkConnectionFailed(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionFailed')
        self.assertEqual(len(calls), n)

    def _checkConnectionClosed(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionClosed')
        self.assertEqual(len(calls), n)

    def _checkConnectionStarted(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionStarted')
        self.assertEqual(len(calls), n)

    def _checkConnectionCompleted(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionCompleted')
        self.assertEqual(len(calls), n)

    def _checkMakeListeningConnection(self, n=1):
        calls = self.connector.mockGetNamedCalls('makeListeningConnection')
        self.assertEqual(len(calls), n)

    def _checkMakeClientConnection(self, n=1):
        calls = self.connector.mockGetNamedCalls("makeClientConnection")
        self.assertEqual(len(calls), n)
        self.assertEqual(calls[n-1].getParam(0), self.address)

    def _checkPacketReceived(self, n=1):
        calls = self.handler.mockGetNamedCalls('packetReceived')
        self.assertEquals(len(calls), n)

    def _checkReadBuf(self, bc, data):
        content = bc.read_buf.read(len(bc.read_buf))
        self.assertEqual(''.join(content), data)

    def _appendToReadBuf(self, bc, data):
        bc.read_buf.append(data)

    def _appendPacketToReadBuf(self, bc, packet):
        data = ''.join(packet.encode())
        bc.read_buf.append(data)

    def _checkWriteBuf(self, bc, data):
        self.assertEqual(''.join(bc.write_buf), data)

    def test_01_BaseConnection1(self):
        # init with connector
        registerConnectorHandler(DoNothingConnector)
        connector = getConnectorHandler("DoNothingConnector")()
        self.assertNotEqual(connector, None)
        bc = self._makeConnection()
        self.assertNotEqual(bc.connector, None)
        self._checkRegistered(1)

    def test_01_BaseConnection2(self):
        # init with address
        bc = self._makeConnection()
        self.assertEqual(bc.getAddress(), self.address)
        self._checkRegistered(1)

    def test_02_ListeningConnection1(self):
        # test init part
        def getNewConnection(self):
            return self, "127.0.0.1"
        DoNothingConnector.getNewConnection = getNewConnection
        addr = ("127.0.0.7", 93413)
        bc = self._makeListeningConnection(addr=addr)
        self.assertEqual(bc.getAddress(), addr)
        self._checkRegistered()
        self._checkReaderAdded()
        self._checkMakeListeningConnection()
        # test readable
        bc.readable()
        self._checkGetNewConnection()
        self._checkConnectionAccepted()

    def test_02_ListeningConnection2(self):
        # test with exception raise when getting new connection
        def getNewConnection(self):
            raise ConnectorTryAgainException
        DoNothingConnector.getNewConnection = getNewConnection
        addr = ("127.0.0.7", 93413)
        bc = self._makeListeningConnection(addr=addr)
        self.assertEqual(bc.getAddress(), addr)
        self._checkRegistered()
        self._checkReaderAdded()
        self._checkMakeListeningConnection()
        # test readable
        bc.readable()
        self._checkGetNewConnection(1)
        self._checkConnectionAccepted(0)

    def test_03_Connection(self):
        bc = self._makeConnection()
        self.assertEqual(bc.getAddress(), self.address)
        self._checkReaderAdded(1)
        self._checkReadBuf(bc, '')
        self._checkWriteBuf(bc, '')
        self.assertEqual(bc.cur_id, 0)
        self.assertEqual(bc.aborted, False)
        # test uuid
        self.assertEqual(bc.uuid, None)
        self.assertEqual(bc.getUUID(), None)
        uuid = self.getNewUUID()
        bc.setUUID(uuid)
        self.assertEqual(bc.getUUID(), uuid)
        # test next id
        cur_id = bc.cur_id
        next_id = bc._getNextId()
        self.assertEqual(next_id, cur_id)
        next_id = bc._getNextId()
        self.assertTrue(next_id > cur_id)
        # test overflow of next id
        bc.cur_id =  0xffffffff
        next_id = bc._getNextId()
        self.assertEqual(next_id, 0xffffffff)
        next_id = bc._getNextId()
        self.assertEqual(next_id, 0)
        # test abort
        bc.abort()
        self.assertEqual(bc.aborted, True)
        self.assertFalse(bc.isServer())

    def test_Connection_pending(self):
        bc = self._makeConnection()
        self.assertEqual(''.join(bc.write_buf), '')
        self.assertFalse(bc.pending())
        bc.write_buf += '1'
        self.assertTrue(bc.pending())

    def test_Connection_recv1(self):
        # patch receive method to return data
        def receive(self):
            return "testdata"
        DoNothingConnector.receive = receive
        bc = self._makeConnection()
        self._checkReadBuf(bc, '')
        bc._recv()
        self._checkReadBuf(bc, 'testdata')

    def test_Connection_recv2(self):
        # patch receive method to raise try again
        def receive(self):
            raise ConnectorTryAgainException
        DoNothingConnector.receive = receive
        bc = self._makeConnection()
        self._checkReadBuf(bc, '')
        bc._recv()
        self._checkReadBuf(bc, '')
        self._checkConnectionClosed(0)
        self._checkUnregistered(0)

    def test_Connection_recv3(self):
        # patch receive method to raise ConnectorConnectionRefusedException
        def receive(self):
            raise ConnectorConnectionRefusedException
        DoNothingConnector.receive = receive
        bc = self._makeConnection()
        self._checkReadBuf(bc, '')
        # fake client connection instance with connecting attribute
        bc.connecting = True
        bc._recv()
        self._checkReadBuf(bc, '')
        self._checkConnectionFailed(1)
        self._checkUnregistered(1)

    def test_Connection_recv4(self):
        # patch receive method to raise any other connector error
        def receive(self):
            raise ConnectorException
        DoNothingConnector.receive = receive
        bc = self._makeConnection()
        self._checkReadBuf(bc, '')
        self.assertRaises(ConnectorException, bc._recv)
        self._checkReadBuf(bc, '')
        self._checkConnectionClosed(1)
        self._checkUnregistered(1)

    def test_Connection_send1(self):
        # no data, nothing done
        # patch receive method to return data
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc._send()
        self._checkSend(0)
        self._checkConnectionClosed(0)
        self._checkUnregistered(0)

    def test_Connection_send2(self):
        # send all data
        def send(self, data):
            return len(data)
        DoNothingConnector.send = send
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata"]
        bc._send()
        self._checkSend(1, "testdata")
        self._checkWriteBuf(bc, '')
        self._checkConnectionClosed(0)
        self._checkUnregistered(0)

    def test_Connection_send3(self):
        # send part of the data
        def send(self, data):
            return len(data)/2
        DoNothingConnector.send = send
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata"]
        bc._send()
        self._checkSend(1, "testdata")
        self._checkWriteBuf(bc, 'data')
        self._checkConnectionClosed(0)
        self._checkUnregistered(0)

    def test_Connection_send4(self):
        # send multiple packet
        def send(self, data):
            return len(data)
        DoNothingConnector.send = send
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata", "second", "third"]
        bc._send()
        self._checkSend(1, "testdatasecondthird")
        self._checkWriteBuf(bc, '')
        self._checkConnectionClosed(0)
        self._checkUnregistered(0)

    def test_Connection_send5(self):
        # send part of multiple packet
        def send(self, data):
            return len(data)/2
        DoNothingConnector.send = send
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata", "second", "third"]
        bc._send()
        self._checkSend(1, "testdatasecondthird")
        self._checkWriteBuf(bc, 'econdthird')
        self._checkConnectionClosed(0)
        self._checkUnregistered(0)

    def test_Connection_send6(self):
        # raise try again
        def send(self, data):
            raise ConnectorTryAgainException
        DoNothingConnector.send = send
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata", "second", "third"]
        bc._send()
        self._checkSend(1, "testdatasecondthird")
        self._checkWriteBuf(bc, 'testdatasecondthird')
        self._checkConnectionClosed(0)
        self._checkUnregistered(0)

    def test_Connection_send7(self):
        # raise other error
        def send(self, data):
            raise ConnectorException
        DoNothingConnector.send = send
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata", "second", "third"]
        self.assertRaises(ConnectorException, bc._send)
        self._checkSend(1, "testdatasecondthird")
        # connection closed -> buffers flushed
        self._checkWriteBuf(bc, '')
        self._checkReaderRemoved(1)
        self._checkConnectionClosed(1)
        self._checkUnregistered(1)

    def test_07_Connection_addPacket(self):
        # new packet
        p = Mock({"encode" : "testdata"})
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc._addPacket(p)
        self._checkWriteBuf(bc, 'testdata')
        self._checkWriterAdded(1)

    def test_Connection_analyse1(self):
        # nothing to read, nothing is done
        bc = self._makeConnection()
        bc._queue = Mock()
        self._checkReadBuf(bc, '')
        bc.analyse()
        self._checkPacketReceived(0)
        self._checkReadBuf(bc, '')

        # give some data to analyse
        master_list = (
                (("127.0.0.1", 2135), self.getNewUUID()),
                (("127.0.0.1", 2135), self.getNewUUID()),
                (("127.0.0.1", 2235), self.getNewUUID()),
                (("127.0.0.1", 2134), self.getNewUUID()),
                (("127.0.0.1", 2335), self.getNewUUID()),
                (("127.0.0.1", 2133), self.getNewUUID()),
                (("127.0.0.1", 2435), self.getNewUUID()),
                (("127.0.0.1", 2132), self.getNewUUID()))
        p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
        p.setId(1)
        self._appendPacketToReadBuf(bc, p)
        bc.analyse()
        # check packet decoded
        self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
        call = bc._queue.mockGetNamedCalls("append")[0]
        data = call.getParam(0)
        self.assertEqual(data.getType(), p.getType())
        self.assertEqual(data.getId(), p.getId())
        self.assertEqual(data.decode(), p.decode())
        self._checkReadBuf(bc, '')

    def test_Connection_analyse2(self):
        # give multiple packet
        bc = self._makeConnection()
        bc._queue = Mock()
        # packet 1
        master_list = (
                (("127.0.0.1", 2135), self.getNewUUID()),
                (("127.0.0.1", 2135), self.getNewUUID()),
                (("127.0.0.1", 2235), self.getNewUUID()),
                (("127.0.0.1", 2134), self.getNewUUID()),
                (("127.0.0.1", 2335), self.getNewUUID()),
                (("127.0.0.1", 2133), self.getNewUUID()),
                (("127.0.0.1", 2435), self.getNewUUID()),
                (("127.0.0.1", 2132), self.getNewUUID()))
        p1 = Packets.AnswerPrimary(self.getNewUUID(), master_list)
        p1.setId(1)
        self._appendPacketToReadBuf(bc, p1)
        # packet 2
        master_list = (
                (("127.0.0.1", 2135), self.getNewUUID()),
                (("127.0.0.1", 2135), self.getNewUUID()),
                (("127.0.0.1", 2235), self.getNewUUID()),
                (("127.0.0.1", 2134), self.getNewUUID()),
                (("127.0.0.1", 2335), self.getNewUUID()),
                (("127.0.0.1", 2133), self.getNewUUID()),
                (("127.0.0.1", 2435), self.getNewUUID()),
                (("127.0.0.1", 2132), self.getNewUUID()))
        p2 = Packets.AnswerPrimary( self.getNewUUID(), master_list)
        p2.setId(2)
        self._appendPacketToReadBuf(bc, p2)
        self.assertEqual(len(bc.read_buf), len(p1) + len(p2))
        bc.analyse()
        # check two packets decoded
        self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 2)
        # packet 1
        call = bc._queue.mockGetNamedCalls("append")[0]
        data = call.getParam(0)
        self.assertEqual(data.getType(), p1.getType())
        self.assertEqual(data.getId(), p1.getId())
        self.assertEqual(data.decode(), p1.decode())
        # packet 2
        call = bc._queue.mockGetNamedCalls("append")[1]
        data = call.getParam(0)
        self.assertEqual(data.getType(), p2.getType())
        self.assertEqual(data.getId(), p2.getId())
        self.assertEqual(data.decode(), p2.decode())
        self._checkReadBuf(bc, '')

    def test_Connection_analyse3(self):
        # give a bad packet, won't be decoded
        bc = self._makeConnection()
        bc._queue = Mock()
        self._appendToReadBuf(bc, 'datadatadatadata')
        bc.analyse()
        self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 0)
        self.assertEquals(
            len(self.handler.mockGetNamedCalls("_packetMalformed")), 1)

    def test_Connection_analyse4(self):
        # give an expected packet
        bc = self._makeConnection()
        bc._queue = Mock()
        master_list = (
                (("127.0.0.1", 2135), self.getNewUUID()),
                (("127.0.0.1", 2135), self.getNewUUID()),
                (("127.0.0.1", 2235), self.getNewUUID()),
                (("127.0.0.1", 2134), self.getNewUUID()),
                (("127.0.0.1", 2335), self.getNewUUID()),
                (("127.0.0.1", 2133), self.getNewUUID()),
                (("127.0.0.1", 2435), self.getNewUUID()),
                (("127.0.0.1", 2132), self.getNewUUID()))
        p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
        p.setId(1)
        self._appendPacketToReadBuf(bc, p)
        bc.analyse()
        # check packet decoded
        self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
        call = bc._queue.mockGetNamedCalls("append")[0]
        data = call.getParam(0)
        self.assertEqual(data.getType(), p.getType())
        self.assertEqual(data.getId(), p.getId())
        self.assertEqual(data.decode(), p.decode())
        self._checkReadBuf(bc, '')

    def test_Connection_writable1(self):
        # with  pending operation after send
        def send(self, data):
            return len(data)/2
        DoNothingConnector.send = send
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata"]
        self.assertTrue(bc.pending())
        self.assertFalse(bc.aborted)
        bc.writable()
        # test send was called
        self._checkSend(1, "testdata")
        self.assertEqual(''.join(bc.write_buf), "data")
        self._checkConnectionClosed(0)
        self._checkUnregistered(0)
        # pending, so nothing called
        self.assertTrue(bc.pending())
        self.assertFalse(bc.aborted)
        self._checkWriterRemoved(0)
        self._checkReaderRemoved(0)
        self._checkShutdown(0)
        self._checkClose(0)

    def test_Connection_writable2(self):
        # with no longer pending operation after send
        def send(self, data):
            return len(data)
        DoNothingConnector.send = send
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata"]
        self.assertTrue(bc.pending())
        self.assertFalse(bc.aborted)
        bc.writable()
        # test send was called
        self._checkSend(1, "testdata")
        self._checkWriteBuf(bc, '')
        self._checkClose(0)
        self._checkUnregistered(0)
        # nothing else pending, and aborted is false, so writer has been removed
        self.assertFalse(bc.pending())
        self.assertFalse(bc.aborted)
        self._checkWriterRemoved(1)
        self._checkReaderRemoved(0)
        self._checkShutdown(0)
        self._checkClose(0)

    def test_Connection_writable3(self):
        # with no longer pending operation after send and aborted set to true
        def send(self, data):
            return len(data)
        DoNothingConnector.send = send
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata"]
        self.assertTrue(bc.pending())
        bc.abort()
        self.assertTrue(bc.aborted)
        bc.writable()
        # test send was called
        self._checkSend(1, "testdata")
        self._checkWriteBuf(bc, '')
        self._checkConnectionClosed(0)
        self._checkUnregistered(1)
        # nothing else pending, and aborted is false, so writer has been removed
        self.assertFalse(bc.pending())
        self.assertTrue(bc.aborted)
        self._checkWriterRemoved(1)
        self._checkReaderRemoved(1)
        self._checkShutdown(1)
        self._checkClose(1)

    def test_Connection_readable(self):
        # With aborted set to false
        # patch receive method to return data
        def receive(self):
            master_list = ((("127.0.0.1", 2135), self.getNewUUID()),
               (("127.0.0.1", 2136), self.getNewUUID()),
               (("127.0.0.1", 2235), self.getNewUUID()),
               (("127.0.0.1", 2134), self.getNewUUID()),
               (("127.0.0.1", 2335), self.getNewUUID()),
               (("127.0.0.1", 2133), self.getNewUUID()),
               (("127.0.0.1", 2435), self.getNewUUID()),
               (("127.0.0.1", 2132), self.getNewUUID()))
            uuid = self.getNewUUID()
            p = Packets.AnswerPrimary(uuid, master_list)
            p.setId(1)
            return ''.join(p.encode())
        DoNothingConnector.receive = receive
        bc = self._makeConnection()
        bc._queue = Mock()
        self._checkReadBuf(bc, '')
        self.assertFalse(bc.aborted)
        bc.readable()
        # check packet decoded
        self._checkReadBuf(bc, '')
        self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
        call = bc._queue.mockGetNamedCalls("append")[0]
        data = call.getParam(0)
        self.assertEqual(data.getType(), Packets.AnswerPrimary)
        self.assertEqual(data.getId(), 1)
        self._checkReadBuf(bc, '')
        # check not aborted
        self.assertFalse(bc.aborted)
        self._checkUnregistered(0)
        self._checkWriterRemoved(0)
        self._checkReaderRemoved(0)
        self._checkShutdown(0)
        self._checkClose(0)

    def test_ClientConnection_init1(self):
        # create a good client connection
        bc = self._makeClientConnection()
        # check connector created and connection initialize
        self.assertFalse(bc.connecting)
        self.assertFalse(bc.isServer())
        self._checkMakeClientConnection(1)
        # check call to handler
        self.assertNotEqual(bc.getHandler(), None)
        self._checkConnectionStarted(1)
        self._checkConnectionCompleted(1)
        self._checkConnectionFailed(0)
        # check call to event manager
        self.assertNotEqual(bc.getEventManager(), None)
        self._checkReaderAdded(1)
        self._checkWriterAdded(0)

    def test_ClientConnection_init2(self):
        # raise connection in progress
        makeClientConnection_org = DoNothingConnector.makeClientConnection
        def makeClientConnection(self, *args, **kw):
            raise ConnectorInProgressException
        DoNothingConnector.makeClientConnection = makeClientConnection
        try:
            bc = self._makeClientConnection()
        finally:
            DoNothingConnector.makeClientConnection = makeClientConnection_org
        # check connector created and connection initialize
        self.assertTrue(bc.connecting)
        self.assertFalse(bc.isServer())
        self._checkMakeClientConnection(1)
        # check call to handler
        self.assertNotEqual(bc.getHandler(), None)
        self._checkConnectionStarted(1)
        self._checkConnectionCompleted(0)
        self._checkConnectionFailed(0)
        # check call to event manager
        self.assertNotEqual(bc.getEventManager(), None)
        self._checkReaderAdded(1)
        self._checkWriterAdded(1)

    def test_ClientConnection_init3(self):
        # raise another error, connection must fail
        makeClientConnection_org = DoNothingConnector.makeClientConnection
        def makeClientConnection(self, *args, **kw):
            raise ConnectorException
        DoNothingConnector.makeClientConnection = makeClientConnection
        try:
            self.assertRaises(ConnectorException, self._makeClientConnection)
        finally:
            DoNothingConnector.makeClientConnection = makeClientConnection_org
        # since the exception was raised, the connection is not created
        # check call to handler
        self._checkConnectionStarted(1)
        self._checkConnectionCompleted(0)
        self._checkConnectionFailed(1)
        # check call to event manager
        self._checkReaderAdded(1)
        self._checkWriterAdded(0)

    def test_ClientConnection_writable1(self):
        # with a non connecting connection, will call parent's method
        def makeClientConnection(self, *args, **kw):
            return "OK"
        def send(self, data):
            return len(data)
        makeClientConnection_org = DoNothingConnector.makeClientConnection
        DoNothingConnector.send = send
        DoNothingConnector.makeClientConnection = makeClientConnection
        try:
            bc = self._makeClientConnection()
        finally:
            DoNothingConnector.makeClientConnection = makeClientConnection_org
        # check connector created and connection initialize
        self.assertFalse(bc.connecting)
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata"]
        self.assertTrue(bc.pending())
        self.assertFalse(bc.aborted)
        # call
        self._checkConnectionCompleted(1)
        self._checkReaderAdded(1)
        bc.writable()
        self.assertFalse(bc.pending())
        self.assertFalse(bc.aborted)
        self.assertFalse(bc.connecting)
        self._checkSend(1, "testdata")
        self._checkConnectionClosed(0)
        self._checkConnectionCompleted(1)
        self._checkConnectionFailed(0)
        self._checkUnregistered(0)
        self._checkReaderAdded(1)
        self._checkWriterRemoved(1)
        self._checkReaderRemoved(0)
        self._checkShutdown(0)
        self._checkClose(0)

    def test_ClientConnection_writable2(self):
        # with a connecting connection, must not call parent's method
        # with errors, close connection
        def getError(self):
            return True
        DoNothingConnector.getError = getError
        bc = self._makeClientConnection()
        # check connector created and connection initialize
        bc.connecting = True
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata"]
        self.assertTrue(bc.pending())
        self.assertFalse(bc.aborted)
        # call
        self._checkConnectionCompleted(1)
        self._checkReaderAdded(1)
        bc.writable()
        self.assertTrue(bc.connecting)
        self.assertFalse(bc.pending())
        self.assertFalse(bc.aborted)
        self._checkWriteBuf(bc, '')
        self._checkConnectionClosed(0)
        self._checkConnectionCompleted(1)
        self._checkConnectionFailed(1)
        self._checkUnregistered(1)
        self._checkReaderAdded(1)
        self._checkWriterRemoved(1)
        self._checkReaderRemoved(1)

    def test_14_ServerConnection(self):
        bc = self._makeServerConnection()
        self.assertEqual(bc.getAddress(), ("127.0.0.7", 93413))
        self._checkReaderAdded(1)
        self._checkReadBuf(bc, '')
        self._checkWriteBuf(bc, '')
        self.assertEqual(bc.cur_id, 0)
        self.assertEqual(bc.aborted, False)
        # test uuid
        self.assertEqual(bc.uuid, None)
        self.assertEqual(bc.getUUID(), None)
        uuid = self.getNewUUID()
        bc.setUUID(uuid)
        self.assertEqual(bc.getUUID(), uuid)
        # test next id
        cur_id = bc.cur_id
        next_id = bc._getNextId()
        self.assertEqual(next_id, cur_id)
        next_id = bc._getNextId()
        self.assertTrue(next_id > cur_id)
        # test overflow of next id
        bc.cur_id =  0xffffffff
        next_id = bc._getNextId()
        self.assertEqual(next_id, 0xffffffff)
        next_id = bc._getNextId()
        self.assertEqual(next_id, 0)
        # test abort
        bc.abort()
        self.assertEqual(bc.aborted, True)
        self.assertTrue(bc.isServer())


class HandlerSwitcherTests(NeoTestBase):

    def setUp(self):
        self._handler = handler = Mock({
            '__repr__': 'initial handler',
        })
        self._connection = connection = Mock({
            '__repr__': 'connection',
            'getAddress': ('127.0.0.1', 10000),
        })
        self._handlers = HandlerSwitcher(connection, handler)

    def _makeNotification(self, msg_id):
        packet = Packets.StartOperation()
        packet.setId(msg_id)
        return packet

    def _makeRequest(self, msg_id):
        packet = Packets.AskBeginTransaction(self.getNextTID())
        packet.setId(msg_id)
        return packet

    def _makeAnswer(self, msg_id):
        packet = Packets.AnswerBeginTransaction(self.getNextTID())
        packet.setId(msg_id)
        return packet

    def _makeHandler(self):
        return Mock({'__repr__': 'handler'})

    def _checkPacketReceived(self, handler, packet, index=0):
        calls = handler.mockGetNamedCalls('packetReceived')
        self.assertEqual(len(calls), index + 1)

    def _checkCurrentHandler(self, handler):
        self.assertTrue(self._handlers.getHandler() is handler)

    def testInit(self):
        self._checkCurrentHandler(self._handler)
        self.assertFalse(self._handlers.isPending())

    def testEmit(self):
        self.assertFalse(self._handlers.isPending())
        request = self._makeRequest(1)
        self._handlers.emit(request)
        self.assertTrue(self._handlers.isPending())

    def testHandleNotification(self):
        # handle with current handler
        notif1 = self._makeNotification(1)
        self._handlers.handle(notif1)
        self._checkPacketReceived(self._handler, notif1)
        # emit a request and delay an handler
        request = self._makeRequest(2)
        self._handlers.emit(request)
        handler = self._makeHandler()
        self._handlers.setHandler(handler)
        # next notification fall into the current handler
        notif2 = self._makeNotification(3)
        self._handlers.handle(notif2)
        self._checkPacketReceived(self._handler, notif2, index=1)
        # handle with new handler
        answer = self._makeAnswer(2)
        self._handlers.handle(answer)
        notif3 = self._makeNotification(4)
        self._handlers.handle(notif3)
        self._checkPacketReceived(handler, notif2)

    def testHandleAnswer1(self):
        # handle with current handler
        request = self._makeRequest(1)
        self._handlers.emit(request)
        answer = self._makeAnswer(1)
        self._handlers.handle(answer)
        self._checkPacketReceived(self._handler, answer)

    def testHandleAnswer2(self):
        # handle with blocking handler
        request = self._makeRequest(1)
        self._handlers.emit(request)
        handler = self._makeHandler()
        self._handlers.setHandler(handler)
        answer = self._makeAnswer(1)
        self._handlers.handle(answer)
        self._checkPacketReceived(self._handler, answer)
        self._checkCurrentHandler(handler)

    def testHandleAnswer3(self):
        # multiple setHandler
        r1 = self._makeRequest(1)
        r2 = self._makeRequest(2)
        r3 = self._makeRequest(3)
        a1 = self._makeAnswer(1)
        a2 = self._makeAnswer(2)
        a3 = self._makeAnswer(3)
        h1 = self._makeHandler()
        h2 = self._makeHandler()
        h3 = self._makeHandler()
        # emit all requests and setHandleres
        self._handlers.emit(r1)
        self._handlers.setHandler(h1)
        self._handlers.emit(r2)
        self._handlers.setHandler(h2)
        self._handlers.emit(r3)
        self._handlers.setHandler(h3)
        self._checkCurrentHandler(self._handler)
        self.assertTrue(self._handlers.isPending())
        # process answers
        self._handlers.handle(a1)
        self._checkCurrentHandler(h1)
        self._handlers.handle(a2)
        self._checkCurrentHandler(h2)
        self._handlers.handle(a3)
        self._checkCurrentHandler(h3)

    def testHandleAnswer4(self):
        # process in disorder
        r1 = self._makeRequest(1)
        r2 = self._makeRequest(2)
        r3 = self._makeRequest(3)
        a1 = self._makeAnswer(1)
        a2 = self._makeAnswer(2)
        a3 = self._makeAnswer(3)
        h = self._makeHandler()
        # emit all requests
        self._handlers.emit(r1)
        self._handlers.emit(r2)
        self._handlers.emit(r3)
        self._handlers.setHandler(h)
        # process answers
        self._handlers.handle(a1)
        self._checkCurrentHandler(self._handler)
        self._handlers.handle(a2)
        self._checkCurrentHandler(self._handler)
        self._handlers.handle(a3)
        self._checkCurrentHandler(h)

    def testHandleUnexpected(self):
        # process in disorder
        r1 = self._makeRequest(1)
        r2 = self._makeRequest(2)
        a2 = self._makeAnswer(2)
        h = self._makeHandler()
        # emit requests aroung state setHandler
        self._handlers.emit(r1)
        self._handlers.setHandler(h)
        self._handlers.emit(r2)
        # process answer for next state
        self._handlers.handle(a2)
        self.checkAborted(self._connection)


class TestTimeout(NeoTestBase):
    """ assume PING_DELAY=5 """

    def setUp(self):
        self.initial = time()
        self.current = self.initial
        self.timeout = Timeout()

    def checkAfter(self, n, soft, hard):
        at = self.current + n
        self.assertEqual(soft, self.timeout.softExpired(at))
        self.assertEqual(hard, self.timeout.hardExpired(at))

    def refreshAfter(self, n):
        self.current += n
        self.timeout.refresh(self.current)

    def testNoTimeout(self):
        self.timeout.update(self.initial, 5)
        self.checkAfter(1, False, False)
        self.checkAfter(4, False, False)
        self.refreshAfter(4) # answer received
        self.checkAfter(1, False, False)

    def testSoftTimeout(self):
        self.timeout.update(self.initial, 5)
        self.checkAfter(1, False, False)
        self.checkAfter(4, False, False)
        self.checkAfter(6, True, True) # ping
        self.refreshAfter(8) # pong
        self.checkAfter(1, False, False)
        self.checkAfter(4, False, True)

    def testHardTimeout(self):
        self.timeout.update(self.initial, 5)
        self.checkAfter(1, False, False)
        self.checkAfter(4, False, False)
        self.checkAfter(6, True, True) # ping
        self.refreshAfter(6) # pong
        self.checkAfter(1, False, False)
        self.checkAfter(4, False, False)
        self.checkAfter(6, False, True) # ping
        self.refreshAfter(6) # pong
        self.checkAfter(1, False, True) # too late
        self.checkAfter(5, False, True)

if __name__ == '__main__':
    unittest.main()