Commit 0b77e96b authored by Joanne Hugé's avatar Joanne Hugé

Push country from openvpn server to openvpn client

parent 0f97c026
...@@ -249,6 +249,11 @@ class Cache(object): ...@@ -249,6 +249,11 @@ class Cache(object):
return prefix, address return prefix, address
logging.warning('Buggy registry sent us our own address') logging.warning('Buggy registry sent us our own address')
def removePeer(self, prefix):
with self._db as db:
db.execute("DELETE FROM peer WHERE prefix=?",
(prefix,))
def addPeer(self, prefix, address, set_preferred=False): def addPeer(self, prefix, address, set_preferred=False):
logging.debug('Adding peer %s: %s', prefix, address) logging.debug('Adding peer %s: %s', prefix, address)
with self._db: with self._db:
......
...@@ -13,5 +13,7 @@ if script_type == 'up': ...@@ -13,5 +13,7 @@ if script_type == 'up':
if script_type == 'route-up': if script_type == 'route-up':
import time import time
country = os.environ['OPENVPN_country'] if 'OPENVPN_country' in os.environ else None
os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(), os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(),
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip']))) int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'],
country)))
...@@ -10,8 +10,10 @@ os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'], ...@@ -10,8 +10,10 @@ os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'],
int(os.environ['tls_serial_0']), external_ip)))) int(os.environ['tls_serial_0']), external_ip))))
if script_type == 'client-connect': if script_type == 'client-connect':
if os.read(fd, 1) == '\0': country = os.read(fd, 1)
if country == '\0':
sys.exit(1) sys.exit(1)
# Send client its external ip address # Send client its external ip address
with open(sys.argv[2], 'w') as f: with open(sys.argv[2], 'w') as f:
f.write('push "setenv-safe external_ip %s"\n' % external_ip) f.write('push "setenv-safe external_ip %s"\n' % external_ip)
f.write('push "setenv-safe country %s"\n' % country)
...@@ -585,11 +585,11 @@ class BaseTunnelManager(object): ...@@ -585,11 +585,11 @@ class BaseTunnelManager(object):
logging.debug("%s%r", event, args) logging.debug("%s%r", event, args)
r = getattr(self, '_ovpn_' + event.replace('-', '_'))(*args) r = getattr(self, '_ovpn_' + event.replace('-', '_'))(*args)
if r is not None: if r is not None:
sock.send(chr(r)) sock.send(r)
def _ovpn_client_connect(self, common_name, iface, serial, trusted_ip): def _ovpn_client_connect(self, common_name, iface, serial, trusted_ip):
if serial in self.cache.crl: if serial in self.cache.crl:
return False return '\0'
prefix = utils.binFromSubnet(common_name) prefix = utils.binFromSubnet(common_name)
self._served[prefix][iface] = serial self._served[prefix][iface] = serial
if isinstance(self, TunnelManager): # XXX if isinstance(self, TunnelManager): # XXX
...@@ -598,7 +598,18 @@ class BaseTunnelManager(object): ...@@ -598,7 +598,18 @@ class BaseTunnelManager(object):
if prefix in self._connection_dict and self._prefix < prefix: if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix) self._kill(prefix)
self.cache.connecting(prefix, 0) self.cache.connecting(prefix, 0)
return True
family = socket.AF_INET
try:
socket.inet_pton(socket.AF_INET, trusted_ip)
except socket.error:
family = socket.AF_INET6
if self.cache.same_country:
country = self._country.get(family, self._conf_country)
if not country:
return '\0'
return country
return '\1'
def _ovpn_client_disconnect(self, common_name, iface, serial, trusted_ip): def _ovpn_client_disconnect(self, common_name, iface, serial, trusted_ip):
prefix = utils.binFromSubnet(common_name) prefix = utils.binFromSubnet(common_name)
...@@ -1015,7 +1026,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -1015,7 +1026,7 @@ class TunnelManager(BaseTunnelManager):
def handleClientEvent(self): def handleClientEvent(self):
msg = self._read_sock.recv(65536) msg = self._read_sock.recv(65536)
logging.debug("route_up%s", msg) logging.debug("route_up%s", msg)
common_name, time, serial, ip = eval(msg) common_name, time, serial, ip, country = eval(msg)
prefix = utils.binFromSubnet(common_name) prefix = utils.binFromSubnet(common_name)
c = self._connection_dict.get(prefix) c = self._connection_dict.get(prefix)
if c and c.time < float(time): if c and c.time < float(time):
...@@ -1031,6 +1042,18 @@ class TunnelManager(BaseTunnelManager): ...@@ -1031,6 +1042,18 @@ class TunnelManager(BaseTunnelManager):
if address: if address:
if self._geoiplookup or self._conf_country: if self._geoiplookup or self._conf_country:
address = self._updateCountry(address) address = self._updateCountry(address)
if country:
my_country = self._country.get(family, self._conf_country)
same_country = self.cache.same_country
if (country != my_country
if my_country in same_country else
country in same_country):
logging.debug("Wrong country in cache, killing tunnel %s", common_name)
self._kill(prefix)
self.cache.removePeer(prefix)
self._address[family] = utils.dump_address(address) self._address[family] = utils.dump_address(address)
self.cache.my_address = ';'.join(self._address.itervalues()) self.cache.my_address = ';'.join(self._address.itervalues())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment