Commit b2040ea0 authored by Julien Muchembled's avatar Julien Muchembled

Make --client & --client-count=0 modes process UDP/326 messages

These modes are partly unified with the normal one by splitting TunnelManager.
parent 9717eb0e
...@@ -158,48 +158,141 @@ class TunnelKiller(object): ...@@ -158,48 +158,141 @@ class TunnelKiller(object):
locked = unlocking = lambda _: None locked = unlocking = lambda _: None
class TunnelManager(object): class BaseTunnelManager(object):
def __init__(self, peer_db, cert, address=()):
self.cert = cert
self._network = cert.network
self._prefix = cert.prefix
self.peer_db = peer_db
self._connecting = set()
self._connection_dict = {}
self._served = set()
address_dict = defaultdict(list)
for family, address in address:
address_dict[family] += address
self._address = dict((family, utils.dump_address(address))
for family, address in address_dict.iteritems()
if address)
self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# See also http://stackoverflow.com/questions/597225/
# about binding and anycast.
self.sock.bind(('::', PORT))
def select(self, r, w, t):
r[self.sock] = self.handlePeerEvent
def sendto(self, peer, msg):
ip = utils.ipFromBin(self._network + peer)
try:
return self.sock.sendto(msg, (ip, PORT))
except socket.error, e:
(logging.info if e.errno == errno.ENETUNREACH else logging.error)(
'Failed to send message to %s/%s (%s)',
int(peer, 2), len(peer), e)
def _sendto(self, to, msg):
try:
return self.sock.sendto(msg, to[:2])
except socket.error, e:
(logging.info if e.errno == errno.ENETUNREACH else logging.error)(
'Failed to send message to %s (%s)', to, e)
def handlePeerEvent(self):
msg, address = self.sock.recvfrom(1<<16)
if address[0] == '::1':
sender = None
else:
try:
sender = utils.binFromIp(address[0])
except socket.error, e:
# inet_pton does not parse '<ipv6>%<iface>'
logging.warning('ignored message from %r (%s)', address, e)
return
if not sender.startswith(self._network):
return
if not msg:
return
code = ord(msg[0])
if code == 1: # answer
# Old versions may send additional and obsolete addresses.
# Ignore them, as well as truncated lines.
try:
prefix, address = msg[1:msg.index('\n')].split()
int(prefix, 2)
except ValueError:
pass
else:
if prefix != self._prefix:
self.peer_db.addPeer(prefix, address)
try:
self._connecting.remove(prefix)
except KeyError:
pass
else:
self._makeTunnel(prefix, address)
elif code == 2: # request
if self._address:
self._sendto(address, '\1%s %s\n' % (self._prefix,
';'.join(self._address.itervalues())))
#else: # I don't know my IP yet!
elif code == 3:
if len(msg) == 1:
self._sendto(address, '\3' + version.version)
elif code in (4, 5): # kill
prefix = msg[1:]
if sender and sender.startswith(prefix, len(self._network)):
try:
tunnel_killer = self._killing[prefix]
except KeyError:
if code == 4 and prefix in self._served: # request
self._killing[prefix] = TunnelKiller(prefix, self)
else:
if code == 5 and tunnel_killer.state == 'locked': # response
self._kill(prefix)
elif code == 255:
# the registry wants to know the topology for debugging purpose
if not sender or sender[len(self._network):].startswith(
self.peer_db.registry_prefix):
msg = ['\xfe%s%u/%u\n%u\n' % (msg[1:],
int(self._prefix, 2), len(self._prefix),
len(self._connection_dict))]
msg.extend('%u/%u\n' % (int(x, 2), len(x))
for x in (self._connection_dict, self._served)
for x in x)
try:
self.sock.sendto(''.join(msg), address[:2])
except socket.error, e:
pass
class TunnelManager(BaseTunnelManager):
def __init__(self, control_socket, peer_db, cert, openvpn_args, timeout, def __init__(self, control_socket, peer_db, cert, openvpn_args, timeout,
refresh, client_count, iface_list, address, ip_changed, refresh, client_count, iface_list, address, ip_changed,
encrypt, remote_gateway, disable_proto, neighbour_list=()): encrypt, remote_gateway, disable_proto, neighbour_list=()):
self.cert = cert super(TunnelManager, self).__init__(peer_db, cert, address)
self._network = cert.network
self.ctl = ctl.Babel(control_socket, weakref.proxy(self), self._network) self.ctl = ctl.Babel(control_socket, weakref.proxy(self), self._network)
self.encrypt = encrypt self.encrypt = encrypt
self.ovpn_args = openvpn_args self.ovpn_args = openvpn_args
self.peer_db = peer_db
self.timeout = timeout self.timeout = timeout
# Create and open read_only pipe to get server events # Create and open read_only pipe to get server events
r, self.write_pipe = os.pipe() r, self.write_pipe = os.pipe()
self._read_pipe = os.fdopen(r) self._read_pipe = os.fdopen(r)
self._connecting = set()
self._connection_dict = {}
self._disconnected = 0 self._disconnected = 0
self._distant_peers = [] self._distant_peers = []
self._iface_to_prefix = {} self._iface_to_prefix = {}
self._refresh_time = refresh self._refresh_time = refresh
self._iface_list = iface_list self._iface_list = iface_list
self._prefix = cert.prefix
address_dict = defaultdict(list)
for family, address in address:
address_dict[family] += address
self._address = dict((family, utils.dump_address(address))
for family, address in address_dict.iteritems()
if address)
self._ip_changed = ip_changed self._ip_changed = ip_changed
self._gateway_manager = MultiGatewayManager(remote_gateway) \ self._gateway_manager = MultiGatewayManager(remote_gateway) \
if remote_gateway else None if remote_gateway else None
self._disable_proto = disable_proto self._disable_proto = disable_proto
self._neighbour_set = set(map(utils.binFromSubnet, neighbour_list)) self._neighbour_set = set(map(utils.binFromSubnet, neighbour_list))
self._served = set()
self._killing = {} self._killing = {}
self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# See also http://stackoverflow.com/questions/597225/
# about binding and anycast.
self.sock.bind(('::', PORT))
self._next_refresh = time.time() self._next_refresh = time.time()
self.resetTunnelRefresh() self.resetTunnelRefresh()
...@@ -511,86 +604,3 @@ class TunnelManager(object): ...@@ -511,86 +604,3 @@ class TunnelManager(object):
family, address = self._ip_changed(ip) family, address = self._ip_changed(ip)
if address: if address:
self._address[family] = utils.dump_address(address) self._address[family] = utils.dump_address(address)
def sendto(self, peer, msg):
ip = utils.ipFromBin(self._network + peer)
try:
return self.sock.sendto(msg, (ip, PORT))
except socket.error, e:
(logging.info if e.errno == errno.ENETUNREACH else logging.error)(
'Failed to send message to %s/%s (%s)',
int(peer, 2), len(peer), e)
def _sendto(self, to, msg):
try:
return self.sock.sendto(msg, to[:2])
except socket.error, e:
(logging.info if e.errno == errno.ENETUNREACH else logging.error)(
'Failed to send message to %s (%s)', to, e)
def handlePeerEvent(self):
msg, address = self.sock.recvfrom(1<<16)
if address[0] == '::1':
sender = None
else:
try:
sender = utils.binFromIp(address[0])
except socket.error, e:
# inet_pton does not parse '<ipv6>%<iface>'
logging.warning('ignored message from %r (%s)', address, e)
return
if not sender.startswith(self._network):
return
if not msg:
return
code = ord(msg[0])
if code == 1: # answer
# Old versions may send additional and obsolete addresses.
# Ignore them, as well as truncated lines.
try:
prefix, address = msg[1:msg.index('\n')].split()
int(prefix, 2)
except ValueError:
pass
else:
if prefix != self._prefix:
self.peer_db.addPeer(prefix, address)
try:
self._connecting.remove(prefix)
except KeyError:
pass
else:
self._makeTunnel(prefix, address)
elif code == 2: # request
if self._address:
self._sendto(address, '\1%s %s\n' % (self._prefix,
';'.join(self._address.itervalues())))
#else: # I don't know my IP yet!
elif code == 3:
if len(msg) == 1:
self._sendto(address, '\3' + version.version)
elif code in (4, 5): # kill
prefix = msg[1:]
if sender and sender.startswith(prefix, len(self._network)):
try:
tunnel_killer = self._killing[prefix]
except KeyError:
if code == 4 and prefix in self._served: # request
self._killing[prefix] = TunnelKiller(prefix, self)
else:
if code == 5 and tunnel_killer.state == 'locked': # response
self._kill(prefix)
elif code == 255:
# the registry wants to know the topology for debugging purpose
if not sender or sender[len(self._network):].startswith(
self.peer_db.registry_prefix):
msg = ['\xfe%s%u/%u\n%u\n' % (msg[1:],
int(self._prefix, 2), len(self._prefix),
len(self._connection_dict))]
msg.extend('%u/%u\n' % (int(x, 2), len(x))
for x in (self._connection_dict, self._served)
for x in x)
try:
self.sock.sendto(''.join(msg), address[:2])
except socket.error, e:
pass
...@@ -256,20 +256,20 @@ def main(): ...@@ -256,20 +256,20 @@ def main():
# Init db and tunnels # Init db and tunnels
tunnel_interfaces = server_tunnels.keys() tunnel_interfaces = server_tunnels.keys()
timeout = 4 * config.hello timeout = 4 * config.hello
cleanup = [] peer_db = db.PeerDB(db_path, config.registry, cert)
cleanup = [lambda: peer_db.cacheMinimize(config.client_count)]
if config.client_count and not config.client: if config.client_count and not config.client:
peer_db = db.PeerDB(db_path, config.registry, cert)
cleanup.append(lambda: peer_db.cacheMinimize(config.client_count))
tunnel_manager = tunnel.TunnelManager(config.control_socket, tunnel_manager = tunnel.TunnelManager(config.control_socket,
peer_db, cert, config.openvpn_args, timeout, peer_db, cert, config.openvpn_args, timeout,
config.tunnel_refresh, config.client_count, config.iface_list, config.tunnel_refresh, config.client_count, config.iface_list,
address, ip_changed, config.encrypt, remote_gateway, address, ip_changed, config.encrypt, remote_gateway,
config.disable_proto, config.neighbour) config.disable_proto, config.neighbour)
cleanup.append(tunnel_manager.sock.close)
tunnel_interfaces += tunnel_manager.new_iface_list tunnel_interfaces += tunnel_manager.new_iface_list
write_pipe = tunnel_manager.write_pipe write_pipe = tunnel_manager.write_pipe
else: else:
tunnel_manager = write_pipe = None write_pipe = None
tunnel_manager = tunnel.BaseTunnelManager(peer_db, cert)
cleanup.append(tunnel_manager.sock.close)
try: try:
exit.acquire() exit.acquire()
...@@ -372,17 +372,18 @@ def main(): ...@@ -372,17 +372,18 @@ def main():
exit.acquire() exit.acquire()
for cmd in config.daemon or (): for cmd in config.daemon or ():
cleanup.insert(-1, utils.Popen(cmd, shell=True).stop) cleanup.insert(-1, utils.Popen(cmd, shell=True).stop)
try:
# main loop
select_list = [forwarder.select] if forwarder else []
if tunnel_manager:
select_list.append(tunnel_manager.select)
cleanup[-1:-1] = (tunnel_manager.delInterfaces, cleanup[-1:-1] = (tunnel_manager.delInterfaces,
tunnel_manager.killAll) tunnel_manager.killAll)
except AttributeError:
pass
# main loop
exit.release() exit.release()
def renew(): def renew():
raise ReexecException("Restart to renew certificate") raise ReexecException("Restart to renew certificate")
select_list.append(utils.select) select_list = [forwarder.select] if forwarder else []
select_list += tunnel_manager.select, utils.select
while True: while True:
args = {}, {}, [(next_renew, renew)] args = {}, {}, [(next_renew, renew)]
for s in select_list: for s in select_list:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment