Commit e9583e12 authored by Cédric Le Ninivin's avatar Cédric Le Ninivin Committed by Julien Muchembled

Do not delete a tunnel if there are still routes through it

Co-authored-by: Julien Muchembled's avatarJulien Muchembled <jm@nexedi.com>
parent 2f49dae1
...@@ -165,11 +165,15 @@ Dump = Packet(1, ...@@ -165,11 +165,15 @@ Dump = Packet(1,
Struct("B"), Struct("B"),
Struct(( Struct((
Array(Struct((Struct("I", "index", "index"), String), "interface", "index name")), Array(Struct((Struct("I", "index", "index"), String), "interface", "index name")),
Array(Struct("16sIHHHHHiHH", "neighbour", "address ifindex reach rxcost txcost rtt rttcost channel if_up")), Array(Struct("16sIHHHHHiHH", "neighbour", "address ifindex reach rxcost txcost rtt rttcost channel if_up cost_multiplier")),
Array(Struct("16sBH", "xroute", "prefix plen metric")), Array(Struct("16sBH", "xroute", "prefix plen metric")),
Array(Struct("16sBHHH8siiI16s16sB", "route", "prefix plen metric smoothed_metric refmetric id seqno age ifindex neigh_address nexthop flags")), Array(Struct("16sBHHH8siiI16s16sB", "route", "prefix plen metric smoothed_metric refmetric id seqno age ifindex neigh_address nexthop flags")),
), "dump", "interfaces neighbours xroutes routes")) ), "dump", "interfaces neighbours xroutes routes"))
SetCostMultiplier = Packet(2,
Struct("16sIH"),
Struct("B", "set_cost_multiplier", "flags"))
class Babel(object): class Babel(object):
...@@ -194,6 +198,7 @@ class Babel(object): ...@@ -194,6 +198,7 @@ class Babel(object):
return self.select(*args) return self.select(*args)
self.select = select self.select = select
self.request_dump = lambda: self.handle_dump((), (), (), ()) self.request_dump = lambda: self.handle_dump((), (), (), ())
self.locked = set()
def send(self, packet): def send(self, packet):
packet.write(self.write_buffer) packet.write(self.write_buffer)
...@@ -253,13 +258,20 @@ class Babel(object): ...@@ -253,13 +258,20 @@ class Babel(object):
else: else:
prefix = None prefix = None
neigh_routes[1][prefix] = route neigh_routes[1][prefix] = route
self.locked.clear()
if unidentified: if unidentified:
routes = {} routes = {}
for address in unidentified: for address in unidentified:
routes.update(n[address][1]) neigh, r = n[address]
if not neigh.cost_multiplier:
self.locked.add(address)
routes.update(r)
if routes: if routes:
neighbours[None] = None, routes neighbours[None] = None, routes
logging.trace("Routes via unidentified neighbours. %r", logging.trace("Routes via unidentified neighbours. %r",
neighbours) neighbours)
self.interfaces = dict((i.index, name) for i, name in interfaces) self.interfaces = dict((i.index, name) for i, name in interfaces)
self.handler.babel_dump() self.handler.babel_dump()
def handle_set_cost_multiplier(self, flags):
pass
...@@ -101,6 +101,59 @@ class Connection(object): ...@@ -101,6 +101,59 @@ class Connection(object):
self.open() self.open()
return True return True
class TunnelKiller(object):
state = None
def __init__(self, peer, tunnel_manager, client=False):
self.peer = peer
self.tm = weakref.proxy(tunnel_manager)
self.timeout = time.time() + 2 * tunnel_manager.timeout
self.client = client
self()
def __call__(self):
if self.state:
return getattr(self, self.state)()
tm_ctl = self.tm.ctl
try:
neigh = tm_ctl.neighbours[self.peer][0]
except KeyError:
return
self.state = 'softLocking'
tm_ctl.send(ctl.SetCostMultiplier(neigh.address, neigh.ifindex, 4096))
self.address = neigh.address
self.ifindex = neigh.ifindex
self.cost_multiplier = neigh.cost_multiplier
def softLocking(self):
tm = self.tm
if self.peer in tm.ctl.neighbours or None in tm.ctl.neighbours:
return
tm.ctl.send(ctl.SetCostMultiplier(self.address, self.ifindex, 0))
self.state = "hardLocking"
def hardLocking(self):
tm = self.tm
if (self.address, self.ifindex) in tm.ctl.locked:
self.state = 'locked'
self.timeout = time.time() + 2 * tm.timeout
tm.sendto(self.peer, ('\4' if self.client else '\5') + tm._prefix)
else:
self.timeout = 0
def unlock(self):
if self.state:
self.tm.ctl.send(ctl.SetCostMultiplier(self.address, self.ifindex,
self.cost_multiplier))
def abort(self):
if self.state != 'unlocking':
self.state = 'unlocking'
self.timeout = time.time() + 2 * self.tm.timeout
locked = unlocking = lambda _: None
class TunnelManager(object): class TunnelManager(object):
...@@ -137,6 +190,7 @@ class TunnelManager(object): ...@@ -137,6 +190,7 @@ class TunnelManager(object):
self._disable_proto = disable_proto self._disable_proto = disable_proto
self._neighbour_set = set(map(utils.binFromSubnet, neighbour_list)) self._neighbour_set = set(map(utils.binFromSubnet, neighbour_list))
self._served = set() self._served = set()
self._killing = {}
self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# See also http://stackoverflow.com/questions/597225/ # See also http://stackoverflow.com/questions/597225/
...@@ -198,6 +252,7 @@ class TunnelManager(object): ...@@ -198,6 +252,7 @@ class TunnelManager(object):
logging.debug('Checking tunnels...') logging.debug('Checking tunnels...')
self._cleanDeads() self._cleanDeads()
if self._next_tunnel_refresh < time.time() or \ if self._next_tunnel_refresh < time.time() or \
self._killing or \
self._makeNewTunnels(False): self._makeNewTunnels(False):
self._next_refresh = None self._next_refresh = None
self.ctl.request_dump() # calls babel_dump immediately at startup self.ctl.request_dump() # calls babel_dump immediately at startup
...@@ -205,7 +260,15 @@ class TunnelManager(object): ...@@ -205,7 +260,15 @@ class TunnelManager(object):
self._next_refresh = time.time() + 5 self._next_refresh = time.time() + 5
def babel_dump(self): def babel_dump(self):
remove = self._next_tunnel_refresh < time.time() t = time.time()
if self._killing:
for prefix, tunnel_killer in self._killing.items():
if tunnel_killer.timeout < t:
tunnel_killer.unlock()
del self._killing[prefix]
else:
tunnel_killer()
remove = self._next_tunnel_refresh < t
if remove: if remove:
self._removeSomeTunnels() self._removeSomeTunnels()
self.resetTunnelRefresh() self.resetTunnelRefresh()
...@@ -237,15 +300,25 @@ class TunnelManager(object): ...@@ -237,15 +300,25 @@ class TunnelManager(object):
def _removeSomeTunnels(self): def _removeSomeTunnels(self):
# Get the candidates to killing # Get the candidates to killing
count = len(self._connection_dict) - self._client_count + 1 peer_set = set(self._connection_dict)
peer_set.difference_update(self._killing)
count = len(peer_set) - self._client_count + 1
if count > 0: if count > 0:
for prefix in sorted(self._connection_dict, for prefix in sorted(peer_set, key=self._tunnelScore)[:count]:
key=self._tunnelScore)[:count]: self._killing[prefix] = TunnelKiller(prefix, self, True)
self._kill(prefix)
def _abortTunnelKiller(self, prefix):
tunnel_killer = self._killing.get(prefix)
if tunnel_killer:
if tunnel_killer.state:
tunnel_killer.abort()
else:
del self._killing[prefix]
def _kill(self, prefix): def _kill(self, prefix):
logging.info('Killing the connection with %u/%u...', logging.info('Killing the connection with %u/%u...',
int(prefix, 2), len(prefix)) int(prefix, 2), len(prefix))
self._abortTunnelKiller(prefix)
connection = self._connection_dict.pop(prefix) connection = self._connection_dict.pop(prefix)
self.freeInterface(connection.iface) self.freeInterface(connection.iface)
connection.close() connection.close()
...@@ -399,6 +472,7 @@ class TunnelManager(object): ...@@ -399,6 +472,7 @@ class TunnelManager(object):
self._served.remove(prefix) self._served.remove(prefix)
except KeyError: except KeyError:
return return
self._abortTunnelKiller(prefix)
if self._gateway_manager is not None: if self._gateway_manager is not None:
self._gateway_manager.remove(trusted_ip) self._gateway_manager.remove(trusted_ip)
...@@ -472,6 +546,17 @@ class TunnelManager(object): ...@@ -472,6 +546,17 @@ class TunnelManager(object):
#else: # I don't know my IP yet! #else: # I don't know my IP yet!
elif code == 3: elif code == 3:
self._sendto(address, '\4' + version.version) self._sendto(address, '\4' + version.version)
elif code in (4, 5): # kill
prefix = msg[1:]
if sender and sender.startswith(prefix, len(self._network)):
try:
tunnel_killer = self._killing[prefix]
except KeyError:
if code == 4 and prefix in self._served: # request
self._killing[prefix] = TunnelKiller(prefix, self)
else:
if code == 5 and tunnel_killer.state == 'locked': # response
self._kill(prefix)
elif code == 255: elif code == 255:
# the registry wants to know the topology for debugging purpose # the registry wants to know the topology for debugging purpose
if not sender or sender[len(self._network):].startswith( if not sender or sender[len(self._network):].startswith(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment