Commit ed50edca authored by Julien Muchembled's avatar Julien Muchembled

Simplify API to establish connections and accept mix of IPv4/IPv6

parent c2c97752
......@@ -21,7 +21,6 @@ from neo.lib.connection import ListeningConnection
from neo.lib.exception import PrimaryFailure
from .handler import AdminEventHandler, MasterEventHandler, \
MasterRequestEventHandler
from neo.lib.connector import getConnectorHandler
from neo.lib.bootstrap import BootstrapManager
from neo.lib.pt import PartitionTable
from neo.lib.protocol import ClusterStates, Errors, \
......@@ -39,8 +38,7 @@ class Application(object):
self.name = config.getCluster()
self.server = config.getBind()
self.master_addresses, connector_name = config.getMasters()
self.connector_handler = getConnectorHandler(connector_name)
self.master_addresses = config.getMasters()
logging.debug('IP address is %s, port is %d', *self.server)
# The partition table is initialized after getting the number of
......@@ -87,8 +85,7 @@ class Application(object):
# Make a listening port.
handler = AdminEventHandler(self)
self.listening_conn = ListeningConnection(self.em, handler,
addr=self.server, connector=self.connector_handler())
self.listening_conn = ListeningConnection(self.em, handler, self.server)
while self.cluster_state != ClusterStates.STOPPING:
self.connectToPrimary()
......@@ -120,7 +117,7 @@ class Application(object):
# search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name, NodeTypes.ADMIN,
self.uuid, self.server)
data = bootstrap.getPrimaryConnection(self.connector_handler)
data = bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data
nm.update([(node.getType(), node.getAddress(), node.getUUID(),
NodeStates.RUNNING)])
......
......@@ -36,7 +36,6 @@ from neo.lib.util import makeChecksum, dump
from neo.lib.locking import Lock
from neo.lib.connection import MTClientConnection, ConnectionClosed
from neo.lib.node import NodeManager
from neo.lib.connector import getConnectorHandler
from .exception import NEOStorageError, NEOStorageCreationUndoneError
from .exception import NEOStorageNotFoundError
from .handlers import storage, master
......@@ -80,8 +79,6 @@ class Application(object):
# Internal Attributes common to all thread
self._db = None
self.name = name
master_addresses, connector_name = parseMasterList(master_nodes)
self.connector_handler = getConnectorHandler(connector_name)
self.dispatcher = Dispatcher(self.poll_thread)
self.nm = NodeManager(dynamic_master_list)
self.cp = ConnectionPool(self)
......@@ -90,7 +87,7 @@ class Application(object):
self.trying_master_node = None
# load master node list
for address in master_addresses:
for address in parseMasterList(master_nodes):
self.nm.createMaster(address=address)
# no self-assigned UUID, primary master will supply us one
......@@ -290,7 +287,6 @@ class Application(object):
conn = MTClientConnection(self.em,
self.notifications_handler,
node=self.trying_master_node,
connector=self.connector_handler(),
dispatcher=self.dispatcher)
# Query for primary master node
if conn.getConnector() is None:
......
......@@ -54,7 +54,7 @@ class ConnectionPool(object):
app = self.app
logging.debug('trying to connect to %s - %s', node, node.getState())
conn = MTClientConnection(app.em, app.storage_event_handler, node,
connector=app.connector_handler(), dispatcher=app.dispatcher)
dispatcher=app.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT,
app.uuid, None, app.name)
try:
......
......@@ -116,7 +116,7 @@ class BootstrapManager(EventHandler):
logging.info('Got a new UUID: %s', uuid_str(self.uuid))
self.accepted = True
def getPrimaryConnection(self, connector_handler):
def getPrimaryConnection(self):
"""
Primary lookup/connection process.
Returns when the connection is made.
......@@ -140,8 +140,7 @@ class BootstrapManager(EventHandler):
sleep(1)
if conn is None:
# open the connection
conn = ClientConnection(em, self, self.current,
connector_handler())
conn = ClientConnection(em, self, self.current)
# still processing
em.poll(1)
return (self.current, conn, self.uuid, self.num_partitions,
......
......@@ -206,6 +206,7 @@ class BaseConnection(object):
Timeouts in HandlerSwitcher are only there to prioritize some packets.
"""
from .connector import SocketConnector as ConnectorClass
KEEP_ALIVE = 60
def __init__(self, event_manager, handler, connector, addr=None):
......@@ -318,19 +319,18 @@ attributeTracker.track(BaseConnection)
class ListeningConnection(BaseConnection):
"""A listen connection."""
def __init__(self, event_manager, handler, addr, connector, **kw):
def __init__(self, event_manager, handler, addr):
logging.debug('listening to %s:%d', *addr)
BaseConnection.__init__(self, event_manager, handler,
addr=addr, connector=connector)
self.connector.makeListeningConnection(addr)
connector = self.ConnectorClass(addr)
BaseConnection.__init__(self, event_manager, handler, connector, addr)
connector.makeListeningConnection()
def readable(self):
try:
new_s, addr = self.connector.getNewConnection()
connector, addr = self.connector.accept()
logging.debug('accepted a connection from %s:%d', *addr)
handler = self.getHandler()
new_conn = ServerConnection(self.em, handler,
connector=new_s, addr=addr)
new_conn = ServerConnection(self.em, handler, connector, addr)
handler.connectionAccepted(new_conn)
except ConnectorTryAgainException:
pass
......@@ -668,14 +668,15 @@ class ClientConnection(Connection):
connecting = True
client = True
def __init__(self, event_manager, handler, node, connector):
def __init__(self, event_manager, handler, node):
addr = node.getAddress()
connector = self.ConnectorClass(addr)
Connection.__init__(self, event_manager, handler, connector, addr)
node.setConnection(self)
handler.connectionStarted(self)
try:
try:
self.connector.makeClientConnection(addr)
connector.makeClientConnection()
except ConnectorInProgressException:
event_manager.addWriter(self)
else:
......
......@@ -19,52 +19,51 @@ import errno
# Global connector registry.
# Fill by calling registerConnectorHandler.
# Read by calling getConnectorHandler.
# Read by calling SocketConnector.__new__
connector_registry = {}
DEFAULT_CONNECTOR = 'SocketConnectorIPv4'
def registerConnectorHandler(connector_handler):
connector_registry[connector_handler.__name__] = connector_handler
def getConnectorHandler(connector=None):
if connector is None:
connector = DEFAULT_CONNECTOR
if isinstance(connector, basestring):
connector_handler = connector_registry.get(connector)
else:
# Allow to directly provide a handler class without requiring to
# register it first.
connector_handler = connector
return connector_handler
class SocketConnector:
connector_registry[connector_handler.af_type] = connector_handler
class SocketConnector(object):
""" This class is a wrapper for a socket """
is_listening = False
remote_addr = None
is_closed = None
is_closed = is_server = None
def __init__(self, s=None, accepted_from=None):
self.accepted_from = accepted_from
if accepted_from is not None:
self.remote_addr = accepted_from
self.is_listening = False
self.is_closed = False
def __new__(cls, addr, s=None):
if s is None:
host, port = addr
for af_type, cls in connector_registry.iteritems():
try :
socket.inet_pton(af_type, host)
break
except socket.error:
pass
else:
raise ValueError("Unknown type of host", host)
self = object.__new__(cls)
self.addr = cls._normAddress(addr)
if s is None:
self.socket = socket.socket(self.af_type, socket.SOCK_STREAM)
s = socket.socket(af_type, socket.SOCK_STREAM)
else:
self.socket = s
self.socket_fd = self.socket.fileno()
self.is_server = True
self.is_closed = False
self.socket = s
self.socket_fd = s.fileno()
# always use non-blocking sockets
self.socket.setblocking(0)
s.setblocking(0)
# disable Nagle algorithm to reduce latency
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return self
def makeClientConnection(self, addr):
self.is_closed = False
self.remote_addr = addr
# Threaded tests monkey-patch the following 2 operations.
_connect = lambda self, addr: self.socket.connect(addr)
_bind = lambda self, addr: self.socket.bind(addr)
def makeClientConnection(self):
assert self.is_closed is None
self.is_server = self.is_closed = False
try:
self.socket.connect(addr)
self._connect(self.addr)
except socket.error, (err, errmsg):
if err == errno.EINPROGRESS:
raise ConnectorInProgressException
......@@ -73,12 +72,12 @@ class SocketConnector:
raise ConnectorException, 'makeClientConnection to %s failed:' \
' %s:%s' % (addr, err, errmsg)
def makeListeningConnection(self, addr):
def makeListeningConnection(self):
assert self.is_closed is None
self.is_closed = False
self.is_listening = True
try:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(addr)
self._bind(self.addr)
self.socket.listen(5)
except socket.error, (err, errmsg):
self.socket.close()
......@@ -94,15 +93,22 @@ class SocketConnector:
# in epoll
return self.socket_fd
def getNewConnection(self):
@staticmethod
def _normAddress(addr):
return addr
def getAddress(self):
return self._normAddress(self.socket.getsockname())
def accept(self):
try:
(new_s, addr) = self._accept()
new_s = self.__class__(new_s, accepted_from=addr)
return (new_s, addr)
s, addr = self.socket.accept()
s = self.__class__(addr, s)
return s, s.addr
except socket.error, (err, errmsg):
if err == errno.EAGAIN:
raise ConnectorTryAgainException
raise ConnectorException, 'getNewConnection failed: %s:%s' % \
raise ConnectorException, 'accept failed: %s:%s' % \
(err, errmsg)
def receive(self):
......@@ -139,14 +145,14 @@ class SocketConnector:
state = 'closed '
else:
state = 'opened '
if self.is_listening:
if self.is_server is None:
state += 'listening'
else:
if self.accepted_from is None:
state += 'to '
else:
if self.is_server:
state += 'from '
state += str(self.remote_addr)
else:
state += 'to '
state += str(self.addr)
return '<%s at 0x%x fileno %s %s, %s>' % (self.__class__.__name__,
id(self), '?' if self.is_closed else self.socket_fd,
self.getAddress(), state)
......@@ -155,22 +161,13 @@ class SocketConnectorIPv4(SocketConnector):
" Wrapper for IPv4 sockets"
af_type = socket.AF_INET
def _accept(self):
return self.socket.accept()
def getAddress(self):
return self.socket.getsockname()
class SocketConnectorIPv6(SocketConnector):
" Wrapper for IPv6 sockets"
af_type = socket.AF_INET6
def _accept(self):
new_s, addr = self.socket.accept()
return new_s, addr[:2]
def getAddress(self):
return self.socket.getsockname()[:2]
@staticmethod
def _normAddress(addr):
return addr[:2]
registerConnectorHandler(SocketConnectorIPv4)
registerConnectorHandler(SocketConnectorIPv6)
......
......@@ -19,12 +19,8 @@ import sys
import traceback
from cStringIO import StringIO
from struct import Struct
try:
from .util import getAddressType
except ImportError:
pass
PROTOCOL_VERSION = 2
PROTOCOL_VERSION = 3
# Size restrictions.
MIN_PACKET_SIZE = 10
......@@ -449,65 +445,6 @@ class PEnum(PStructItem):
enum = self._enum.__class__.__name__
raise ValueError, 'Invalid code for %s enum: %r' % (enum, code)
class PAddressIPGeneric(PStructItem):
def __init__(self, name, format):
PStructItem.__init__(self, name, format)
def encode(self, writer, address):
host, port = address
host = socket.inet_pton(self.af_type, host)
writer(self.pack(host, port))
def decode(self, reader):
data = reader(self.size)
address = self.unpack(data)
host, port = address
host = socket.inet_ntop(self.af_type, host)
return (host, port)
class PAddressIPv4(PAddressIPGeneric):
af_type = socket.AF_INET
def __init__(self, name):
PAddressIPGeneric.__init__(self, name, '!4sH')
class PAddressIPv6(PAddressIPGeneric):
af_type = socket.AF_INET6
def __init__(self, name):
PAddressIPGeneric.__init__(self, name, '!16sH')
class PAddress(PStructItem):
"""
An host address (IPv4/IPv6)
"""
address_format_dict = {
socket.AF_INET: PAddressIPv4('ipv4'),
socket.AF_INET6: PAddressIPv6('ipv6'),
}
def __init__(self, name):
PStructItem.__init__(self, name, '!L')
def _encode(self, writer, address):
if address is None:
writer(self.pack(INVALID_ADDRESS_TYPE))
return
af_type = getAddressType(address)
writer(self.pack(af_type))
encoder = self.address_format_dict[af_type]
encoder.encode(writer, address)
def _decode(self, reader):
af_type = self.unpack(reader(self.size))[0]
if af_type == INVALID_ADDRESS_TYPE:
return None
decoder = self.address_format_dict[af_type]
host, port = decoder.decode(reader)
return (host, port)
class PString(PStructItem):
"""
A variable-length string
......@@ -523,6 +460,29 @@ class PString(PStructItem):
length = self.unpack(reader(self.size))[0]
return reader(length)
class PAddress(PString):
"""
An host address (IPv4/IPv6)
"""
def __init__(self, name):
PString.__init__(self, name)
self._port = Struct('!H')
def _encode(self, writer, address):
if address:
host, port = address
PString._encode(self, writer, host)
writer(self._port.pack(port))
else:
PString._encode(self, writer, '')
def _decode(self, reader):
host = PString._decode(self, reader)
if host:
p = self._port
return host, p.unpack(reader(p.size))[0]
class PBoolean(PStructItem):
"""
A boolean value, encoded as a single byte
......
......@@ -23,11 +23,6 @@ from Queue import deque
from struct import pack, unpack
from time import gmtime
SOCKET_CONNECTORS_DICT = {
socket.AF_INET : 'SocketConnectorIPv4',
socket.AF_INET6: 'SocketConnectorIPv6',
}
TID_LOW_OVERFLOW = 2**32
TID_LOW_MAX = TID_LOW_OVERFLOW - 1
SECOND_PER_TID_LOW = 60.0 / TID_LOW_OVERFLOW
......@@ -125,25 +120,6 @@ def makeChecksum(s):
return sha1(s).digest()
def getAddressType(address):
"Return the type (IPv4 or IPv6) of an ip"
(host, port) = address
for af_type in SOCKET_CONNECTORS_DICT:
try :
socket.inet_pton(af_type, host)
except:
continue
else:
break
else:
raise ValueError("Unknown type of host", host)
return af_type
def getConnectorFromAddress(address):
address_type = getAddressType(address)
return SOCKET_CONNECTORS_DICT[address_type]
def parseNodeAddress(address, port_opt=None):
if address[:1] == '[':
(host, port) = address[1:].split(']')
......@@ -164,24 +140,12 @@ def parseNodeAddress(address, port_opt=None):
def parseMasterList(masters, except_node=None):
assert masters, 'At least one master must be defined'
# load master node list
socket_connector = None
master_node_list = []
for node in masters.split(' '):
if not node:
continue
for node in masters.split():
address = parseNodeAddress(node)
if (address != except_node):
if address != except_node:
master_node_list.append(address)
socket_connector_temp = getConnectorFromAddress(address)
if socket_connector is None:
socket_connector = socket_connector_temp
elif socket_connector != socket_connector_temp:
raise TypeError("Wrong connector type : you're trying to use "
"ipv6 and ipv4 simultaneously")
return master_node_list, socket_connector
return master_node_list
class ReadBuffer(object):
......
......@@ -18,7 +18,6 @@ import sys, weakref
from time import time
from neo.lib import logging
from neo.lib.connector import getConnectorHandler
from neo.lib.debug import register as registerLiveDebugger
from neo.lib.protocol import uuid_str, UUID_NAMESPACES, ZERO_TID
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
......@@ -59,9 +58,7 @@ class Application(object):
self.autostart = config.getAutostart()
self.storage_readiness = set()
master_addresses, connector_name = config.getMasters()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses:
for master_address in config.getMasters():
self.nm.createMaster(address=master_address)
logging.debug('IP address is %s, port is %d', *self.server)
......@@ -102,7 +99,7 @@ class Application(object):
raise ValueError("upstream cluster name must be"
" different from cluster name")
self.backup_app = BackupApplication(self, upstream_cluster,
*config.getUpstreamMasters())
config.getUpstreamMasters())
self.administration_handler = administration.AdministrationHandler(
self)
......@@ -141,8 +138,7 @@ class Application(object):
def _run(self):
"""Make sure that the status is sane and start a loop."""
# Make a listening port.
self.listening_conn = ListeningConnection(self.em, None,
addr=self.server, connector=self.connector_handler())
self.listening_conn = ListeningConnection(self.em, None, self.server)
# Start a normal operation.
while self.cluster_state != ClusterStates.STOPPING:
......@@ -196,8 +192,7 @@ class Application(object):
ClientConnection(self.em, client_handler,
# XXX: Ugly, but the whole election code will be
# replaced soon
node=getByAddress(addr),
connector=self.connector_handler())
getByAddress(addr))
self.unconnected_master_node_set.clear()
self.em.poll(1)
except ElectionFailure, m:
......@@ -381,9 +376,7 @@ class Application(object):
# Reconnect to primary master node.
primary_handler = secondary.PrimaryHandler(self)
ClientConnection(self.em, primary_handler,
node=self.primary_master_node,
connector=self.connector_handler())
ClientConnection(self.em, primary_handler, self.primary_master_node)
# and another for the future incoming connections
self.listening_conn.setHandler(
......
......@@ -19,7 +19,6 @@ from bisect import bisect
from collections import defaultdict
from neo.lib import logging
from neo.lib.bootstrap import BootstrapManager
from neo.lib.connector import getConnectorHandler
from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler
from neo.lib.node import NodeManager
......@@ -67,11 +66,10 @@ class BackupApplication(object):
pt = None
def __init__(self, app, name, master_addresses, connector_name):
def __init__(self, app, name, master_addresses):
self.app = weakref.proxy(app)
self.name = name
self.nm = NodeManager()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses:
self.nm.createMaster(address=master_address)
......@@ -107,7 +105,7 @@ class BackupApplication(object):
break
poll(1)
node, conn, uuid, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection(self.connector_handler)
bootstrap.getPrimaryConnection()
try:
app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node
......
......@@ -14,11 +14,9 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from neo.lib.connector import getConnectorHandler
from neo.lib.connection import ClientConnection
from neo.lib.event import EventManager
from neo.lib.protocol import ClusterStates, NodeStates, ErrorCodes, Packets
from neo.lib.util import getConnectorFromAddress
from neo.lib.node import NodeManager
from .handler import CommandEventHandler
......@@ -31,8 +29,6 @@ class NeoCTL(object):
connected = False
def __init__(self, address):
connector_name = getConnectorFromAddress(address)
self.connector_handler = getConnectorHandler(connector_name)
self.nm = nm = NodeManager()
self.server = nm.createAdmin(address=address)
self.em = EventManager()
......@@ -47,7 +43,7 @@ class NeoCTL(object):
def __getConnection(self):
if not self.connected:
self.connection = ClientConnection(self.em, self.handler,
node=self.server, connector=self.connector_handler())
self.server)
while not self.connected:
self.em.poll(1)
if self.connection is None:
......
......@@ -24,7 +24,6 @@ from neo.lib.node import NodeManager
from neo.lib.event import EventManager
from neo.lib.connection import ListeningConnection
from neo.lib.exception import OperationFailure, PrimaryFailure
from neo.lib.connector import getConnectorHandler
from neo.lib.pt import PartitionTable
from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager
......@@ -54,9 +53,7 @@ class Application(object):
)
# load master nodes
master_addresses, connector_name = config.getMasters()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses :
for master_address in config.getMasters():
self.nm.createMaster(address=master_address)
# set the bind address
......@@ -177,8 +174,7 @@ class Application(object):
# Make a listening port
handler = identification.IdentificationHandler(self)
self.listening_conn = ListeningConnection(self.em, handler,
addr=self.server, connector=self.connector_handler())
self.listening_conn = ListeningConnection(self.em, handler, self.server)
self.server = self.listening_conn.getAddress()
# Connect to a primary master node, verify data, and
......@@ -234,7 +230,7 @@ class Application(object):
# search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, self.name,
NodeTypes.STORAGE, self.uuid, self.server)
data = bootstrap.getPrimaryConnection(self.connector_handler)
data = bootstrap.getPrimaryConnection()
(node, conn, uuid, num_partitions, num_replicas) = data
self.master_node = node
self.master_conn = conn
......
......@@ -46,7 +46,7 @@ class Checker(object):
conn.asClient()
else:
conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler())
node)
conn.ask(Packets.RequestIdentification(
NodeTypes.STORAGE, uuid, app.server, name))
self.conn_dict[conn] = node.isIdentified()
......
......@@ -254,8 +254,7 @@ class Replicator(object):
self.fetchTransactions()
else:
assert name or node.getUUID() != app.uuid, "loopback connection"
conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler())
conn = ClientConnection(app.em, StorageOperationHandler(app), node)
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name))
if previous_node is not None and previous_node.isConnected():
......
......@@ -30,7 +30,6 @@ from functools import wraps
from mock import Mock
from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
from neo.lib.util import getAddressType
from time import time
from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess
......@@ -203,8 +202,7 @@ class NeoUnitTestBase(NeoTestBase):
return Mock({
'getCluster': cluster,
'getBind': masters[0],
'getMasters': (masters, getAddressType((
self.local_ip, 0))),
'getMasters': masters,
'getReplicas': replicas,
'getPartitions': partitions,
'getUUID': uuid,
......@@ -226,8 +224,7 @@ class NeoUnitTestBase(NeoTestBase):
return Mock({
'getCluster': cluster,
'getBind': (masters[0], 10020 + index),
'getMasters': (masters, getAddressType((
self.local_ip, 0))),
'getMasters': masters,
'getDatabase': db,
'getUUID': uuid,
'getReset': False,
......@@ -554,29 +551,5 @@ class Patch(object):
self.__del__()
connector_cpt = 0
class DoNothingConnector(Mock):
def __init__(self, s=None):
logging.info("initializing connector")
global connector_cpt
self.desc = connector_cpt
connector_cpt += 1
self.packet_cpt = 0
Mock.__init__(self)
def getAddress(self):
return self.addr
def makeClientConnection(self, addr):
self.addr = addr
def makeListeningConnection(self, addr):
self.addr = addr