Commit ed50edca authored by Julien Muchembled's avatar Julien Muchembled

Simplify API to establish connections and accept mix of IPv4/IPv6

parent c2c97752
...@@ -21,7 +21,6 @@ from neo.lib.connection import ListeningConnection ...@@ -21,7 +21,6 @@ from neo.lib.connection import ListeningConnection
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
from .handler import AdminEventHandler, MasterEventHandler, \ from .handler import AdminEventHandler, MasterEventHandler, \
MasterRequestEventHandler MasterRequestEventHandler
from neo.lib.connector import getConnectorHandler
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.protocol import ClusterStates, Errors, \ from neo.lib.protocol import ClusterStates, Errors, \
...@@ -39,8 +38,7 @@ class Application(object): ...@@ -39,8 +38,7 @@ class Application(object):
self.name = config.getCluster() self.name = config.getCluster()
self.server = config.getBind() self.server = config.getBind()
self.master_addresses, connector_name = config.getMasters() self.master_addresses = config.getMasters()
self.connector_handler = getConnectorHandler(connector_name)
logging.debug('IP address is %s, port is %d', *self.server) logging.debug('IP address is %s, port is %d', *self.server)
# The partition table is initialized after getting the number of # The partition table is initialized after getting the number of
...@@ -87,8 +85,7 @@ class Application(object): ...@@ -87,8 +85,7 @@ class Application(object):
# Make a listening port. # Make a listening port.
handler = AdminEventHandler(self) handler = AdminEventHandler(self)
self.listening_conn = ListeningConnection(self.em, handler, self.listening_conn = ListeningConnection(self.em, handler, self.server)
addr=self.server, connector=self.connector_handler())
while self.cluster_state != ClusterStates.STOPPING: while self.cluster_state != ClusterStates.STOPPING:
self.connectToPrimary() self.connectToPrimary()
...@@ -120,7 +117,7 @@ class Application(object): ...@@ -120,7 +117,7 @@ class Application(object):
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, NodeTypes.ADMIN, bootstrap = BootstrapManager(self, self.name, NodeTypes.ADMIN,
self.uuid, self.server) self.uuid, self.server)
data = bootstrap.getPrimaryConnection(self.connector_handler) data = bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data (node, conn, uuid, num_partitions, num_replicas) = data
nm.update([(node.getType(), node.getAddress(), node.getUUID(), nm.update([(node.getType(), node.getAddress(), node.getUUID(),
NodeStates.RUNNING)]) NodeStates.RUNNING)])
......
...@@ -36,7 +36,6 @@ from neo.lib.util import makeChecksum, dump ...@@ -36,7 +36,6 @@ from neo.lib.util import makeChecksum, dump
from neo.lib.locking import Lock from neo.lib.locking import Lock
from neo.lib.connection import MTClientConnection, ConnectionClosed from neo.lib.connection import MTClientConnection, ConnectionClosed
from neo.lib.node import NodeManager from neo.lib.node import NodeManager
from neo.lib.connector import getConnectorHandler
from .exception import NEOStorageError, NEOStorageCreationUndoneError from .exception import NEOStorageError, NEOStorageCreationUndoneError
from .exception import NEOStorageNotFoundError from .exception import NEOStorageNotFoundError
from .handlers import storage, master from .handlers import storage, master
...@@ -80,8 +79,6 @@ class Application(object): ...@@ -80,8 +79,6 @@ class Application(object):
# Internal Attributes common to all thread # Internal Attributes common to all thread
self._db = None self._db = None
self.name = name self.name = name
master_addresses, connector_name = parseMasterList(master_nodes)
self.connector_handler = getConnectorHandler(connector_name)
self.dispatcher = Dispatcher(self.poll_thread) self.dispatcher = Dispatcher(self.poll_thread)
self.nm = NodeManager(dynamic_master_list) self.nm = NodeManager(dynamic_master_list)
self.cp = ConnectionPool(self) self.cp = ConnectionPool(self)
...@@ -90,7 +87,7 @@ class Application(object): ...@@ -90,7 +87,7 @@ class Application(object):
self.trying_master_node = None self.trying_master_node = None
# load master node list # load master node list
for address in master_addresses: for address in parseMasterList(master_nodes):
self.nm.createMaster(address=address) self.nm.createMaster(address=address)
# no self-assigned UUID, primary master will supply us one # no self-assigned UUID, primary master will supply us one
...@@ -290,7 +287,6 @@ class Application(object): ...@@ -290,7 +287,6 @@ class Application(object):
conn = MTClientConnection(self.em, conn = MTClientConnection(self.em,
self.notifications_handler, self.notifications_handler,
node=self.trying_master_node, node=self.trying_master_node,
connector=self.connector_handler(),
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
# Query for primary master node # Query for primary master node
if conn.getConnector() is None: if conn.getConnector() is None:
......
...@@ -54,7 +54,7 @@ class ConnectionPool(object): ...@@ -54,7 +54,7 @@ class ConnectionPool(object):
app = self.app app = self.app
logging.debug('trying to connect to %s - %s', node, node.getState()) logging.debug('trying to connect to %s - %s', node, node.getState())
conn = MTClientConnection(app.em, app.storage_event_handler, node, conn = MTClientConnection(app.em, app.storage_event_handler, node,
connector=app.connector_handler(), dispatcher=app.dispatcher) dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name) app.uuid, None, app.name)
try: try:
......
...@@ -116,7 +116,7 @@ class BootstrapManager(EventHandler): ...@@ -116,7 +116,7 @@ class BootstrapManager(EventHandler):
logging.info('Got a new UUID: %s', uuid_str(self.uuid)) logging.info('Got a new UUID: %s', uuid_str(self.uuid))
self.accepted = True self.accepted = True
def getPrimaryConnection(self, connector_handler): def getPrimaryConnection(self):
""" """
Primary lookup/connection process. Primary lookup/connection process.
Returns when the connection is made. Returns when the connection is made.
...@@ -140,8 +140,7 @@ class BootstrapManager(EventHandler): ...@@ -140,8 +140,7 @@ class BootstrapManager(EventHandler):
sleep(1) sleep(1)
if conn is None: if conn is None:
# open the connection # open the connection
conn = ClientConnection(em, self, self.current, conn = ClientConnection(em, self, self.current)
connector_handler())
# still processing # still processing
em.poll(1) em.poll(1)
return (self.current, conn, self.uuid, self.num_partitions, return (self.current, conn, self.uuid, self.num_partitions,
......
...@@ -206,6 +206,7 @@ class BaseConnection(object): ...@@ -206,6 +206,7 @@ class BaseConnection(object):
Timeouts in HandlerSwitcher are only there to prioritize some packets. Timeouts in HandlerSwitcher are only there to prioritize some packets.
""" """
from .connector import SocketConnector as ConnectorClass
KEEP_ALIVE = 60 KEEP_ALIVE = 60
def __init__(self, event_manager, handler, connector, addr=None): def __init__(self, event_manager, handler, connector, addr=None):
...@@ -318,19 +319,18 @@ attributeTracker.track(BaseConnection) ...@@ -318,19 +319,18 @@ attributeTracker.track(BaseConnection)
class ListeningConnection(BaseConnection): class ListeningConnection(BaseConnection):
"""A listen connection.""" """A listen connection."""
def __init__(self, event_manager, handler, addr, connector, **kw): def __init__(self, event_manager, handler, addr):
logging.debug('listening to %s:%d', *addr) logging.debug('listening to %s:%d', *addr)
BaseConnection.__init__(self, event_manager, handler, connector = self.ConnectorClass(addr)
addr=addr, connector=connector) BaseConnection.__init__(self, event_manager, handler, connector, addr)
self.connector.makeListeningConnection(addr) connector.makeListeningConnection()
def readable(self): def readable(self):
try: try:
new_s, addr = self.connector.getNewConnection() connector, addr = self.connector.accept()
logging.debug('accepted a connection from %s:%d', *addr) logging.debug('accepted a connection from %s:%d', *addr)
handler = self.getHandler() handler = self.getHandler()
new_conn = ServerConnection(self.em, handler, new_conn = ServerConnection(self.em, handler, connector, addr)
connector=new_s, addr=addr)
handler.connectionAccepted(new_conn) handler.connectionAccepted(new_conn)
except ConnectorTryAgainException: except ConnectorTryAgainException:
pass pass
...@@ -668,14 +668,15 @@ class ClientConnection(Connection): ...@@ -668,14 +668,15 @@ class ClientConnection(Connection):
connecting = True connecting = True
client = True client = True
def __init__(self, event_manager, handler, node, connector): def __init__(self, event_manager, handler, node):
addr = node.getAddress() addr = node.getAddress()
connector = self.ConnectorClass(addr)
Connection.__init__(self, event_manager, handler, connector, addr) Connection.__init__(self, event_manager, handler, connector, addr)
node.setConnection(self) node.setConnection(self)
handler.connectionStarted(self) handler.connectionStarted(self)
try: try:
try: try:
self.connector.makeClientConnection(addr) connector.makeClientConnection()
except ConnectorInProgressException: except ConnectorInProgressException:
event_manager.addWriter(self) event_manager.addWriter(self)
else: else:
......
...@@ -19,52 +19,51 @@ import errno ...@@ -19,52 +19,51 @@ import errno
# Global connector registry. # Global connector registry.
# Fill by calling registerConnectorHandler. # Fill by calling registerConnectorHandler.
# Read by calling getConnectorHandler. # Read by calling SocketConnector.__new__
connector_registry = {} connector_registry = {}
DEFAULT_CONNECTOR = 'SocketConnectorIPv4'
def registerConnectorHandler(connector_handler): def registerConnectorHandler(connector_handler):
connector_registry[connector_handler.__name__] = connector_handler connector_registry[connector_handler.af_type] = connector_handler
def getConnectorHandler(connector=None): class SocketConnector(object):
if connector is None:
connector = DEFAULT_CONNECTOR
if isinstance(connector, basestring):
connector_handler = connector_registry.get(connector)
else:
# Allow to directly provide a handler class without requiring to
# register it first.
connector_handler = connector
return connector_handler
class SocketConnector:
""" This class is a wrapper for a socket """ """ This class is a wrapper for a socket """
is_listening = False is_closed = is_server = None
remote_addr = None
is_closed = None
def __init__(self, s=None, accepted_from=None): def __new__(cls, addr, s=None):
self.accepted_from = accepted_from if s is None:
if accepted_from is not None: host, port = addr
self.remote_addr = accepted_from for af_type, cls in connector_registry.iteritems():
self.is_listening = False try :
self.is_closed = False socket.inet_pton(af_type, host)
break
except socket.error:
pass
else:
raise ValueError("Unknown type of host", host)
self = object.__new__(cls)
self.addr = cls._normAddress(addr)
if s is None: if s is None:
self.socket = socket.socket(self.af_type, socket.SOCK_STREAM) s = socket.socket(af_type, socket.SOCK_STREAM)
else: else:
self.socket = s self.is_server = True
self.socket_fd = self.socket.fileno() self.is_closed = False
self.socket = s
self.socket_fd = s.fileno()
# always use non-blocking sockets # always use non-blocking sockets
self.socket.setblocking(0) s.setblocking(0)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return self
def makeClientConnection(self, addr): # Threaded tests monkey-patch the following 2 operations.
self.is_closed = False _connect = lambda self, addr: self.socket.connect(addr)
self.remote_addr = addr _bind = lambda self, addr: self.socket.bind(addr)
def makeClientConnection(self):
assert self.is_closed is None
self.is_server = self.is_closed = False
try: try:
self.socket.connect(addr) self._connect(self.addr)
except socket.error, (err, errmsg): except socket.error, (err, errmsg):
if err == errno.EINPROGRESS: if err == errno.EINPROGRESS:
raise ConnectorInProgressException raise ConnectorInProgressException
...@@ -73,12 +72,12 @@ class SocketConnector: ...@@ -73,12 +72,12 @@ class SocketConnector:
raise ConnectorException, 'makeClientConnection to %s failed:' \ raise ConnectorException, 'makeClientConnection to %s failed:' \
' %s:%s' % (addr, err, errmsg) ' %s:%s' % (addr, err, errmsg)
def makeListeningConnection(self, addr): def makeListeningConnection(self):
assert self.is_closed is None
self.is_closed = False self.is_closed = False
self.is_listening = True
try: try:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(addr) self._bind(self.addr)
self.socket.listen(5) self.socket.listen(5)
except socket.error, (err, errmsg): except socket.error, (err, errmsg):
self.socket.close() self.socket.close()
...@@ -94,15 +93,22 @@ class SocketConnector: ...@@ -94,15 +93,22 @@ class SocketConnector:
# in epoll # in epoll
return self.socket_fd return self.socket_fd
def getNewConnection(self): @staticmethod
def _normAddress(addr):
return addr
def getAddress(self):
return self._normAddress(self.socket.getsockname())
def accept(self):
try: try:
(new_s, addr) = self._accept() s, addr = self.socket.accept()
new_s = self.__class__(new_s, accepted_from=addr) s = self.__class__(addr, s)
return (new_s, addr) return s, s.addr
except socket.error, (err, errmsg): except socket.error, (err, errmsg):
if err == errno.EAGAIN: if err == errno.EAGAIN:
raise ConnectorTryAgainException raise ConnectorTryAgainException
raise ConnectorException, 'getNewConnection failed: %s:%s' % \ raise ConnectorException, 'accept failed: %s:%s' % \
(err, errmsg) (err, errmsg)
def receive(self): def receive(self):
...@@ -139,14 +145,14 @@ class SocketConnector: ...@@ -139,14 +145,14 @@ class SocketConnector:
state = 'closed ' state = 'closed '
else: else:
state = 'opened ' state = 'opened '
if self.is_listening: if self.is_server is None:
state += 'listening' state += 'listening'
else: else:
if self.accepted_from is None: if self.is_server:
state += 'to '
else:
state += 'from ' state += 'from '
state += str(self.remote_addr) else:
state += 'to '
state += str(self.addr)
return '<%s at 0x%x fileno %s %s, %s>' % (self.__class__.__name__, return '<%s at 0x%x fileno %s %s, %s>' % (self.__class__.__name__,
id(self), '?' if self.is_closed else self.socket_fd, id(self), '?' if self.is_closed else self.socket_fd,
self.getAddress(), state) self.getAddress(), state)
...@@ -155,22 +161,13 @@ class SocketConnectorIPv4(SocketConnector): ...@@ -155,22 +161,13 @@ class SocketConnectorIPv4(SocketConnector):
" Wrapper for IPv4 sockets" " Wrapper for IPv4 sockets"
af_type = socket.AF_INET af_type = socket.AF_INET
def _accept(self):
return self.socket.accept()
def getAddress(self):
return self.socket.getsockname()
class SocketConnectorIPv6(SocketConnector): class SocketConnectorIPv6(SocketConnector):
" Wrapper for IPv6 sockets" " Wrapper for IPv6 sockets"
af_type = socket.AF_INET6 af_type = socket.AF_INET6
def _accept(self): @staticmethod
new_s, addr = self.socket.accept() def _normAddress(addr):
return new_s, addr[:2] return addr[:2]
def getAddress(self):
return self.socket.getsockname()[:2]
registerConnectorHandler(SocketConnectorIPv4) registerConnectorHandler(SocketConnectorIPv4)
registerConnectorHandler(SocketConnectorIPv6) registerConnectorHandler(SocketConnectorIPv6)
......
...@@ -19,12 +19,8 @@ import sys ...@@ -19,12 +19,8 @@ import sys
import traceback import traceback
from cStringIO import StringIO from cStringIO import StringIO
from struct import Struct from struct import Struct
try:
from .util import getAddressType
except ImportError:
pass
PROTOCOL_VERSION = 2 PROTOCOL_VERSION = 3
# Size restrictions. # Size restrictions.
MIN_PACKET_SIZE = 10 MIN_PACKET_SIZE = 10
...@@ -449,65 +445,6 @@ class PEnum(PStructItem): ...@@ -449,65 +445,6 @@ class PEnum(PStructItem):
enum = self._enum.__class__.__name__ enum = self._enum.__class__.__name__
raise ValueError, 'Invalid code for %s enum: %r' % (enum, code) raise ValueError, 'Invalid code for %s enum: %r' % (enum, code)
class PAddressIPGeneric(PStructItem):
def __init__(self, name, format):
PStructItem.__init__(self, name, format)
def encode(self, writer, address):
host, port = address
host = socket.inet_pton(self.af_type, host)
writer(self.pack(host, port))
def decode(self, reader):
data = reader(self.size)
address = self.unpack(data)
host, port = address
host = socket.inet_ntop(self.af_type, host)
return (host, port)
class PAddressIPv4(PAddressIPGeneric):
af_type = socket.AF_INET
def __init__(self, name):
PAddressIPGeneric.__init__(self, name, '!4sH')
class PAddressIPv6(PAddressIPGeneric):
af_type = socket.AF_INET6
def __init__(self, name):
PAddressIPGeneric.__init__(self, name, '!16sH')
class PAddress(PStructItem):
"""
An host address (IPv4/IPv6)
"""
address_format_dict = {
socket.AF_INET: PAddressIPv4('ipv4'),
socket.AF_INET6: PAddressIPv6('ipv6'),
}
def __init__(self, name):
PStructItem.__init__(self, name, '!L')
def _encode(self, writer, address):
if address is None:
writer(self.pack(INVALID_ADDRESS_TYPE))
return
af_type = getAddressType(address)
writer(self.pack(af_type))
encoder = self.address_format_dict[af_type]
encoder.encode(writer, address)
def _decode(self, reader):
af_type = self.unpack(reader(self.size))[0]
if af_type == INVALID_ADDRESS_TYPE:
return None
decoder = self.address_format_dict[af_type]
host, port = decoder.decode(reader)
return (host, port)
class PString(PStructItem): class PString(PStructItem):
""" """
A variable-length string A variable-length string
...@@ -523,6 +460,29 @@ class PString(PStructItem): ...@@ -523,6 +460,29 @@ class PString(PStructItem):
length = self.unpack(reader(self.size))[0] length = self.unpack(reader(self.size))[0]
return reader(length) return reader(length)
class PAddress(PString):
"""
An host address (IPv4/IPv6)
"""
def __init__(self, name):
PString.__init__(self, name)
self._port = Struct('!H')
def _encode(self, writer, address):
if address:
host, port = address
PString._encode(self, writer, host)
writer(self._port.pack(port))
else:
PString._encode(self, writer, '')
def _decode(self, reader):
host = PString._decode(self, reader)
if host:
p = self._port
return host, p.unpack(reader(p.size))[0]
class PBoolean(PStructItem): class PBoolean(PStructItem):
""" """
A boolean value, encoded as a single byte A boolean value, encoded as a single byte
......
...@@ -23,11 +23,6 @@ from Queue import deque ...@@ -23,11 +23,6 @@ from Queue import deque
from struct import pack, unpack from struct import pack, unpack
from time import gmtime from time import gmtime
SOCKET_CONNECTORS_DICT = {
socket.AF_INET : 'SocketConnectorIPv4',
socket.AF_INET6: 'SocketConnectorIPv6',
}
TID_LOW_OVERFLOW = 2**32 TID_LOW_OVERFLOW = 2**32
TID_LOW_MAX = TID_LOW_OVERFLOW - 1 TID_LOW_MAX = TID_LOW_OVERFLOW - 1
SECOND_PER_TID_LOW = 60.0 / TID_LOW_OVERFLOW SECOND_PER_TID_LOW = 60.0 / TID_LOW_OVERFLOW
...@@ -125,25 +120,6 @@ def makeChecksum(s): ...@@ -125,25 +120,6 @@ def makeChecksum(s):
return sha1(s).digest() return sha1(s).digest()
def getAddressType(address):
"Return the type (IPv4 or IPv6) of an ip"
(host, port) = address
for af_type in SOCKET_CONNECTORS_DICT:
try :
socket.inet_pton(af_type, host)
except:
continue
else:
break
else:
raise ValueError("Unknown type of host", host)
return af_type
def getConnectorFromAddress(address):
address_type = getAddressType(address)
return SOCKET_CONNECTORS_DICT[address_type]
def parseNodeAddress(address, port_opt=None): def parseNodeAddress(address, port_opt=None):
if address[:1] == '[': if address[:1] == '[':
(host, port) = address[1:].split(']') (host, port) = address[1:].split(']')
...@@ -164,24 +140,12 @@ def parseNodeAddress(address, port_opt=None): ...@@ -164,24 +140,12 @@ def parseNodeAddress(address, port_opt=None):
def parseMasterList(masters, except_node=None): def parseMasterList(masters, except_node=None):
assert masters, 'At least one master must be defined' assert masters, 'At least one master must be defined'
# load master node list
socket_connector = None
master_node_list = [] master_node_list = []
for node in masters.split(' '): for node in masters.split():
if not node:
continue
address = parseNodeAddress(node) address = parseNodeAddress(node)
if address != except_node:
if (address != except_node):
master_node_list.append(address) master_node_list.append(address)
return master_node_list
socket_connector_temp = getConnectorFromAddress(address)
if socket_connector is None:
socket_connector = socket_connector_temp
elif socket_connector != socket_connector_temp:
raise TypeError("Wrong connector type : you're trying to use "
"ipv6 and ipv4 simultaneously")
return master_node_list, socket_connector
class ReadBuffer(object): class ReadBuffer(object):
......
...@@ -18,7 +18,6 @@ import sys, weakref ...@@ -18,7 +18,6 @@ import sys, weakref
from time import time from time import time
from neo.lib import logging from neo.lib import logging
from neo.lib.connector import getConnectorHandler
from neo.lib.debug import register as registerLiveDebugger from neo.lib.debug import register as registerLiveDebugger
from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
...@@ -59,9 +58,7 @@ class Application(object): ...@@ -59,9 +58,7 @@ class Application(object):
self.autostart = config.getAutostart() self.autostart = config.getAutostart()
self.storage_readiness = set() self.storage_readiness = set()
master_addresses, connector_name = config.getMasters() for master_address in config.getMasters():
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses:
self.nm.createMaster(address=master_address) self.nm.createMaster(address=master_address)
logging.debug('IP address is %s, port is %d', *self.server) logging.debug('IP address is %s, port is %d', *self.server)
...@@ -102,7 +99,7 @@ class Application(object): ...@@ -102,7 +99,7 @@ class Application(object):
raise ValueError("upstream cluster name must be" raise ValueError("upstream cluster name must be"
" different from cluster name") " different from cluster name")
self.backup_app = BackupApplication(self, upstream_cluster, self.backup_app = BackupApplication(self, upstream_cluster,
*config.getUpstreamMasters()) config.getUpstreamMasters())
self.administration_handler = administration.AdministrationHandler( self.administration_handler = administration.AdministrationHandler(
self) self)
...@@ -141,8 +138,7 @@ class Application(object): ...@@ -141,8 +138,7 @@ class Application(object):
def _run(self): def _run(self):
"""Make sure that the status is sane and start a loop.""" """Make sure that the status is sane and start a loop."""
# Make a listening port. # Make a listening port.
self.listening_conn = ListeningConnection(self.em, None, self.listening_conn = ListeningConnection(self.em, None, self.server)
addr=self.server, connector=self.connector_handler())
# Start a normal operation. # Start a normal operation.
while self.cluster_state != ClusterStates.STOPPING: while self.cluster_state != ClusterStates.STOPPING:
...@@ -196,8 +192,7 @@ class Application(object): ...@@ -196,8 +192,7 @@ class Application(object):
ClientConnection(self.em, client_handler, ClientConnection(self.em, client_handler,
# XXX: Ugly, but the whole election code will be # XXX: Ugly, but the whole election code will be
# replaced soon # replaced soon
node=getByAddress(addr), getByAddress(addr))
connector=self.connector_handler())
self.unconnected_master_node_set.clear() self.unconnected_master_node_set.clear()
self.em.poll(1) self.em.poll(1)
except ElectionFailure, m: except ElectionFailure, m:
...@@ -381,9 +376,7 @@ class Application(object): ...@@ -381,9 +376,7 @@ class Application(object):
# Reconnect to primary master node. # Reconnect to primary master node.
primary_handler = secondary.PrimaryHandler(self) primary_handler = secondary.PrimaryHandler(self)
ClientConnection(self.em, primary_handler, ClientConnection(self.em, primary_handler, self.primary_master_node)
node=self.primary_master_node,
connector=self.connector_handler())
# and another for the future incoming connections # and another for the future incoming connections
self.listening_conn.setHandler( self.listening_conn.setHandler(
......
...@@ -19,7 +19,6 @@ from bisect import bisect ...@@ -19,7 +19,6 @@ from bisect import bisect
from collections import defaultdict from collections import defaultdict
from neo.lib import logging from neo.lib import logging
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from neo.lib.connector import getConnectorHandler
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.node import NodeManager from neo.lib.node import NodeManager
...@@ -67,11 +66,10 @@ class BackupApplication(object): ...@@ -67,11 +66,10 @@ class BackupApplication(object):
pt = None pt = None
def __init__(self, app, name, master_addresses, connector_name): def __init__(self, app, name, master_addresses):
self.app = weakref.proxy(app) self.app = weakref.proxy(app)
self.name = name self.name = name
self.nm = NodeManager() self.nm = NodeManager()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses: for master_address in master_addresses:
self.nm.createMaster(address=master_address) self.nm.createMaster(address=master_address)
...@@ -107,7 +105,7 @@ class BackupApplication(object): ...@@ -107,7 +105,7 @@ class BackupApplication(object):
break break
poll(1) poll(1)
node, conn, uuid, num_partitions, num_replicas = \ node, conn, uuid, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection(self.connector_handler) bootstrap.getPrimaryConnection()
try: try:
app.changeClusterState(ClusterStates.BACKINGUP) app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node del bootstrap, node
......
...@@ -14,11 +14,9 @@ ...@@ -14,11 +14,9 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib.connector import getConnectorHandler
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.protocol import ClusterStates, NodeStates, ErrorCodes, Packets from neo.lib.protocol import ClusterStates, NodeStates, ErrorCodes, Packets
from neo.lib.util import getConnectorFromAddress
from neo.lib.node import NodeManager from neo.lib.node import NodeManager
from .handler import CommandEventHandler from .handler import CommandEventHandler
...@@ -31,8 +29,6 @@ class NeoCTL(object): ...@@ -31,8 +29,6 @@ class NeoCTL(object):
connected = False connected = False
def __init__(self, address): def __init__(self, address):
connector_name = getConnectorFromAddress(address)
self.connector_handler = getConnectorHandler(connector_name)
self.nm = nm = NodeManager() self.nm = nm = NodeManager()
self.server = nm.createAdmin(address=address) self.server = nm.createAdmin(address=address)
self.em = EventManager() self.em = EventManager()
...@@ -47,7 +43,7 @@ class NeoCTL(object): ...@@ -47,7 +43,7 @@ class NeoCTL(object):
def __getConnection(self): def __getConnection(self):
if not self.connected: if not self.connected:
self.connection = ClientConnection(self.em, self.handler, self.connection = ClientConnection(self.em, self.handler,
node=self.server, connector=self.connector_handler()) self.server)
while not self.connected: while not self.connected:
self.em.poll(1) self.em.poll(1)
if self.connection is None: if self.connection is None:
......
...@@ -24,7 +24,6 @@ from neo.lib.node import NodeManager ...@@ -24,7 +24,6 @@ from neo.lib.node import NodeManager
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.connection import ListeningConnection from neo.lib.connection import ListeningConnection
from neo.lib.exception import OperationFailure, PrimaryFailure from neo.lib.exception import OperationFailure, PrimaryFailure
from neo.lib.connector import getConnectorHandler
from neo.lib.pt import PartitionTable from neo.lib.pt import PartitionTable
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
...@@ -54,9 +53,7 @@ class Application(object): ...@@ -54,9 +53,7 @@ class Application(object):
) )
# load master nodes # load master nodes
master_addresses, connector_name = config.getMasters() for master_address in config.getMasters():
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses :
self.nm.createMaster(address=master_address) self.nm.createMaster(address=master_address)
# set the bind address # set the bind address
...@@ -177,8 +174,7 @@ class Application(object): ...@@ -177,8 +174,7 @@ class Application(object):
# Make a listening port # Make a listening port
handler = identification.IdentificationHandler(self) handler = identification.IdentificationHandler(self)
self.listening_conn = ListeningConnection(self.em, handler, self.listening_conn = ListeningConnection(self.em, handler, self.server)
addr=self.server, connector=self.connector_handler())
self.server = self.listening_conn.getAddress() self.server = self.listening_conn.getAddress()
# Connect to a primary master node, verify data, and # Connect to a primary master node, verify data, and
...@@ -234,7 +230,7 @@ class Application(object): ...@@ -234,7 +230,7 @@ class Application(object):
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, bootstrap = BootstrapManager(self, self.name,
NodeTypes.STORAGE, self.uuid, self.server) NodeTypes.STORAGE, self.uuid, self.server)
data = bootstrap.getPrimaryConnection(self.connector_handler) data = bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data (node, conn, uuid, num_partitions, num_replicas) = data
self.master_node = node self.master_node = node
self.master_conn = conn self.master_conn = conn
......
...@@ -46,7 +46,7 @@ class Checker(object): ...@@ -46,7 +46,7 @@ class Checker(object):
conn.asClient() conn.asClient()
else: else:
conn = ClientConnection(app.em, StorageOperationHandler(app), conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler()) node)
conn.ask(Packets.RequestIdentification( conn.ask(Packets.RequestIdentification(
NodeTypes.STORAGE, uuid, app.server, name)) NodeTypes.STORAGE, uuid, app.server, name))
self.conn_dict[conn] = node.isIdentified() self.conn_dict[conn] = node.isIdentified()
......
...@@ -254,8 +254,7 @@ class Replicator(object): ...@@ -254,8 +254,7 @@ class Replicator(object):
self.fetchTransactions() self.fetchTransactions()
else: else:
assert name or node.getUUID() != app.uuid, "loopback connection" assert name or node.getUUID() != app.uuid, "loopback connection"
conn = ClientConnection(app.em, StorageOperationHandler(app), conn = ClientConnection(app.em, StorageOperationHandler(app), node)
node=node, connector=app.connector_handler())
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name)) None if name else app.uuid, app.server, name or app.name))
if previous_node is not None and previous_node.isConnected(): if previous_node is not None and previous_node.isConnected():
......
...@@ -30,7 +30,6 @@ from functools import wraps ...@@ -30,7 +30,6 @@ from functools import wraps
from mock import Mock from mock import Mock
from neo.lib import debug, logging, protocol from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
from neo.lib.util import getAddressType
from time import time from time import time
from struct import pack, unpack from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess from unittest.case import _ExpectedFailure, _UnexpectedSuccess
...@@ -203,8 +202,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -203,8 +202,7 @@ class NeoUnitTestBase(NeoTestBase):
return Mock({ return Mock({
'getCluster': cluster, 'getCluster': cluster,
'getBind': masters[0], 'getBind': masters[0],
'getMasters': (masters, getAddressType(( 'getMasters': masters,
self.local_ip, 0))),
'getReplicas': replicas, 'getReplicas': replicas,
'getPartitions': partitions, 'getPartitions': partitions,
'getUUID': uuid, 'getUUID': uuid,
...@@ -226,8 +224,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -226,8 +224,7 @@ class NeoUnitTestBase(NeoTestBase):
return Mock({ return Mock({
'getCluster': cluster, 'getCluster': cluster,
'getBind': (masters[0], 10020 + index), 'getBind': (masters[0], 10020 + index),
'getMasters': (masters, getAddressType(( 'getMasters': masters,
self.local_ip, 0))),
'getDatabase': db, 'getDatabase': db,
'getUUID': uuid, 'getUUID': uuid,
'getReset': False, 'getReset': False,
...@@ -554,29 +551,5 @@ class Patch(object): ...@@ -554,29 +551,5 @@ class Patch(object):
self.__del__() self.__del__()
connector_cpt = 0
class DoNothingConnector(Mock):
def __init__(self, s=None):
logging.info("initializing connector")
global connector_cpt
self.desc = connector_cpt
connector_cpt += 1
self.packet_cpt = 0
Mock.__init__(self)
def getAddress(self):
return self.addr
def makeClientConnection(self, addr):
self.addr = addr
def makeListeningConnection(self, addr):
self.addr = addr
def getDescriptor(self):
return self.desc
__builtin__.pdb = lambda depth=0: \ __builtin__.pdb = lambda depth=0: \
debug.getPdb().set_trace(sys._getframe(depth+1)) debug.getPdb().set_trace(sys._getframe(depth+1))
...@@ -25,7 +25,7 @@ from neo.client.cache import test as testCache ...@@ -25,7 +25,7 @@ from neo.client.cache import test as testCache
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.lib.protocol import NodeTypes, Packets, Errors, \ from neo.lib.protocol import NodeTypes, Packets, Errors, \
INVALID_PARTITION, UUID_NAMESPACES INVALID_PARTITION, UUID_NAMESPACES
from neo.lib.util import makeChecksum, SOCKET_CONNECTORS_DICT from neo.lib.util import makeChecksum
import time import time
class Dispatcher(object): class Dispatcher(object):
...@@ -95,10 +95,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -95,10 +95,9 @@ class ClientApplicationTests(NeoUnitTestBase):
return txn_context return txn_context
def getApp(self, master_nodes=None, name='test', **kw): def getApp(self, master_nodes=None, name='test', **kw):
connector = SOCKET_CONNECTORS_DICT[ADDRESS_TYPE]
if master_nodes is None: if master_nodes is None:
master_nodes = '%s:10010' % buildUrlFromString(self.local_ip) master_nodes = '%s:10010' % buildUrlFromString(self.local_ip)
app = Application(master_nodes, name, connector, **kw) app = Application(master_nodes, name, **kw)
self._to_stop_list.append(app) self._to_stop_list.append(app)
app.dispatcher = Mock({ }) app.dispatcher = Mock({ })
return app return app
...@@ -750,7 +749,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -750,7 +749,6 @@ class ClientApplicationTests(NeoUnitTestBase):
# the third will not be ready # the third will not be ready
# after the third, the partition table will be operational # after the third, the partition table will be operational
# (as if it was connected to the primary master node) # (as if it was connected to the primary master node)
from .. import DoNothingConnector
# will raise IndexError at the third iteration # will raise IndexError at the third iteration
app = self.getApp('127.0.0.1:10010 127.0.0.1:10011') app = self.getApp('127.0.0.1:10010 127.0.0.1:10011')
# TODO: test more connection failure cases # TODO: test more connection failure cases
...@@ -797,7 +795,6 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -797,7 +795,6 @@ class ClientApplicationTests(NeoUnitTestBase):
app.nm.getByAddress(conn.getAddress())._connection = None app.nm.getByAddress(conn.getAddress())._connection = None
app._ask = _ask_base app._ask = _ask_base
# faked environnement # faked environnement
app.connector_handler = DoNothingConnector
app.em = Mock({'getConnectionList': []}) app.em = Mock({'getConnectionList': []})
app.pt = Mock({ 'operational': False}) app.pt = Mock({ 'operational': False})
app.master_conn = app._connectToPrimaryNode() app.master_conn = app._connectToPrimaryNode()
......
...@@ -17,17 +17,43 @@ ...@@ -17,17 +17,43 @@
import unittest import unittest
from time import time from time import time
from mock import Mock from mock import Mock
from neo.lib import connection from neo.lib import connection, logging
from neo.lib.connection import ListeningConnection, Connection, \ from neo.lib.connection import BaseConnection, ListeningConnection, \
ClientConnection, ServerConnection, MTClientConnection, \ Connection, ClientConnection, ServerConnection, MTClientConnection, \
HandlerSwitcher, CRITICAL_TIMEOUT HandlerSwitcher, CRITICAL_TIMEOUT
from neo.lib.connector import getConnectorHandler, registerConnectorHandler from neo.lib.connector import registerConnectorHandler
from . import DoNothingConnector
from neo.lib.connector import ConnectorException, ConnectorTryAgainException, \ from neo.lib.connector import ConnectorException, ConnectorTryAgainException, \
ConnectorInProgressException, ConnectorConnectionRefusedException ConnectorInProgressException, ConnectorConnectionRefusedException
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets, PACKET_HEADER_FORMAT from neo.lib.protocol import Packets, PACKET_HEADER_FORMAT
from . import NeoUnitTestBase from . import NeoUnitTestBase, Patch
connector_cpt = 0
class DummyConnector(Mock):
def __init__(self, addr, s=None):
logging.info("initializing connector")
global connector_cpt
self.desc = connector_cpt
connector_cpt += 1
self.packet_cpt = 0
self.addr = addr
Mock.__init__(self)
def getAddress(self):
return self.addr
def getDescriptor(self):
return self.desc
accept = getError = makeClientConnection = makeListeningConnection = \
receive = send = lambda *args, **kw: None
dummy_connector = Patch(BaseConnection,
ConnectorClass=lambda orig, self, *args, **kw: DummyConnector(*args, **kw))
class ConnectionTests(NeoUnitTestBase): class ConnectionTests(NeoUnitTestBase):
...@@ -41,25 +67,23 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -41,25 +67,23 @@ class ConnectionTests(NeoUnitTestBase):
connection.connect_limit = 0 connection.connect_limit = 0
def _makeListeningConnection(self, addr): def _makeListeningConnection(self, addr):
# create instance after monkey patches with dummy_connector:
self.connector = DoNothingConnector() conn = ListeningConnection(self.em, self.handler, addr)
return ListeningConnection(event_manager=self.em, handler=self.handler, self.connector = conn.connector
connector=self.connector, addr=addr) return conn
def _makeConnection(self): def _makeConnection(self):
self.connector = DoNothingConnector() addr = self.address
return Connection(event_manager=self.em, handler=self.handler, self.connector = DummyConnector(addr)
connector=self.connector, addr=self.address) return Connection(self.em, self.handler, self.connector, addr)
def _makeClientConnection(self): def _makeClientConnection(self):
self.connector = DoNothingConnector() with dummy_connector:
return ClientConnection(event_manager=self.em, handler=self.handler, conn = ClientConnection(self.em, self.handler, self.node)
connector=self.connector, node=self.node) self.connector = conn.connector
return conn
def _makeServerConnection(self): _makeServerConnection = _makeConnection
self.connector = DoNothingConnector()
return ServerConnection(event_manager=self.em, handler=self.handler,
connector=self.connector, addr=self.address)
def _checkRegistered(self, n=1): def _checkRegistered(self, n=1):
self.assertEqual(len(self.em.mockGetNamedCalls("register")), n) self.assertEqual(len(self.em.mockGetNamedCalls("register")), n)
...@@ -82,8 +106,8 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -82,8 +106,8 @@ class ConnectionTests(NeoUnitTestBase):
def _checkClose(self, n=1): def _checkClose(self, n=1):
self.assertEqual(len(self.connector.mockGetNamedCalls("close")), n) self.assertEqual(len(self.connector.mockGetNamedCalls("close")), n)
def _checkGetNewConnection(self, n=1): def _checkAccept(self, n=1):
calls = self.connector.mockGetNamedCalls('getNewConnection') calls = self.connector.mockGetNamedCalls('accept')
self.assertEqual(len(calls), n) self.assertEqual(len(calls), n)
def _checkSend(self, n=1, data=None): def _checkSend(self, n=1, data=None):
...@@ -120,7 +144,6 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -120,7 +144,6 @@ class ConnectionTests(NeoUnitTestBase):
def _checkMakeClientConnection(self, n=1): def _checkMakeClientConnection(self, n=1):
calls = self.connector.mockGetNamedCalls("makeClientConnection") calls = self.connector.mockGetNamedCalls("makeClientConnection")
self.assertEqual(len(calls), n) self.assertEqual(len(calls), n)
self.assertEqual(calls[n-1].getParam(0), self.address)
def _checkPacketReceived(self, n=1): def _checkPacketReceived(self, n=1):
calls = self.handler.mockGetNamedCalls('packetReceived') calls = self.handler.mockGetNamedCalls('packetReceived')
...@@ -140,28 +163,17 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -140,28 +163,17 @@ class ConnectionTests(NeoUnitTestBase):
def _checkWriteBuf(self, bc, data): def _checkWriteBuf(self, bc, data):
self.assertEqual(''.join(bc.write_buf), data) self.assertEqual(''.join(bc.write_buf), data)
def test_01_BaseConnection1(self): def test_01_BaseConnection(self):
# init with connector
registerConnectorHandler(DoNothingConnector)
connector = getConnectorHandler("DoNothingConnector")()
self.assertFalse(connector is None)
bc = self._makeConnection()
self.assertFalse(bc.connector is None)
self._checkRegistered(1)
def test_01_BaseConnection2(self):
# init with address # init with address
bc = self._makeConnection() bc = self._makeConnection()
self.assertEqual(bc.getAddress(), self.address) self.assertEqual(bc.getAddress(), self.address)
self.assertIsNot(bc.connector, None)
self._checkRegistered(1) self._checkRegistered(1)
def test_02_ListeningConnection1(self): def test_02_ListeningConnection1(self):
# test init part # test init part
def getNewConnection(self):
return self, ('', 0)
DoNothingConnector.getNewConnection = getNewConnection
addr = ("127.0.0.7", 93413) addr = ("127.0.0.7", 93413)
try: with Patch(DummyConnector, accept=lambda orig, self: (self, ('', 0))):
bc = self._makeListeningConnection(addr=addr) bc = self._makeListeningConnection(addr=addr)
self.assertEqual(bc.getAddress(), addr) self.assertEqual(bc.getAddress(), addr)
self._checkRegistered() self._checkRegistered()
...@@ -169,18 +181,15 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -169,18 +181,15 @@ class ConnectionTests(NeoUnitTestBase):
self._checkMakeListeningConnection() self._checkMakeListeningConnection()
# test readable # test readable
bc.readable() bc.readable()
self._checkGetNewConnection() self._checkAccept()
self._checkConnectionAccepted() self._checkConnectionAccepted()
finally:
del DoNothingConnector.getNewConnection
def test_02_ListeningConnection2(self): def test_02_ListeningConnection2(self):
# test with exception raise when getting new connection # test with exception raise when getting new connection
def getNewConnection(self): def accept(orig, self):
raise ConnectorTryAgainException raise ConnectorTryAgainException
DoNothingConnector.getNewConnection = getNewConnection
addr = ("127.0.0.7", 93413) addr = ("127.0.0.7", 93413)
try: with Patch(DummyConnector, accept=accept):
bc = self._makeListeningConnection(addr=addr) bc = self._makeListeningConnection(addr=addr)
self.assertEqual(bc.getAddress(), addr) self.assertEqual(bc.getAddress(), addr)
self._checkRegistered() self._checkRegistered()
...@@ -188,10 +197,8 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -188,10 +197,8 @@ class ConnectionTests(NeoUnitTestBase):
self._checkMakeListeningConnection() self._checkMakeListeningConnection()
# test readable # test readable
bc.readable() bc.readable()
self._checkGetNewConnection(1) self._checkAccept(1)
self._checkConnectionAccepted(0) self._checkConnectionAccepted(0)
finally:
del DoNothingConnector.getNewConnection
def test_03_Connection(self): def test_03_Connection(self):
bc = self._makeConnection() bc = self._makeConnection()
...@@ -229,38 +236,29 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -229,38 +236,29 @@ class ConnectionTests(NeoUnitTestBase):
def test_Connection_recv1(self): def test_Connection_recv1(self):
# patch receive method to return data # patch receive method to return data
def receive(self): with Patch(DummyConnector, receive=lambda orig, self: "testdata"):
return "testdata"
DoNothingConnector.receive = receive
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
bc._recv() bc._recv()
self._checkReadBuf(bc, 'testdata') self._checkReadBuf(bc, 'testdata')
finally:
del DoNothingConnector.receive
def test_Connection_recv2(self): def test_Connection_recv2(self):
# patch receive method to raise try again # patch receive method to raise try again
def receive(self): def receive(orig, self):
raise ConnectorTryAgainException raise ConnectorTryAgainException
DoNothingConnector.receive = receive with Patch(DummyConnector, receive=receive):
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
bc._recv() bc._recv()
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self._checkConnectionClosed(0) self._checkConnectionClosed(0)
self._checkUnregistered(0) self._checkUnregistered(0)
finally:
del DoNothingConnector.receive
def test_Connection_recv3(self): def test_Connection_recv3(self):
# patch receive method to raise ConnectorConnectionRefusedException # patch receive method to raise ConnectorConnectionRefusedException
def receive(self): def receive(orig, self):
raise ConnectorConnectionRefusedException raise ConnectorConnectionRefusedException
DoNothingConnector.receive = receive with Patch(DummyConnector, receive=receive):
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
# fake client connection instance with connecting attribute # fake client connection instance with connecting attribute
...@@ -269,23 +267,18 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -269,23 +267,18 @@ class ConnectionTests(NeoUnitTestBase):
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self._checkConnectionFailed(1) self._checkConnectionFailed(1)
self._checkUnregistered(1) self._checkUnregistered(1)
finally:
del DoNothingConnector.receive
def test_Connection_recv4(self): def test_Connection_recv4(self):
# patch receive method to raise any other connector error # patch receive method to raise any other connector error
def receive(self): def receive(orig, self):
raise ConnectorException raise ConnectorException
DoNothingConnector.receive = receive with Patch(DummyConnector, receive=receive):
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self.assertRaises(ConnectorException, bc._recv) self.assertRaises(ConnectorException, bc._recv)
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
self._checkConnectionClosed(1) self._checkConnectionClosed(1)
self._checkUnregistered(1) self._checkUnregistered(1)
finally:
del DoNothingConnector.receive
def test_Connection_send1(self): def test_Connection_send1(self):
# no data, nothing done # no data, nothing done
...@@ -299,10 +292,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -299,10 +292,7 @@ class ConnectionTests(NeoUnitTestBase):
def test_Connection_send2(self): def test_Connection_send2(self):
# send all data # send all data
def send(self, data): with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
return len(data)
DoNothingConnector.send = send
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"] bc.write_buf = ["testdata"]
...@@ -311,15 +301,10 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -311,15 +301,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
self._checkConnectionClosed(0) self._checkConnectionClosed(0)
self._checkUnregistered(0) self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send3(self): def test_Connection_send3(self):
# send part of the data # send part of the data
def send(self, data): with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
return len(data)/2
DoNothingConnector.send = send
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"] bc.write_buf = ["testdata"]
...@@ -328,15 +313,10 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -328,15 +313,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, 'data') self._checkWriteBuf(bc, 'data')
self._checkConnectionClosed(0) self._checkConnectionClosed(0)
self._checkUnregistered(0) self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send4(self): def test_Connection_send4(self):
# send multiple packet # send multiple packet
def send(self, data): with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
return len(data)
DoNothingConnector.send = send
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"] bc.write_buf = ["testdata", "second", "third"]
...@@ -345,15 +325,10 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -345,15 +325,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
self._checkConnectionClosed(0) self._checkConnectionClosed(0)
self._checkUnregistered(0) self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send5(self): def test_Connection_send5(self):
# send part of multiple packet # send part of multiple packet
def send(self, data): with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
return len(data)/2
DoNothingConnector.send = send
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"] bc.write_buf = ["testdata", "second", "third"]
...@@ -362,15 +337,12 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -362,15 +337,12 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, 'econdthird') self._checkWriteBuf(bc, 'econdthird')
self._checkConnectionClosed(0) self._checkConnectionClosed(0)
self._checkUnregistered(0) self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send6(self): def test_Connection_send6(self):
# raise try again # raise try again
def send(self, data): def send(orig, self, data):
raise ConnectorTryAgainException raise ConnectorTryAgainException
DoNothingConnector.send = send with Patch(DummyConnector, send=send):
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"] bc.write_buf = ["testdata", "second", "third"]
...@@ -379,15 +351,12 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -379,15 +351,12 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, 'testdatasecondthird') self._checkWriteBuf(bc, 'testdatasecondthird')
self._checkConnectionClosed(0) self._checkConnectionClosed(0)
self._checkUnregistered(0) self._checkUnregistered(0)
finally:
del DoNothingConnector.send
def test_Connection_send7(self): def test_Connection_send7(self):
# raise other error # raise other error
def send(self, data): def send(orig, self, data):
raise ConnectorException raise ConnectorException
DoNothingConnector.send = send with Patch(DummyConnector, send=send):
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata", "second", "third"] bc.write_buf = ["testdata", "second", "third"]
...@@ -397,8 +366,6 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -397,8 +366,6 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
self._checkConnectionClosed(1) self._checkConnectionClosed(1)
self._checkUnregistered(1) self._checkUnregistered(1)
finally:
del DoNothingConnector.send
def test_07_Connection_addPacket(self): def test_07_Connection_addPacket(self):
# new packet # new packet
...@@ -499,10 +466,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -499,10 +466,7 @@ class ConnectionTests(NeoUnitTestBase):
def test_Connection_writable1(self): def test_Connection_writable1(self):
# with pending operation after send # with pending operation after send
def send(self, data): with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
return len(data)/2
DoNothingConnector.send = send
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"] bc.write_buf = ["testdata"]
...@@ -520,15 +484,10 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -520,15 +484,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriterRemoved(0) self._checkWriterRemoved(0)
self._checkReaderRemoved(0) self._checkReaderRemoved(0)
self._checkClose(0) self._checkClose(0)
finally:
del DoNothingConnector.send
def test_Connection_writable2(self): def test_Connection_writable2(self):
# without pending operation after send # without pending operation after send
def send(self, data): with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
return len(data)
DoNothingConnector.send = send
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"] bc.write_buf = ["testdata"]
...@@ -546,15 +505,10 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -546,15 +505,10 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriterRemoved(1) self._checkWriterRemoved(1)
self._checkReaderRemoved(0) self._checkReaderRemoved(0)
self._checkClose(0) self._checkClose(0)
finally:
del DoNothingConnector.send
def test_Connection_writable3(self): def test_Connection_writable3(self):
# without pending operation after send and aborted set to true # without pending operation after send and aborted set to true
def send(self, data): with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
return len(data)
DoNothingConnector.send = send
try:
bc = self._makeConnection() bc = self._makeConnection()
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"] bc.write_buf = ["testdata"]
...@@ -571,18 +525,15 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -571,18 +525,15 @@ class ConnectionTests(NeoUnitTestBase):
# nothing else pending, so writer has been removed # nothing else pending, so writer has been removed
self.assertFalse(bc.pending()) self.assertFalse(bc.pending())
self._checkClose(1) self._checkClose(1)
finally:
del DoNothingConnector.send
def test_Connection_readable(self): def test_Connection_readable(self):
# With aborted set to false # With aborted set to false
# patch receive method to return data # patch receive method to return data
def receive(self): def receive(orig, self):
p = Packets.AnswerPrimary(self.getNewUUID(None)) p = Packets.AnswerPrimary(self.getNewUUID(None))
p.setId(1) p.setId(1)
return ''.join(p.encode()) return ''.join(p.encode())
DoNothingConnector.receive = receive with Patch(DummyConnector, receive=receive):
try:
bc = self._makeConnection() bc = self._makeConnection()
bc._queue = Mock({'__len__': 0}) bc._queue = Mock({'__len__': 0})
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
...@@ -602,8 +553,6 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -602,8 +553,6 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriterRemoved(0) self._checkWriterRemoved(0)
self._checkReaderRemoved(0) self._checkReaderRemoved(0)
self._checkClose(0) self._checkClose(0)
finally:
del DoNothingConnector.receive
def test_ClientConnection_init1(self): def test_ClientConnection_init1(self):
# create a good client connection # create a good client connection
...@@ -624,14 +573,10 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -624,14 +573,10 @@ class ConnectionTests(NeoUnitTestBase):
def test_ClientConnection_init2(self): def test_ClientConnection_init2(self):
# raise connection in progress # raise connection in progress
makeClientConnection_org = DoNothingConnector.makeClientConnection def makeClientConnection(orig, self):
def makeClientConnection(self, *args, **kw):
raise ConnectorInProgressException raise ConnectorInProgressException
DoNothingConnector.makeClientConnection = makeClientConnection with Patch(DummyConnector, makeClientConnection=makeClientConnection):
try:
bc = self._makeClientConnection() bc = self._makeClientConnection()
finally:
DoNothingConnector.makeClientConnection = makeClientConnection_org
# check connector created and connection initialize # check connector created and connection initialize
self.assertTrue(bc.connecting) self.assertTrue(bc.connecting)
self.assertFalse(bc.isServer()) self.assertFalse(bc.isServer())
...@@ -648,14 +593,10 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -648,14 +593,10 @@ class ConnectionTests(NeoUnitTestBase):
def test_ClientConnection_init3(self): def test_ClientConnection_init3(self):
# raise another error, connection must fail # raise another error, connection must fail
makeClientConnection_org = DoNothingConnector.makeClientConnection def makeClientConnection(orig, self):
def makeClientConnection(self, *args, **kw):
raise ConnectorException raise ConnectorException
DoNothingConnector.makeClientConnection = makeClientConnection with Patch(DummyConnector, makeClientConnection=makeClientConnection):
try:
self.assertRaises(ConnectorException, self._makeClientConnection) self.assertRaises(ConnectorException, self._makeClientConnection)
finally:
DoNothingConnector.makeClientConnection = makeClientConnection_org
# since the exception was raised, the connection is not created # since the exception was raised, the connection is not created
# check call to handler # check call to handler
self._checkConnectionStarted(1) self._checkConnectionStarted(1)
...@@ -667,18 +608,11 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -667,18 +608,11 @@ class ConnectionTests(NeoUnitTestBase):
def test_ClientConnection_writable1(self): def test_ClientConnection_writable1(self):
# with a non connecting connection, will call parent's method # with a non connecting connection, will call parent's method
def makeClientConnection(self, *args, **kw): with Patch(DummyConnector, send=lambda orig, self, data: len(data)), \
return "OK" Patch(DummyConnector,
def send(self, data): makeClientConnection=lambda orig, self: "OK") as p:
return len(data) bc = self._makeClientConnection()
makeClientConnection_org = DoNothingConnector.makeClientConnection p.revert()
DoNothingConnector.send = send
DoNothingConnector.makeClientConnection = makeClientConnection
try:
try:
bc = self._makeClientConnection()
finally:
DoNothingConnector.makeClientConnection = makeClientConnection_org
# check connector created and connection initialize # check connector created and connection initialize
self.assertFalse(bc.connecting) self.assertFalse(bc.connecting)
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
...@@ -701,19 +635,12 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -701,19 +635,12 @@ class ConnectionTests(NeoUnitTestBase):
self._checkWriterRemoved(1) self._checkWriterRemoved(1)
self._checkReaderRemoved(0) self._checkReaderRemoved(0)
self._checkClose(0) self._checkClose(0)
finally:
del DoNothingConnector.send
def test_ClientConnection_writable2(self): def test_ClientConnection_writable2(self):
# with a connecting connection, must not call parent's method # with a connecting connection, must not call parent's method
# with errors, close connection # with errors, close connection
def getError(self): with Patch(DummyConnector, getError=lambda orig, self: True):
return True
DoNothingConnector.getError = getError
try:
bc = self._makeClientConnection() bc = self._makeClientConnection()
finally:
del DoNothingConnector.getError
# check connector created and connection initialize # check connector created and connection initialize
self._checkWriteBuf(bc, '') self._checkWriteBuf(bc, '')
bc.write_buf = ["testdata"] bc.write_buf = ["testdata"]
...@@ -836,10 +763,11 @@ class MTConnectionTests(ConnectionTests): ...@@ -836,10 +763,11 @@ class MTConnectionTests(ConnectionTests):
self.dispatcher = Mock({'__repr__': 'Fake Dispatcher'}) self.dispatcher = Mock({'__repr__': 'Fake Dispatcher'})
def _makeClientConnection(self): def _makeClientConnection(self):
self.connector = DoNothingConnector() with dummy_connector:
return MTClientConnection(event_manager=self.em, handler=self.handler, conn = MTClientConnection(self.em, self.handler, self.node,
connector=self.connector, node=self.node, dispatcher=self.dispatcher)
dispatcher=self.dispatcher) self.connector = conn.connector
return conn
def test_MTClientConnectionQueueParameter(self): def test_MTClientConnectionQueueParameter(self):
ask = self._makeClientConnection().ask ask = self._makeClientConnection().ask
......
...@@ -15,35 +15,12 @@ ...@@ -15,35 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
import socket import socket
from . import NeoUnitTestBase, IP_VERSION_FORMAT_DICT from . import NeoUnitTestBase
from neo.lib.util import ReadBuffer, getAddressType, parseNodeAddress, \ from neo.lib.util import ReadBuffer, parseNodeAddress
getConnectorFromAddress, SOCKET_CONNECTORS_DICT
class UtilTests(NeoUnitTestBase): class UtilTests(NeoUnitTestBase):
def test_getConnectorFromAddress(self):
""" Connector name must correspond to address type """
connector = getConnectorFromAddress((
IP_VERSION_FORMAT_DICT[socket.AF_INET], 0))
self.assertEqual(connector, SOCKET_CONNECTORS_DICT[socket.AF_INET])
connector = getConnectorFromAddress((
IP_VERSION_FORMAT_DICT[socket.AF_INET6], 0))
self.assertEqual(connector, SOCKET_CONNECTORS_DICT[socket.AF_INET6])
self.assertRaises(ValueError, getConnectorFromAddress, ('', 0))
self.assertRaises(ValueError, getConnectorFromAddress, ('test', 0))
def test_getAddressType(self):
""" Get the type on an IP Address """
self.assertRaises(ValueError, getAddressType, ('', 0))
address_type = getAddressType(('::1', 0))
self.assertEqual(address_type, socket.AF_INET6)
address_type = getAddressType(('0.0.0.0', 0))
self.assertEqual(address_type, socket.AF_INET)
address_type = getAddressType(('127.0.0.1', 0))
self.assertEqual(address_type, socket.AF_INET)
def test_parseNodeAddress(self): def test_parseNodeAddress(self):
""" Parsing of addesses """ """ Parsing of addesses """
def test(parsed, *args): def test(parsed, *args):
......
...@@ -35,7 +35,7 @@ from neo.lib.connector import SocketConnector, \ ...@@ -35,7 +35,7 @@ from neo.lib.connector import SocketConnector, \
ConnectorConnectionRefusedException, ConnectorTryAgainException ConnectorConnectionRefusedException, ConnectorTryAgainException
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList, p64 from neo.lib.util import parseMasterList, p64
from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \ from .. import NeoTestBase, Patch, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
...@@ -166,7 +166,7 @@ class SerializedEventManager(EventManager): ...@@ -166,7 +166,7 @@ class SerializedEventManager(EventManager):
class Node(object): class Node(object):
def getConnectionList(self, *peers): def getConnectionList(self, *peers):
addr = lambda c: c and (c.accepted_from or c.getAddress()) addr = lambda c: c and (c.addr if c.is_server else c.getAddress())
addr_set = {addr(c.connector) for peer in peers addr_set = {addr(c.connector) for peer in peers
for c in peer.em.connection_dict.itervalues() for c in peer.em.connection_dict.itervalues()
if isinstance(c, Connection)} if isinstance(c, Connection)}
...@@ -467,10 +467,8 @@ class ConnectionFilter(object): ...@@ -467,10 +467,8 @@ class ConnectionFilter(object):
class NEOCluster(object): class NEOCluster(object):
BaseConnection_getTimeout = staticmethod(BaseConnection.getTimeout) BaseConnection_getTimeout = staticmethod(BaseConnection.getTimeout)
SocketConnector_makeClientConnection = staticmethod( SocketConnector_bind = staticmethod(SocketConnector._bind)
SocketConnector.makeClientConnection) SocketConnector_connect = staticmethod(SocketConnector._connect)
SocketConnector_makeListeningConnection = staticmethod(
SocketConnector.makeListeningConnection)
SocketConnector_receive = staticmethod(SocketConnector.receive) SocketConnector_receive = staticmethod(SocketConnector.receive)
SocketConnector_send = staticmethod(SocketConnector.send) SocketConnector_send = staticmethod(SocketConnector.send)
_patch_count = 0 _patch_count = 0
...@@ -489,12 +487,6 @@ class NEOCluster(object): ...@@ -489,12 +487,6 @@ class NEOCluster(object):
cls._patch_count += 1 cls._patch_count += 1
if cls._patch_count > 1: if cls._patch_count > 1:
return return
def makeClientConnection(self, addr):
real_addr = ServerNode.resolv(addr)
try:
return cls.SocketConnector_makeClientConnection(self, real_addr)
finally:
self.remote_addr = addr
def send(self, msg): def send(self, msg):
result = cls.SocketConnector_send(self, msg) result = cls.SocketConnector_send(self, msg)
if type(Serialized.pending) is not frozenset: if type(Serialized.pending) is not frozenset:
...@@ -518,9 +510,10 @@ class NEOCluster(object): ...@@ -518,9 +510,10 @@ class NEOCluster(object):
# safely started even if the cluster isn't. # safely started even if the cluster isn't.
bootstrap.sleep = lambda seconds: None bootstrap.sleep = lambda seconds: None
BaseConnection.getTimeout = lambda self: None BaseConnection.getTimeout = lambda self: None
SocketConnector.makeClientConnection = makeClientConnection SocketConnector._bind = lambda self, addr: \
SocketConnector.makeListeningConnection = lambda self, addr: \ cls.SocketConnector_bind(self, BIND)
cls.SocketConnector_makeListeningConnection(self, BIND) SocketConnector._connect = lambda self, addr: \
cls.SocketConnector_connect(self, ServerNode.resolv(addr))
SocketConnector.receive = receive SocketConnector.receive = receive
SocketConnector.send = send SocketConnector.send = send
Serialized.init() Serialized.init()
...@@ -534,10 +527,8 @@ class NEOCluster(object): ...@@ -534,10 +527,8 @@ class NEOCluster(object):
return return
bootstrap.sleep = time.sleep bootstrap.sleep = time.sleep
BaseConnection.getTimeout = cls.BaseConnection_getTimeout BaseConnection.getTimeout = cls.BaseConnection_getTimeout
SocketConnector.makeClientConnection = \ SocketConnector._bind = cls.SocketConnector_bind
cls.SocketConnector_makeClientConnection SocketConnector._connect = cls.SocketConnector_connect
SocketConnector.makeListeningConnection = \
cls.SocketConnector_makeListeningConnection
SocketConnector.receive = cls.SocketConnector_receive SocketConnector.receive = cls.SocketConnector_receive
SocketConnector.send = cls.SocketConnector_send SocketConnector.send = cls.SocketConnector_send
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment