Commit ba573ab7 authored by Julien Muchembled's avatar Julien Muchembled

Make nodes ask registry for their country

To prepare for the removal of geoip2, we want nodes to ask the registry
for their country. geoip2 is kept in this update since nodes will still
need to figure out countries of other nodes which haven't updated yet.
Once all nodes will be updated to this version, geoip2 will be ready to
be deleted.

See merge request !32
parents 0f97c026 dd943d7c
...@@ -272,3 +272,9 @@ class Cache(object): ...@@ -272,3 +272,9 @@ class Cache(object):
if a != address: if a != address:
q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address)) q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address))
q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,)) q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,))
def getCountry(self, ip):
try:
return self._registry.getCountry(self._prefix, ip)
except socket.error, e:
logging.warning('Failed to get country (%s)', ip)
...@@ -67,7 +67,7 @@ def getConfig(): ...@@ -67,7 +67,7 @@ def getConfig():
"to access it.") "to access it.")
_('--country', metavar='CODE', _('--country', metavar='CODE',
help="Country code that is advertised to other nodes" help="Country code that is advertised to other nodes"
"(default: country is fetched from MaxMind database)") "(default: country is resolved by the registry)")
_ = parser.add_argument_group('routing').add_argument _ = parser.add_argument_group('routing').add_argument
_('-B', dest='babel_args', metavar='ARG', action='append', default=[], _('-B', dest='babel_args', metavar='ARG', action='append', default=[],
...@@ -296,8 +296,8 @@ def main(): ...@@ -296,8 +296,8 @@ def main():
control_socket = os.path.join(config.run, 'babeld.sock') control_socket = os.path.join(config.run, 'babeld.sock')
if config.client_count and not config.client: if config.client_count and not config.client:
tunnel_manager = tunnel.TunnelManager(control_socket, tunnel_manager = tunnel.TunnelManager(control_socket,
cache, cert, config.openvpn_args, timeout, cache, cert, config.openvpn_args, timeout, config.client_count,
config.client_count, config.iface_list, config.country, address, ip_changed, config.iface_list, config.country, address, ip_changed,
remote_gateway, config.disable_proto, config.neighbour) remote_gateway, config.disable_proto, config.neighbour)
add_tunnels(tunnel_manager.new_iface_list) add_tunnels(tunnel_manager.new_iface_list)
else: else:
......
...@@ -39,7 +39,7 @@ def rpc(f): ...@@ -39,7 +39,7 @@ def rpc(f):
args, varargs, varkw, defaults = inspect.getargspec(f) args, varargs, varkw, defaults = inspect.getargspec(f)
assert not (varargs or varkw), f assert not (varargs or varkw), f
if not defaults: if not defaults:
defaults = () defaults = ()
i = len(args) - len(defaults) i = len(args) - len(defaults)
f.getcallargs = eval("lambda %s: locals()" % ','.join(args[1:i] f.getcallargs = eval("lambda %s: locals()" % ','.join(args[1:i]
+ map("%s=%r".__mod__, zip(args[i:], defaults)))) + map("%s=%r".__mod__, zip(args[i:], defaults))))
...@@ -59,6 +59,9 @@ class RegistryServer(object): ...@@ -59,6 +59,9 @@ class RegistryServer(object):
peers = 0, () peers = 0, ()
cert_duration = 365 * 86400 cert_duration = 365 * 86400
def _geoiplookup(self, ip):
raise HTTPError(httplib.BAD_REQUEST)
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.lock = threading.Lock() self.lock = threading.Lock()
...@@ -103,6 +106,20 @@ class RegistryServer(object): ...@@ -103,6 +106,20 @@ class RegistryServer(object):
self.ctl = ctl.Babel(os.path.join(config.run, 'babeld.sock'), self.ctl = ctl.Babel(os.path.join(config.run, 'babeld.sock'),
weakref.proxy(self), self.network) weakref.proxy(self), self.network)
db = os.getenv('GEOIP2_MMDB')
if db:
from geoip2 import database, errors
country = database.Reader(db).country
def geoiplookup(ip):
try:
return country(ip).country.iso_code.encode()
except errors.AddressNotFoundError:
return
self._geoiplookup = geoiplookup
elif self.config.same_country:
sys.exit("Can not respect 'same_country' network configuration"
" (GEOIP2_MMDB not set)")
self.onTimeout() self.onTimeout()
if self.prefix: if self.prefix:
with self.db: with self.db:
...@@ -504,6 +521,10 @@ class RegistryServer(object): ...@@ -504,6 +521,10 @@ class RegistryServer(object):
logging.info("Timeout while querying address for %s/%s", logging.info("Timeout while querying address for %s/%s",
int(peer, 2), len(peer)) int(peer, 2), len(peer))
@rpc
def getCountry(self, cn, address):
return self._geoiplookup(address)
@rpc @rpc
def getBootstrapPeer(self, cn): def getBootstrapPeer(self, cn):
with self.peers_lock: with self.peers_lock:
......
...@@ -213,17 +213,22 @@ class BaseTunnelManager(object): ...@@ -213,17 +213,22 @@ class BaseTunnelManager(object):
address_dict = defaultdict(list) address_dict = defaultdict(list)
for family, address in address: for family, address in address:
address_dict[family] += address address_dict[family] += address
if any(address_dict.itervalues()):
del cache.my_address # Cache may contain our country, we want to use it if possible to
else: # prevent interaction with registry
address = cache.my_address cache_address = cache.my_address
if address: if cache_address:
for address in utils.parse_address(address): cache_dict = defaultdict(list)
try: for address in utils.parse_address(cache_address):
proto = proto_dict[address[2]] try:
except KeyError: proto = proto_dict[address[2]]
continue except KeyError:
address_dict[proto[0]].append(address) continue
cache_dict[proto[0]].append(address)
if {proto: cache_dict[proto][:3] for proto in cache_dict
} == address_dict:
address_dict = cache_dict
db = os.getenv('GEOIP2_MMDB') db = os.getenv('GEOIP2_MMDB')
if db: if db:
from geoip2 import database, errors from geoip2 import database, errors
...@@ -234,6 +239,7 @@ class BaseTunnelManager(object): ...@@ -234,6 +239,7 @@ class BaseTunnelManager(object):
except errors.AddressNotFoundError: except errors.AddressNotFoundError:
return return
self._geoiplookup = geoiplookup self._geoiplookup = geoiplookup
if cache.same_country:
self._country = {} self._country = {}
address_dict = {family: self._updateCountry(address) address_dict = {family: self._updateCountry(address)
...@@ -244,6 +250,7 @@ class BaseTunnelManager(object): ...@@ -244,6 +250,7 @@ class BaseTunnelManager(object):
self._address = {family: utils.dump_address(address) self._address = {family: utils.dump_address(address)
for family, address in address_dict.iteritems() for family, address in address_dict.iteritems()
if address} if address}
cache.my_address = ';'.join(self._address.itervalues())
self.sock = socket.socket(socket.AF_INET6, self.sock = socket.socket(socket.AF_INET6,
socket.SOCK_DGRAM | socket.SOCK_CLOEXEC) socket.SOCK_DGRAM | socket.SOCK_CLOEXEC)
...@@ -664,7 +671,7 @@ class BaseTunnelManager(object): ...@@ -664,7 +671,7 @@ class BaseTunnelManager(object):
for a in address: for a in address:
family, ip = resolve(*a[:3]) family, ip = resolve(*a[:3])
for ip in ip: for ip in ip:
country = a[3] if len(a) > 3 else self._geoiplookup(ip) country = a[3] if len(a) > 3 else self.cache.getCountry(ip)
if country: if country:
if self._country.get(family) != country: if self._country.get(family) != country:
self._country[family] = country self._country[family] = country
...@@ -680,10 +687,10 @@ class TunnelManager(BaseTunnelManager): ...@@ -680,10 +687,10 @@ class TunnelManager(BaseTunnelManager):
'client_count', 'max_clients', 'same_country', 'tunnel_refresh')) 'client_count', 'max_clients', 'same_country', 'tunnel_refresh'))
def __init__(self, control_socket, cache, cert, openvpn_args, def __init__(self, control_socket, cache, cert, openvpn_args,
timeout, client_count, iface_list, country, address, ip_changed, timeout, client_count, iface_list, conf_country, address,
remote_gateway, disable_proto, neighbour_list=()): ip_changed, remote_gateway, disable_proto, neighbour_list=()):
super(TunnelManager, self).__init__(control_socket, super(TunnelManager, self).__init__(control_socket,
cache, cert, country, address) cache, cert, conf_country, address)
self.ovpn_args = openvpn_args self.ovpn_args = openvpn_args
self.timeout = timeout self.timeout = timeout
self._read_sock, self.write_sock = socket.socketpair( self._read_sock, self.write_sock = socket.socketpair(
...@@ -1029,7 +1036,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -1029,7 +1036,7 @@ class TunnelManager(BaseTunnelManager):
if self._ip_changed: if self._ip_changed:
family, address = self._ip_changed(ip) family, address = self._ip_changed(ip)
if address: if address:
if self._geoiplookup or self._conf_country: if self.cache.same_country:
address = self._updateCountry(address) address = self._updateCountry(address)
self._address[family] = utils.dump_address(address) self._address[family] = utils.dump_address(address)
self.cache.my_address = ';'.join(self._address.itervalues()) self.cache.my_address = ';'.join(self._address.itervalues())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment