Commit 1b204364 authored by Tom Niget's avatar Tom Niget

py2to3: continue fixing py2 to py3 issues

parent dcbda20b
This diff is collapsed.
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
import re import re
import os import os
from new import function
from nemu.iproute import backticks, get_if_data, route, \ from nemu.iproute import backticks, get_if_data, route, \
get_addr_data, get_all_route_data, interface get_addr_data, get_all_route_data, interface
from nemu.interface import Switch, Interface from nemu.interface import Switch, Interface
from types import FunctionType
def _get_all_route_data(): def _get_all_route_data():
ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all" ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all"
...@@ -65,7 +65,7 @@ def __init__(self, *args, **kw): ...@@ -65,7 +65,7 @@ def __init__(self, *args, **kw):
self.name = self.name.split('@',1)[0] self.name = self.name.split('@',1)[0]
interface.__init__ = __init__ interface.__init__ = __init__
get_addr_data.orig = function(get_addr_data.__code__, get_addr_data.orig = FunctionType(get_addr_data.__code__,
get_addr_data.__globals__) get_addr_data.__globals__)
def _get_addr_data(): def _get_addr_data():
byidx, bynam = get_addr_data.orig() byidx, bynam = get_addr_data.orig()
......
...@@ -64,7 +64,7 @@ class Ping(Thread): ...@@ -64,7 +64,7 @@ class Ping(Thread):
os.utime(csv_path, (time.time(), time.time())) os.utime(csv_path, (time.time(), time.time()))
for add in no_responses: for add in no_responses:
print(('No response from %s with seq no %d' % (add, seq))) print('No response from %s with seq no %d' % (add, seq))
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('n', help = 'my machine name (m1,m2...)') parser.add_argument('n', help = 'my machine name (m1,m2...)')
......
...@@ -30,4 +30,5 @@ def __file__(): ...@@ -30,4 +30,5 @@ def __file__():
return os.path.join(sys.path[0], sys.argv[1]) return os.path.join(sys.path[0], sys.argv[1])
__file__ = __file__() __file__ = __file__()
exec(compile(open(__file__, "rb").read(), __file__, 'exec')) with open(__file__) as f:
exec(compile(f.read(), __file__, 'exec'))
...@@ -39,5 +39,5 @@ def checkHMAC(db, machines): ...@@ -39,5 +39,5 @@ def checkHMAC(db, machines):
if rc: if rc:
print('All nodes use Babel with the correct HMAC configuration') print('All nodes use Babel with the correct HMAC configuration')
else: else:
print(('Expected config: %s' % dict(list(zip(BABEL_HMAC, hmac))))) print('Expected config: %s' % dict(zip(BABEL_HMAC, hmac)))
return rc return rc
...@@ -5,7 +5,7 @@ from . import utils, version, x509 ...@@ -5,7 +5,7 @@ from . import utils, version, x509
class Cache(object): class Cache(object):
def __init__(self, db_path, registry, cert, db_size=200): def __init__(self, db_path, registry, cert: x509.Cert, db_size=200):
self._prefix = cert.prefix self._prefix = cert.prefix
self._db_size = db_size self._db_size = db_size
self._decrypt = cert.decrypt self._decrypt = cert.decrypt
...@@ -89,8 +89,10 @@ class Cache(object): ...@@ -89,8 +89,10 @@ class Cache(object):
logging.info("Getting new network parameters from registry...") logging.info("Getting new network parameters from registry...")
try: try:
# TODO: When possible, the registry should be queried via the re6st. # TODO: When possible, the registry should be queried via the re6st.
network_config = self._registry.getNetworkConfig(self._prefix)
logging.debug('getNetworkConfig result: %r', network_config)
x = json.loads(zlib.decompress( x = json.loads(zlib.decompress(
self._registry.getNetworkConfig(self._prefix))) network_config))
base64_list = x.pop('', ()) base64_list = x.pop('', ())
config = {} config = {}
for k, v in x.items(): for k, v in x.items():
...@@ -134,7 +136,7 @@ class Cache(object): ...@@ -134,7 +136,7 @@ class Cache(object):
((k, memoryview(v) if k in base64_list or ((k, memoryview(v) if k in base64_list or
k.startswith('babel_hmac') else v) k.startswith('babel_hmac') else v)
for k, v in config.items())) for k, v in config.items()))
self._loadConfig(iter(config.items())) self._loadConfig(config.items())
return [k[:-5] if k.endswith(':json') else k return [k[:-5] if k.endswith(':json') else k
for k in chain(remove, (k for k in chain(remove, (k
for k, v in config.items() for k, v in config.items()
...@@ -229,10 +231,9 @@ class Cache(object): ...@@ -229,10 +231,9 @@ class Cache(object):
" WHERE prefix=peer AND prefix!=? AND try=?" " WHERE prefix=peer AND prefix!=? AND try=?"
def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address" def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address"
+ " ORDER BY RANDOM()"): + " ORDER BY RANDOM()"):
#return self._db.execute(__sql, (self._prefix, failed)) return self._db.execute(__sql, (self._prefix, failed))
r = self._db.execute(__sql, (self._prefix, failed))
return r def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)") -> int:
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)"):
return self._db.execute(__sql, (self._prefix, failed)).next()[0] return self._db.execute(__sql, (self._prefix, failed)).next()[0]
def getBootstrapPeer(self): def getBootstrapPeer(self):
......
#!/usr/bin/python2 #!/usr/bin/env python3
import argparse, atexit, binascii, errno, hashlib import argparse, atexit, binascii, errno, hashlib
import os, subprocess, sqlite3, sys, time import os, subprocess, sqlite3, sys, time
from OpenSSL import crypto from OpenSSL import crypto
...@@ -13,7 +13,7 @@ def create(path, text=None, mode=0o666): ...@@ -13,7 +13,7 @@ def create(path, text=None, mode=0o666):
finally: finally:
os.close(fd) os.close(fd)
def loadCert(pem): def loadCert(pem: bytes):
return crypto.load_certificate(crypto.FILETYPE_PEM, pem) return crypto.load_certificate(crypto.FILETYPE_PEM, pem)
def main(): def main():
...@@ -91,8 +91,7 @@ def main(): ...@@ -91,8 +91,7 @@ def main():
try: try:
with open(cert_path) as f: with open(cert_path) as f:
cert = loadCert(f.read()) cert = loadCert(f.read())
components = dict(cert.get_subject().get_components()) components = {k.decode(): v for k, v in cert.get_subject().get_components()}
components = {k.decode(): v for k, v in components.items()}
for k in reserved: for k in reserved:
components.pop(k, None) components.pop(k, None)
except IOError as e: except IOError as e:
...@@ -140,7 +139,7 @@ def main(): ...@@ -140,7 +139,7 @@ def main():
req.set_pubkey(pkey) req.set_pubkey(pkey)
req.sign(pkey, 'sha512') req.sign(pkey, 'sha512')
req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req) req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req).decode()
# First make sure we can open certificate file for writing, # First make sure we can open certificate file for writing,
# to avoid using our token for nothing. # to avoid using our token for nothing.
...@@ -165,13 +164,13 @@ def main(): ...@@ -165,13 +164,13 @@ def main():
cert = loadCert(cert) cert = loadCert(cert)
not_after = x509.notAfter(cert) not_after = x509.notAfter(cert)
print(("Setup complete. Certificate is valid until %s UTC" print("Setup complete. Certificate is valid until %s UTC"
" and will be automatically renewed after %s UTC.\n" " and will be automatically renewed after %s UTC.\n"
"Do not forget to backup to your private key (%s) or" "Do not forget to backup to your private key (%s) or"
" you will lose your assigned subnet." % ( " you will lose your assigned subnet." % (
time.asctime(time.gmtime(not_after)), time.asctime(time.gmtime(not_after)),
time.asctime(time.gmtime(not_after - registry.RENEW_PERIOD)), time.asctime(time.gmtime(not_after - registry.RENEW_PERIOD)),
key_path))) key_path))
if not os.path.lexists(conf_path): if not os.path.lexists(conf_path):
create(conf_path, ("""\ create(conf_path, ("""\
...@@ -188,13 +187,13 @@ key %s ...@@ -188,13 +187,13 @@ key %s
#O--verb #O--verb
#O3 #O3
""" % (config.registry, ca_path, cert_path, key_path, """ % (config.registry, ca_path, cert_path, key_path,
('country ' + config.location.split(',', 1)[0]) \ ('country ' + config.location.split(',', 1)[0])
if config.location else '')).encode()) if config.location else '')).encode())
print("Sample configuration file created.") print("Sample configuration file created.")
cn = x509.subnetFromCert(cert) cn = x509.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn) subnet = network + utils.binFromSubnet(cn)
print("Your subnet: %s/%u (CN=%s)" \ print("Your subnet: %s/%u (CN=%s)"
% (utils.ipFromBin(subnet), len(subnet), cn)) % (utils.ipFromBin(subnet), len(subnet), cn))
if __name__ == "__main__": if __name__ == "__main__":
......
#!/usr/bin/python2 #!/usr/bin/env python3
import atexit, errno, logging, os, shutil, signal import atexit, errno, logging, os, shutil, signal
import socket, struct, subprocess, sys import socket, struct, subprocess, sys
from collections import deque from collections import deque
...@@ -256,10 +256,10 @@ def main(): ...@@ -256,10 +256,10 @@ def main():
forwarder.addRule(port, proto) forwarder.addRule(port, proto)
address.append(forwarder.checkExternalIp()) address.append(forwarder.checkExternalIp())
elif 'any' not in ipv4: elif 'any' not in ipv4:
address += list(map(ip_changed, ipv4)) address += map(ip_changed, ipv4)
ipv4_any = () ipv4_any = ()
if ipv6: if ipv6:
address += list(map(ip_changed, ipv6)) address += map(ip_changed, ipv6)
ipv6_any = () ipv6_any = ()
else: else:
ip_changed = remote_gateway = None ip_changed = remote_gateway = None
......
#!/usr/bin/python2 #!/usr/bin/env python3
import http.client, logging, os, socket, sys import http.client, logging, os, socket, sys
from http.server import BaseHTTPRequestHandler from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingTCPServer from socketserver import ThreadingTCPServer
...@@ -29,13 +29,13 @@ class RequestHandler(BaseHTTPRequestHandler): ...@@ -29,13 +29,13 @@ class RequestHandler(BaseHTTPRequestHandler):
path = self.path path = self.path
query = {} query = {}
else: else:
query = dict(parse_qsl(query, keep_blank_values=1, query = dict(parse_qsl(query, keep_blank_values=True,
strict_parsing=1)) strict_parsing=True))
_, path = path.split('/') _, path = path.split('/')
if not _: if not _:
return self.server.handle_request(self, path, query) return self.server.handle_request(self, path, query)
except Exception: except Exception:
logging.info(self.requestline, exc_info=1) logging.info(self.requestline, exc_info=True)
self.send_error(http.client.BAD_REQUEST) self.send_error(http.client.BAD_REQUEST)
def log_error(*args): def log_error(*args):
......
...@@ -34,13 +34,13 @@ class Array(object): ...@@ -34,13 +34,13 @@ class Array(object):
def __init__(self, item): def __init__(self, item):
self._item = item self._item = item
def encode(self, buffer, value): def encode(self, buffer: bytes, value: list):
buffer += uint16.pack(len(value)) buffer += uint16.pack(len(value))
encode = self._item.encode encode = self._item.encode
for value in value: for value in value:
encode(buffer, value) encode(buffer, value)
def decode(self, buffer, offset=0): def decode(self, buffer: bytes, offset=0) -> tuple[int, list]:
r = [] r = []
o = offset + 2 o = offset + 2
decode = self._item.decode decode = self._item.decode
...@@ -52,13 +52,13 @@ class Array(object): ...@@ -52,13 +52,13 @@ class Array(object):
class String(object): class String(object):
@staticmethod @staticmethod
def encode(buffer, value): def encode(buffer: bytes, value: str):
buffer += value + b'\x00' buffer += value.encode("utf-8") + b'\x00'
@staticmethod @staticmethod
def decode(buffer, offset=0): def decode(buffer: bytes, offset=0) -> tuple[int, str]:
i = buffer.index(0, offset) i = buffer.index(0, offset)
return i + 1, buffer[offset:i] return i + 1, buffer[offset:i].decode("utf-8")
class Buffer(object): class Buffer(object):
...@@ -69,7 +69,7 @@ class Buffer(object): ...@@ -69,7 +69,7 @@ class Buffer(object):
def __iadd__(self, value): def __iadd__(self, value):
self._buf.extend(value) self._buf += value
return self return self
def __len__(self): def __len__(self):
...@@ -195,7 +195,7 @@ class Babel(object): ...@@ -195,7 +195,7 @@ class Babel(object):
logging.debug("Can't connect to %r (%r)", self.socket_path, e) logging.debug("Can't connect to %r (%r)", self.socket_path, e)
return e return e
s.send(b'\x01') s.send(b'\x01')
s.setblocking(0) s.setblocking(False)
del self.select del self.select
self.socket = s self.socket = s
return self.select(*args) return self.select(*args)
......
...@@ -38,8 +38,7 @@ class Socket(object): ...@@ -38,8 +38,7 @@ class Socket(object):
self._socket.recv(0) self._socket.recv(0)
return True return True
except socket.error as e: except socket.error as e:
(err, _) = e if e.errno != errno.EAGAIN:
if err != errno.EAGAIN:
raise raise
self._socket.setblocking(1) self._socket.setblocking(1)
return False return False
......
#!/usr/bin/python -S #!/usr/bin/env -S python3 -S
import os, sys import os, sys
script_type = os.environ['script_type'] script_type = os.environ['script_type']
...@@ -13,7 +13,5 @@ if script_type == 'up': ...@@ -13,7 +13,5 @@ if script_type == 'up':
if script_type == 'route-up': if script_type == 'route-up':
import time import time
with open('/opt/openvpn_route_up.log', 'w+') as f:
f.write(repr(sys.argv))
os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(), os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(),
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'])).encode()) int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'])).encode())
#!/usr/bin/python -S #!/usr/bin/env -S python3 -S
import os, sys import os, sys
script_type = os.environ['script_type'] script_type = os.environ['script_type']
...@@ -7,7 +7,7 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6'] ...@@ -7,7 +7,7 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
# Write into pipe connect/disconnect events # Write into pipe connect/disconnect events
fd = int(sys.argv[1]) fd = int(sys.argv[1])
os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'], os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'],
int(os.environ['tls_serial_0']), external_ip)))) int(os.environ['tls_serial_0']), external_ip))).encode("utf-8"))
if script_type == 'client-connect': if script_type == 'client-connect':
if os.read(fd, 1) == b'\x00': if os.read(fd, 1) == b'\x00':
......
import binascii
import logging, errno, os import logging, errno, os
from typing import Optional
from . import utils from . import utils
here = os.path.realpath(os.path.dirname(__file__)) here = os.path.realpath(os.path.dirname(__file__))
ovpn_server = os.path.join(here, 'ovpn-server') ovpn_server = os.path.join(here, 'ovpn-server')
ovpn_client = os.path.join(here, 'ovpn-client') ovpn_client = os.path.join(here, 'ovpn-client')
ovpn_log = None ovpn_log: Optional[str] = None
def openvpn(iface, encrypt, *args, **kw): def openvpn(iface, encrypt, *args, **kw):
args = ['openvpn', args = ['openvpn',
...@@ -80,9 +82,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile, ...@@ -80,9 +82,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C', 'redistribute local deny', '-C', 'redistribute local deny',
'-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)] '-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)]
if hmac_sign: if hmac_sign:
def key(cmd, id, value): def key(cmd, id: str, value):
cmd += '-C', ('key type blake2s128 id %s value %s' % cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, value.encode('hex'))) (id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign) key(cmd, 'sign', hmac_sign)
default += ' key sign' default += ' key sign'
if hmac_accept is not None: if hmac_accept is not None:
......
...@@ -91,7 +91,7 @@ class RegistryServer(object): ...@@ -91,7 +91,7 @@ class RegistryServer(object):
"name TEXT PRIMARY KEY NOT NULL", "name TEXT PRIMARY KEY NOT NULL",
"value") "value")
self.prefix = self.getConfig("prefix", None) self.prefix = self.getConfig("prefix", None)
self.version = str(self.getConfig("version", b'\x00')) # BBB: blob self.version = self.getConfig("version", b'\x00')
utils.sqliteCreateTable(self.db, "token", utils.sqliteCreateTable(self.db, "token",
"token TEXT PRIMARY KEY NOT NULL", "token TEXT PRIMARY KEY NOT NULL",
"email TEXT NOT NULL", "email TEXT NOT NULL",
...@@ -189,15 +189,15 @@ class RegistryServer(object): ...@@ -189,15 +189,15 @@ class RegistryServer(object):
self.sendto(self.prefix, 0) self.sendto(self.prefix, 0)
# The following entry lists values that are base64-encoded. # The following entry lists values that are base64-encoded.
kw[''] = 'version', kw[''] = 'version',
kw['version'] = base64.b64encode(self.version) kw['version'] = base64.b64encode(self.version).decode()
self.network_config = kw self.network_config = kw
def increaseVersion(self): def increaseVersion(self):
x = utils.packInteger(1 + utils.unpackInteger(self.version)[0:1]) x = utils.packInteger(1 + utils.unpackInteger(self.version)[0])
self.version = x + self.cert.sign(x) self.version = x + self.cert.sign(x)
def sendto(self, prefix, code): def sendto(self, prefix: str, code: int):
self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT)) self.sock.sendto(prefix.encode() + bytes((0, code)), ('::1', tunnel.PORT))
def recv(self, code): def recv(self, code):
try: try:
...@@ -314,9 +314,11 @@ class RegistryServer(object): ...@@ -314,9 +314,11 @@ class RegistryServer(object):
except HTTPError as e: except HTTPError as e:
return request.send_error(*e.args) return request.send_error(*e.args)
except: except:
logging.warning(request.requestline, exc_info=1) logging.warning(request.requestline, exc_info=True)
return request.send_error(http.client.INTERNAL_SERVER_ERROR) return request.send_error(http.client.INTERNAL_SERVER_ERROR)
if result: if result:
if type(result) is str:
result = result.encode("utf-8")
request.send_response(http.client.OK) request.send_response(http.client.OK)
request.send_header("Content-Length", str(len(result))) request.send_header("Content-Length", str(len(result)))
else: else:
...@@ -432,9 +434,9 @@ class RegistryServer(object): ...@@ -432,9 +434,9 @@ class RegistryServer(object):
prev_prefix = None prev_prefix = None
max_len = 128, max_len = 128,
while True: while True:
max_len = next(q("SELECT max(length(prefix)) FROM cert" max_len = q("SELECT max(length(prefix)) FROM cert"
" WHERE cert is null AND length(prefix) < ?", " WHERE cert is null AND length(prefix) < ?",
max_len)) max_len).fetchone()
if not max_len[0]: if not max_len[0]:
break break
for prefix, in q("SELECT prefix FROM cert" for prefix, in q("SELECT prefix FROM cert"
...@@ -593,8 +595,8 @@ class RegistryServer(object): ...@@ -593,8 +595,8 @@ class RegistryServer(object):
hmac = [self.getConfig(k, None) for k in BABEL_HMAC] hmac = [self.getConfig(k, None) for k in BABEL_HMAC]
for i, v in enumerate(v for v in hmac if v is not None): for i, v in enumerate(v for v in hmac if v is not None):
config[('babel_hmac_sign', 'babel_hmac_accept')[i]] = \ config[('babel_hmac_sign', 'babel_hmac_accept')[i]] = \
v and base64.b64encode(x509.encrypt(cert, v)) v and base64.b64encode(x509.encrypt(cert, v)).decode()
return zlib.compress(json.dumps(config)) return zlib.compress(json.dumps(config).encode("utf-8"))
def _queryAddress(self, peer): def _queryAddress(self, peer):
self.sendto(peer, 1) self.sendto(peer, 1)
...@@ -800,7 +802,7 @@ class RegistryClient(object): ...@@ -800,7 +802,7 @@ class RegistryClient(object):
_hmac = None _hmac = None
user_agent = "re6stnet/%s, %s" % (version.version, platform.platform()) user_agent = "re6stnet/%s, %s" % (version.version, platform.platform())
def __init__(self, url, cert=None, auto_close=True): def __init__(self, url, cert: x509.Cert=None, auto_close=True):
self.cert = cert self.cert = cert
self.auto_close = auto_close self.auto_close = auto_close
url_parsed = urlparse(url) url_parsed = urlparse(url)
...@@ -812,12 +814,12 @@ class RegistryClient(object): ...@@ -812,12 +814,12 @@ class RegistryClient(object):
def __getattr__(self, name): def __getattr__(self, name):
getcallargs = getattr(RegistryServer, name).getcallargs getcallargs = getattr(RegistryServer, name).getcallargs
def rpc(*args, **kw): def rpc(*args, **kw) -> bytes:
kw = getcallargs(*args, **kw) kw = getcallargs(*args, **kw)
query = '/' + name query = '/' + name
if kw: if kw:
if any(type(v) is not str for v in kw.values()): if any(type(v) is not str for v in kw.values()):
raise TypeError raise TypeError(kw)
query += '?' + urlencode(kw) query += '?' + urlencode(kw)
url = self._path + query url = self._path + query
client_prefix = kw.get('cn') client_prefix = kw.get('cn')
...@@ -862,7 +864,7 @@ class RegistryClient(object): ...@@ -862,7 +864,7 @@ class RegistryClient(object):
except HTTPError: except HTTPError:
raise raise
except Exception: except Exception:
logging.info(url, exc_info=1) logging.info(url, exc_info=True)
else: else:
logging.info('%s\nUnexpected response %s %s', logging.info('%s\nUnexpected response %s %s',
url, response.status, response.reason) url, response.status, response.reason)
......
from pathlib2 import Path from pathlib import Path
DEMO_PATH = Path(__file__).resolve().parent.parent.parent / "demo" DEMO_PATH = Path(__file__).resolve().parent.parent.parent / "demo"
...@@ -60,7 +60,7 @@ class TestRegistryClientInteract(unittest.TestCase): ...@@ -60,7 +60,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# read token from db # read token from db
db = sqlite3.connect(str(self.server.db), isolation_level=None) db = sqlite3.connect(str(self.server.db), isolation_level=None)
token = None token = None
for _ in xrange(100): for _ in range(100):
time.sleep(.1) time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?", token = db.execute("SELECT token FROM token WHERE email=?",
(email,)).fetchone() (email,)).fetchone()
......
...@@ -4,7 +4,7 @@ import nemu ...@@ -4,7 +4,7 @@ import nemu
import time import time
import weakref import weakref
from subprocess import PIPE from subprocess import PIPE
from pathlib2 import Path from pathlib import Path
from re6st.tests import DEMO_PATH from re6st.tests import DEMO_PATH
...@@ -60,7 +60,7 @@ class NetManager(object): ...@@ -60,7 +60,7 @@ class NetManager(object):
Raise: Raise:
AssertionError AssertionError
""" """
for reg, nodes in self.registries.iteritems(): for reg, nodes in self.registries.items():
for node in nodes: for node in nodes:
app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE) app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE)
ret = app0.wait() ret = app0.wait()
......
...@@ -6,13 +6,15 @@ import ipaddress ...@@ -6,13 +6,15 @@ import ipaddress
import json import json
import logging import logging
import re import re
import shlex
import shutil import shutil
import sqlite3 import sqlite3
import sys
import tempfile import tempfile
import time import time
import weakref import weakref
from subprocess import PIPE from subprocess import PIPE
from pathlib2 import Path from pathlib import Path
from re6st.tests import tools from re6st.tests import tools
from re6st.tests import DEMO_PATH from re6st.tests import DEMO_PATH
...@@ -20,9 +22,10 @@ from re6st.tests import DEMO_PATH ...@@ -20,9 +22,10 @@ from re6st.tests import DEMO_PATH
WORK_DIR = Path(__file__).parent / "temp_net_test" WORK_DIR = Path(__file__).parent / "temp_net_test"
DH_FILE = DEMO_PATH / "dh2048.pem" DH_FILE = DEMO_PATH / "dh2048.pem"
RE6STNET = "python -m re6st.cli.node" PYTHON = shlex.quote(sys.executable)
RE6ST_REGISTRY = "python -m re6st.cli.registry" RE6STNET = PYTHON + " -m re6st.cli.node"
RE6ST_CONF = "python -m re6st.cli.conf" RE6ST_REGISTRY = PYTHON + " -m re6st.cli.registry"
RE6ST_CONF = PYTHON + " -m re6st.cli.conf"
def initial(): def initial():
"""create the workplace""" """create the workplace"""
...@@ -72,7 +75,7 @@ class Re6stRegistry(object): ...@@ -72,7 +75,7 @@ class Re6stRegistry(object):
self.run() self.run()
# wait the servcice started # wait the servcice started
p = self.node.Popen(['python', '-c', """if 1: p = self.node.Popen([sys.executable, '-c', """if 1:
import socket, time import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True: while True:
...@@ -115,7 +118,7 @@ class Re6stRegistry(object): ...@@ -115,7 +118,7 @@ class Re6stRegistry(object):
'--client-count', (self.client_number+1)//2, '--port', self.port] '--client-count', (self.client_number+1)//2, '--port', self.port]
#PY3: convert PosixPath to str, can be remove in Python 3 #PY3: convert PosixPath to str, can be remove in Python 3
cmd = map(str, cmd) cmd = list(map(str, cmd))
cmd[:0] = RE6ST_REGISTRY.split() cmd[:0] = RE6ST_REGISTRY.split()
...@@ -210,7 +213,7 @@ class Re6stNode(object): ...@@ -210,7 +213,7 @@ class Re6stNode(object):
# read token # read token
db = sqlite3.connect(str(self.registry.db), isolation_level=None) db = sqlite3.connect(str(self.registry.db), isolation_level=None)
token = None token = None
for _ in xrange(100): for _ in range(100):
time.sleep(.1) time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?", token = db.execute("SELECT token FROM token WHERE email=?",
(self.email,)).fetchone() (self.email,)).fetchone()
...@@ -223,7 +226,7 @@ class Re6stNode(object): ...@@ -223,7 +226,7 @@ class Re6stNode(object):
out, _ = p.communicate(str(token[0])) out, _ = p.communicate(str(token[0]))
# logging.debug("re6st-conf output: {}".format(out)) # logging.debug("re6st-conf output: {}".format(out))
# find the ipv6 subnet of node # find the ipv6 subnet of node
self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out).group(0) self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out.decode("utf-8")).group(0)
data = {'ip6': self.ip6, 'hash': self.registry.ident} data = {'ip6': self.ip6, 'hash': self.registry.ident}
with open(str(self.data_file), 'w') as f: with open(str(self.data_file), 'w') as f:
json.dump(data, f) json.dump(data, f)
...@@ -236,7 +239,7 @@ class Re6stNode(object): ...@@ -236,7 +239,7 @@ class Re6stNode(object):
'--key', self.key, '-v4', '--registry', self.registry.url, '--key', self.key, '-v4', '--registry', self.registry.url,
'--console', self.console] '--console', self.console]
#PY3: same as for Re6stRegistry.run #PY3: same as for Re6stRegistry.run
cmd = map(str, cmd) cmd = list(map(str, cmd))
cmd[:0] = RE6STNET.split() cmd[:0] = RE6STNET.split()
cmd += args cmd += args
......
"""contain ping-test for re6set net""" """contain ping-test for re6set net"""
import os import os
import sys
import unittest import unittest
import time import time
import psutil import psutil
import logging import logging
import random import random
from pathlib2 import Path from pathlib import Path
import network_build import network_build
import re6st_wrap import re6st_wrap
...@@ -47,12 +48,12 @@ def wait_stable(nodes, timeout=240): ...@@ -47,12 +48,12 @@ def wait_stable(nodes, timeout=240):
for node in nodes: for node in nodes:
sub_ips = set(ips) - {node.ip6} sub_ips = set(ips) - {node.ip6}
node.ping_proc = node.node.Popen( node.ping_proc = node.node.Popen(
["python", PING_PATH, '--retry', '-a'] + list(sub_ips)) [sys.executable, PING_PATH, '--retry', '-a'] + list(sub_ips), env=os.environ)
# check all the node network can ping each other, in order reverse # check all the node network can ping each other, in order reverse
unfinished = list(nodes) unfinished = list(nodes)
while unfinished: while unfinished:
for i in xrange(len(unfinished)-1, -1, -1): for i in range(len(unfinished)-1, -1, -1):
node = unfinished[i] node = unfinished[i]
if node.ping_proc.poll() is not None: if node.ping_proc.poll() is not None:
logging.debug("%s 's network is stable", node.name) logging.debug("%s 's network is stable", node.name)
......
#!/usr/bin/python2 #!/usr/bin/env python3
""" unit test for re6st-conf """ unit test for re6st-conf
""" """
...@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase):
# mocked server cert and pkey # mocked server cert and pkey
cls.pkey, cls.cert = create_ca_file(os.devnull, os.devnull) cls.pkey, cls.cert = create_ca_file(os.devnull, os.devnull)
cls.fingerprint = "".join( cls.cert.digest("sha1").split(":")) cls.fingerprint = "".join( cls.cert.digest("sha1").decode().split(":"))
# client.getCa should return a string form cert # client.getCa should return a string form cert
cls.cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cls.cert) cls.cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cls.cert)
......
...@@ -13,12 +13,13 @@ import tempfile ...@@ -13,12 +13,13 @@ import tempfile
from argparse import Namespace from argparse import Namespace
from OpenSSL import crypto from OpenSSL import crypto
from mock import Mock, patch from mock import Mock, patch
from pathlib2 import Path from pathlib import Path
from re6st import registry from re6st import registry
from re6st.tests.tools import * from re6st.tests.tools import *
from re6st.tests import DEMO_PATH from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer # TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions # getIPV4Information, versions
...@@ -49,6 +50,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None): ...@@ -49,6 +50,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
insert_cert.serial += 1 insert_cert.serial += 1
return key, cert return key, cert
insert_cert.serial = 0 insert_cert.serial = 0
...@@ -77,17 +79,26 @@ class TestRegistryServer(unittest.TestCase): ...@@ -77,17 +79,26 @@ class TestRegistryServer(unittest.TestCase):
def setUp(self): def setUp(self):
self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \ self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \
+ "@mail.com" + "@mail.com"
def test_recv(self): def test_recv(self):
recv = self.server.sock.recv = Mock() side_effect = iter([
recv.side_effect = [
"0001001001001a_msg", "0001001001001a_msg",
"0001001001002\0001dqdq", "0001001001002\0001dqdq",
"0001001001001\000a_msg", "0001001001001\000a_msg",
"0001001001001\000\4a_msg", "0001001001001\000\4a_msg",
"0000000000000\0" # ERROR, IndexError: msg is null "0000000000000\0" # ERROR, IndexError: msg is null
] ])
class SocketProxy:
def __init__(self, wrappee):
self.wrappee = wrappee
self.recv = lambda _: next(side_effect)
def __getattr__(self, attr):
return getattr(self.wrappee, attr)
self.server.sock = SocketProxy(self.server.sock)
try: try:
res1 = self.server.recv(4) res1 = self.server.recv(4)
...@@ -115,7 +126,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -115,7 +126,7 @@ class TestRegistryServer(unittest.TestCase):
now = int(time.time()) - self.config.grace_period + 20 now = int(time.time()) - self.config.grace_period + 20
# makeup data # makeup data
insert_cert(cur, self.server.cert, prefix_old, 1) insert_cert(cur, self.server.cert, prefix_old, 1)
insert_cert(cur, self.server.cert, prefix, now -1) insert_cert(cur, self.server.cert, prefix, now - 1)
cur.execute("INSERT INTO token VALUES (?,?,?,?)", cur.execute("INSERT INTO token VALUES (?,?,?,?)",
(token_old, self.email, 4, 2)) (token_old, self.email, 4, 2))
cur.execute("INSERT INTO token VALUES (?,?,?,?)", cur.execute("INSERT INTO token VALUES (?,?,?,?)",
...@@ -143,7 +154,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -143,7 +154,7 @@ class TestRegistryServer(unittest.TestCase):
prefix = "0000000011111111" prefix = "0000000011111111"
method = "func" method = "func"
protocol = 7 protocol = 7
params = {"cn" : prefix, "a" : 1, "b" : 2} params = {"cn": prefix, "a": 1, "b": 2}
func.getcallargs.return_value = params func.getcallargs.return_value = params
del func._private del func._private
func.return_value = result = b"this_is_a_result" func.return_value = result = b"this_is_a_result"
...@@ -176,12 +187,12 @@ class TestRegistryServer(unittest.TestCase): ...@@ -176,12 +187,12 @@ class TestRegistryServer(unittest.TestCase):
def test_handle_request_private(self, func): def test_handle_request_private(self, func):
"""case request with _private attr""" """case request with _private attr"""
method = "func" method = "func"
params = {"a" : 1, "b" : 2} params = {"a": 1, "b": 2}
func.getcallargs.return_value = params func.getcallargs.return_value = params
func.return_value = None func.return_value = None
request_good = Mock() request_good = Mock()
request_good.client_address = self.config.authorized_origin request_good.client_address = self.config.authorized_origin
request_good.headers = {'X-Forwarded-For':self.config.authorized_origin[0]} request_good.headers = {'X-Forwarded-For': self.config.authorized_origin[0]}
request_bad = Mock() request_bad = Mock()
request_bad.client_address = ["wrong_address"] request_bad.client_address = ["wrong_address"]
...@@ -282,7 +293,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -282,7 +293,7 @@ class TestRegistryServer(unittest.TestCase):
nb_less = 0 nb_less = 0
for cert in self.server.iterCert(): for cert in self.server.iterCert():
s = cert[0].get_subject().serialNumber s = cert[0].get_subject().serialNumber
if(s and int(s) <= serial): if s and int(s) <= serial:
nb_less += 1 nb_less += 1
self.assertEqual(nb_less, serial) self.assertEqual(nb_less, serial)
...@@ -378,7 +389,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -378,7 +389,7 @@ class TestRegistryServer(unittest.TestCase):
hmacs = get_hmac() hmacs = get_hmac()
key_1 = hmacs[1] key_1 = hmacs[1]
self.assertEqual(hmacs, [None, key_1, '']) self.assertEqual(hmacs, [None, key_1, b''])
# step 2 # step 2
self.server.updateHMAC() self.server.updateHMAC()
...@@ -402,7 +413,6 @@ class TestRegistryServer(unittest.TestCase): ...@@ -402,7 +413,6 @@ class TestRegistryServer(unittest.TestCase):
self.assertEqual(get_hmac(), [key_2, None, None]) self.assertEqual(get_hmac(), [key_2, None, None])
def test_getNodePrefix(self): def test_getNodePrefix(self):
# prefix in short format # prefix in short format
prefix = "0000000101" prefix = "0000000101"
...@@ -426,19 +436,33 @@ class TestRegistryServer(unittest.TestCase): ...@@ -426,19 +436,33 @@ class TestRegistryServer(unittest.TestCase):
('0000000000000001', '2 0/16 6/16') ('0000000000000001', '2 0/16 6/16')
] ]
recv.side_effect = recv_case recv.side_effect = recv_case
def side_effct(rlist, wlist, elist, timeout): def side_effct(rlist, wlist, elist, timeout):
# rlist is true until the len(recv_case)th call # rlist is true until the len(recv_case)th call
side_effct.i -= side_effct.i > 0 side_effct.i -= side_effct.i > 0
return [side_effct.i, wlist, None] return [side_effct.i, wlist, None]
side_effct.i = len(recv_case) + 1 side_effct.i = len(recv_case) + 1
select.side_effect = side_effct select.side_effect = side_effct
res = self.server.topology() res = self.server.topology()
expect_res = '{"36893488147419103232/80": ["0/16", "7/16"], ' \ class CustomDecoder(json.JSONDecoder):
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], ' \ def __init__(self, **kwargs):
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '\ json.JSONDecoder.__init__(self, **kwargs)
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}''' self.parse_array = self.JSONArray
self.scan_once = json.scanner.py_make_scanner(self)
def JSONArray(self, s_and_end, scan_once, **kwargs):
values, end = json.decoder.JSONArray(s_and_end, scan_once, **kwargs)
return set(values), end
res = json.loads(res, cls=CustomDecoder)
expect_res = {"36893488147419103232/80": {"0/16", "7/16"},
"": {"36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"}, "4/16": {"0/16"},
"3/16": {"0/16", "7/16"}, "0/16": {"6/16", "7/16"}, "1/16": {"6/16", "0/16"},
"7/16": {"6/16", "4/16"}}
self.assertEqual(res, expect_res) self.assertEqual(res, expect_res)
......
...@@ -52,9 +52,9 @@ class TestRegistryClient(unittest.TestCase): ...@@ -52,9 +52,9 @@ class TestRegistryClient(unittest.TestCase):
self.client._hmac = None self.client._hmac = None
self.client.hello = Mock(return_value = "aaabbb") self.client.hello = Mock(return_value = "aaabbb")
self.client.cert = Mock() self.client.cert = Mock()
key = "this_is_a_key" key = b"this_is_a_key"
self.client.cert.decrypt.return_value = key self.client.cert.decrypt.return_value = key
h = hmac.HMAC(key, query, hashlib.sha1).digest() h = hmac.HMAC(key, query.encode(), hashlib.sha1).digest()
key = hashlib.sha1(key).digest() key = hashlib.sha1(key).digest()
# response part # response part
body = b'this is a body' body = b'this is a body'
......
#!/usr/bin/python2 #!/usr/bin/env python3
import os import os
import sys import sys
import unittest import unittest
...@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase):
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True) # @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# def test_processPacket_address_with_msg_peer(self, makeTunnel): # def test_processPacket_address_with_msg_peer(self, makeTunnel):
# """code is 1, peer and msg not none """ # """code is 1, peer and msg not none """
# c = chr(1) # c = b"\x01"
# msg = "address" # msg = "address"
# peer = x509.Peer("000001") # peer = x509.Peer("000001")
# self.tunnel._connecting = {peer} # self.tunnel._connecting = {peer}
...@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase):
def test_processPacket_address(self): def test_processPacket_address(self):
"""code is 1, for address. And peer or msg are none""" """code is 1, for address. And peer or msg are none"""
c = chr(1) c = b"\x01"
self.tunnel._address = {1: "1,1", 2: "2,2"} self.tunnel._address = {1: "1,1", 2: "2,2"}
res = self.tunnel._processPacket(c) res = self.tunnel._processPacket(c)
...@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase):
and each address join by ; and each address join by ;
it will truncate address which has more than 3 element it will truncate address which has more than 3 element
""" """
c = chr(1) c = b"\x01"
peer = x509.Peer("000001") peer = x509.Peer("000001")
peer.protocol = 1 peer.protocol = 1
self.tunnel._peers.append(peer) self.tunnel._peers.append(peer)
...@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase): ...@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase):
"""code is 0, for network version, peer is not none """code is 0, for network version, peer is not none
2 case, one modify the version, one not 2 case, one modify the version, one not
""" """
c = chr(0) c = b"\x00"
peer = x509.Peer("000001") peer = x509.Peer("000001")
version1 = "00003" version1 = b"00003"
version2 = "00007" version2 = b"00007"
self.tunnel._version = version3 = "00005" self.tunnel._version = version3 = b"00005"
self.tunnel._peers.append(peer) self.tunnel._peers.append(peer)
res = self.tunnel._processPacket(c + version1, peer) res = self.tunnel._processPacket(c + version1, peer)
......
#!/usr/bin/python2 #!/usr/bin/env python3
import os import os
import sys import sys
import unittest import unittest
......
...@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
return return
crypto.X509Cert in pem format crypto.X509Cert in pem format
""" """
if type(ca) is str: if type(ca) is bytes:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca) ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca)
if type(ca_key) is str: if type(ca_key) is bytes:
ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key) ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr) req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
...@@ -56,10 +56,10 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -56,10 +56,10 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial): def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial):
pkey, csr = generate_csr() pkey, csr = generate_csr()
cert = generate_cert(ca, ca_key, csr, prefix, serial) cert = generate_cert(ca, ca_key, csr, prefix, serial)
with open(pkey_file, 'w') as f: with open(pkey_file, 'wb') as f:
f.write(pkey.decode()) f.write(pkey)
with open(cert_file, 'w') as f: with open(cert_file, 'wb') as f:
f.write(cert.decode()) f.write(cert)
return pkey, cert return pkey, cert
...@@ -84,9 +84,9 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042): ...@@ -84,9 +84,9 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
cert.set_pubkey(key) cert.set_pubkey(key)
cert.sign(key, "sha512") cert.sign(key, "sha512")
with open(pkey_file, 'w') as pkey_file: with open(pkey_file, 'wb') as pkey_file:
pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key)) pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key))
with open(cert_file, 'w') as cert_file: with open(cert_file, 'wb') as cert_file:
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
return key, cert return key, cert
......
...@@ -249,7 +249,7 @@ class BaseTunnelManager(object): ...@@ -249,7 +249,7 @@ class BaseTunnelManager(object):
self._address = {family: utils.dump_address(address) self._address = {family: utils.dump_address(address)
for family, address in address_dict.items() for family, address in address_dict.items()
if address} if address}
cache.my_address = ';'.join(iter(self._address.values())) cache.my_address = ';'.join(self._address.values())
self.sock = socket.socket(socket.AF_INET6, self.sock = socket.socket(socket.AF_INET6,
socket.SOCK_DGRAM | socket.SOCK_CLOEXEC) socket.SOCK_DGRAM | socket.SOCK_CLOEXEC)
...@@ -329,7 +329,7 @@ class BaseTunnelManager(object): ...@@ -329,7 +329,7 @@ class BaseTunnelManager(object):
def _getPeer(self, prefix): def _getPeer(self, prefix):
return self._peers[bisect(self._peers, prefix) - 1] return self._peers[bisect(self._peers, prefix) - 1]
def sendto(self, prefix, msg): def sendto(self, prefix: str, msg):
to = utils.ipFromBin(self._network + prefix), PORT to = utils.ipFromBin(self._network + prefix), PORT
peer = self._getPeer(prefix) peer = self._getPeer(prefix)
if peer.prefix != prefix: if peer.prefix != prefix:
...@@ -344,6 +344,8 @@ class BaseTunnelManager(object): ...@@ -344,6 +344,8 @@ class BaseTunnelManager(object):
peer.hello0Sent() peer.hello0Sent()
def _sendto(self, to, msg, peer=None): def _sendto(self, to, msg, peer=None):
if type(msg) is str:
msg = msg.encode()
try: try:
r = self.sock.sendto(peer.encode(msg) if peer else msg, to) r = self.sock.sendto(peer.encode(msg) if peer else msg, to)
except socket.error as e: except socket.error as e:
...@@ -360,6 +362,7 @@ class BaseTunnelManager(object): ...@@ -360,6 +362,7 @@ class BaseTunnelManager(object):
if address[0] == '::1': if address[0] == '::1':
try: try:
prefix, msg = msg.split(b'\0', 1) prefix, msg = msg.split(b'\0', 1)
prefix = prefix.decode()
int(prefix, 2) int(prefix, 2)
except ValueError: except ValueError:
return return
...@@ -371,7 +374,7 @@ class BaseTunnelManager(object): ...@@ -371,7 +374,7 @@ class BaseTunnelManager(object):
if msg: if msg:
self._sendto(to, '%s\0%c%s' % (prefix, code, msg)) self._sendto(to, '%s\0%c%s' % (prefix, code, msg))
else: else:
self.sendto(prefix, chr(code | 0x80) + msg[1:]) self.sendto(prefix, bytes([code | 0x80]) + msg[1:])
return return
try: try:
sender = utils.binFromIp(address[0]) sender = utils.binFromIp(address[0])
...@@ -384,7 +387,7 @@ class BaseTunnelManager(object): ...@@ -384,7 +387,7 @@ class BaseTunnelManager(object):
msg = peer.decode(msg) msg = peer.decode(msg)
if type(msg) is tuple: if type(msg) is tuple:
seqno, msg, protocol = msg seqno, msg, protocol = msg
def handleHello(peer, seqno, msg, retry): def handleHello(peer, seqno, msg: bytes, retry):
if seqno == 2: if seqno == 2:
i = len(msg) // 2 i = len(msg) // 2
h = msg[:i] h = msg[:i]
...@@ -394,7 +397,7 @@ class BaseTunnelManager(object): ...@@ -394,7 +397,7 @@ class BaseTunnelManager(object):
except (AttributeError, crypto.Error, x509.NewSessionError, except (AttributeError, crypto.Error, x509.NewSessionError,
subprocess.CalledProcessError): subprocess.CalledProcessError):
logging.debug('ignored new session key from %r', logging.debug('ignored new session key from %r',
address, exc_info=1) address, exc_info=True)
return return
peer.version = self._version \ peer.version = self._version \
if self._sendto(to, b'\0' + self._version, peer) else b'' if self._sendto(to, b'\0' + self._version, peer) else b''
...@@ -469,8 +472,8 @@ class BaseTunnelManager(object): ...@@ -469,8 +472,8 @@ class BaseTunnelManager(object):
# Don't send country to old nodes # Don't send country to old nodes
if self._getPeer(peer).protocol < 7: if self._getPeer(peer).protocol < 7:
return ';'.join(','.join(a.split(',')[:3]) for a in return ';'.join(','.join(a.split(',')[:3]) for a in
';'.join(iter(self._address.values())).split(';')) ';'.join(self._address.values()).split(';'))
return ';'.join(iter(self._address.values())) return ';'.join(self._address.values())
elif not code: # network version elif not code: # network version
if peer: if peer:
try: try:
...@@ -555,7 +558,7 @@ class BaseTunnelManager(object): ...@@ -555,7 +558,7 @@ class BaseTunnelManager(object):
if (not self.NEED_RESTART.isdisjoint(changed) if (not self.NEED_RESTART.isdisjoint(changed)
or version.protocol < self.cache.min_protocol or version.protocol < self.cache.min_protocol
# TODO: With --management, we could kill clients without restarting. # TODO: With --management, we could kill clients without restarting.
or not all(crl.isdisjoint(iter(serials.values())) or not all(crl.isdisjoint(serials.values())
for serials in self._served.values())): for serials in self._served.values())):
# Wait at least 1 second to broadcast new version to neighbours. # Wait at least 1 second to broadcast new version to neighbours.
self.selectTimeout(time.time() + 1 + self.cache.delay_restart, self.selectTimeout(time.time() + 1 + self.cache.delay_restart,
...@@ -782,7 +785,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -782,7 +785,7 @@ class TunnelManager(BaseTunnelManager):
def _cleanDeads(self): def _cleanDeads(self):
disconnected = False disconnected = False
for prefix in list(self._connection_dict.keys()): for prefix in list(self._connection_dict):
status = self._connection_dict[prefix].refresh() status = self._connection_dict[prefix].refresh()
if status: if status:
disconnected |= status > 0 disconnected |= status > 0
...@@ -989,7 +992,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -989,7 +992,7 @@ class TunnelManager(BaseTunnelManager):
break break
def killAll(self): def killAll(self):
for prefix in list(self._connection_dict.keys()): for prefix in list(self._connection_dict):
self._kill(prefix) self._kill(prefix)
def handleClientEvent(self): def handleClientEvent(self):
...@@ -1012,7 +1015,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -1012,7 +1015,7 @@ class TunnelManager(BaseTunnelManager):
if self.cache.same_country: if self.cache.same_country:
address = self._updateCountry(address) address = self._updateCountry(address)
self._address[family] = utils.dump_address(address) self._address[family] = utils.dump_address(address)
self.cache.my_address = ';'.join(iter(self._address.values())) self.cache.my_address = ';'.join(self._address.values())
def broadcastNewVersion(self): def broadcastNewVersion(self):
self._babel_dump_new_version() self._babel_dump_new_version()
......
...@@ -69,7 +69,7 @@ class Forwarder(object): ...@@ -69,7 +69,7 @@ class Forwarder(object):
try: try:
return self._refresh() return self._refresh()
except UPnPException as e: except UPnPException as e:
logging.debug("UPnP failure", exc_info=1) logging.debug("UPnP failure", exc_info=True)
self.clear() self.clear()
try: try:
self.discover() self.discover()
......
import argparse, errno, fcntl, hashlib, logging, os, select as _select import argparse, errno, fcntl, hashlib, logging, os, select as _select
import shlex, signal, socket, sqlite3, struct, subprocess import shlex, signal, socket, sqlite3, struct, subprocess
import sys, textwrap, threading, time, traceback import sys, textwrap, threading, time, traceback
from typing import Optional
# PY3: It will be even better to use Popen(pass_fds=...),
# and then socket.SOCK_CLOEXEC will be useless.
# (We already follow the good practice that consists in not
# relying on the GC for the closing of file descriptors.)
#socket.SOCK_CLOEXEC = 0x80000
HMAC_LEN = len(hashlib.sha1(b'').digest()) HMAC_LEN = len(hashlib.sha1(b'').digest())
...@@ -37,12 +32,12 @@ class FileHandler(logging.FileHandler): ...@@ -37,12 +32,12 @@ class FileHandler(logging.FileHandler):
finally: finally:
self.lock.release() self.lock.release()
# In the rare case _reopen is set just before the lock was released # In the rare case _reopen is set just before the lock was released
if self._reopen and self.lock.acquire(0): if self._reopen and self.lock.acquire(False):
self.release() self.release()
def async_reopen(self, *_): def async_reopen(self, *_):
self._reopen = True self._reopen = True
if self.lock.acquire(0): if self.lock.acquire(False):
self.release() self.release()
def setupLog(log_level, filename=None, **kw): def setupLog(log_level, filename=None, **kw):
...@@ -150,7 +145,7 @@ class exit(object): ...@@ -150,7 +145,7 @@ class exit(object):
def handler(*args): def handler(*args):
if self.status is None: if self.status is None:
self.status = status self.status = status
if self.acquire(0): if self.acquire(False):
self.release() self.release()
for sig in sigs: for sig in sigs:
signal.signal(sig, handler) signal.signal(sig, handler)
...@@ -179,11 +174,9 @@ class Popen(subprocess.Popen): ...@@ -179,11 +174,9 @@ class Popen(subprocess.Popen):
self.terminate() self.terminate()
t = threading.Timer(5, self.kill) t = threading.Timer(5, self.kill)
t.start() t.start()
# PY3: use waitid(WNOWAIT) and call self.poll() after t.cancel() r = os.waitid(os.P_PID, self.pid, os.WNOWAIT)
#r = self.wait()
r = self.waitid(WNOWAIT) # PY3
t.cancel() t.cancel()
self.poll() # PY3 self.poll()
return r return r
...@@ -263,7 +256,7 @@ newHmacSecret = newHmacSecret() ...@@ -263,7 +256,7 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value # - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes # - the 3 first bits code the number of bytes
def packInteger(i): def packInteger(i: int) -> bytes:
for n in range(8): for n in range(8):
x = 32 << 8 * n x = 32 << 8 * n
if i < x: if i < x:
...@@ -271,7 +264,7 @@ def packInteger(i): ...@@ -271,7 +264,7 @@ def packInteger(i):
i -= x i -= x
raise OverflowError raise OverflowError
def unpackInteger(x): def unpackInteger(x: bytes) -> Optional[tuple[int, int]]:
n = x[0] >> 5 n = x[0] >> 5
try: try:
i, = struct.unpack("!Q", b'\0' * (7 - n) + x[:n+1]) i, = struct.unpack("!Q", b'\0' * (7 - n) + x[:n+1])
......
...@@ -52,7 +52,7 @@ def maybe_renew(path, cert, info, renew, force=False): ...@@ -52,7 +52,7 @@ def maybe_renew(path, cert, info, renew, force=False):
if time.time() < next_renew: if time.time() < next_renew:
return cert, next_renew return cert, next_renew
try: try:
pem = renew() pem: bytes = renew()
if not pem or pem == crypto.dump_certificate( if not pem or pem == crypto.dump_certificate(
crypto.FILETYPE_PEM, cert): crypto.FILETYPE_PEM, cert):
exc_info = 0 exc_info = 0
...@@ -62,7 +62,7 @@ def maybe_renew(path, cert, info, renew, force=False): ...@@ -62,7 +62,7 @@ def maybe_renew(path, cert, info, renew, force=False):
exc_info = 1 exc_info = 1
break break
new_path = path + '.new' new_path = path + '.new'
with open(new_path, 'w') as f: with open(new_path, 'wb') as f:
f.write(pem) f.write(pem)
try: try:
s = os.stat(path) s = os.stat(path)
...@@ -90,9 +90,9 @@ class Cert(object): ...@@ -90,9 +90,9 @@ class Cert(object):
self.ca_path = ca self.ca_path = ca
self.cert_path = cert self.cert_path = cert
self.key_path = key self.key_path = key
with open(ca) as f: with open(ca, "rb") as f:
self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
with open(key) as f: with open(key, "rb") as f:
self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read()) self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
if cert: if cert:
with open(cert) as f: with open(cert) as f:
...@@ -143,22 +143,21 @@ class Cert(object): ...@@ -143,22 +143,21 @@ class Cert(object):
"error running openssl, assuming cert is invalid") "error running openssl, assuming cert is invalid")
# BBB: With old versions of openssl, detailed # BBB: With old versions of openssl, detailed
# error is printed to standard output. # error is printed to standard output.
out, err = out.decode(), err.decode() for stream in err, out:
for err in err, out: for x in stream.decode(errors='replace').splitlines():
for x in err.splitlines():
if x.startswith('error '): if x.startswith('error '):
x, msg = x.split(':', 1) x, msg = x.split(':', 1)
_, code, _, depth, _ = x.split(None, 4) _, code, _, depth, _ = x.split(None, 4)
raise VerifyError(int(code), int(depth), msg.strip()) raise VerifyError(int(code), int(depth), msg.strip())
return r return r
def verify(self, sign, data): def verify(self, sign: bytes, data):
crypto.verify(self.ca, sign, data, 'sha512') crypto.verify(self.ca, sign, data, 'sha512')
def sign(self, data): def sign(self, data) -> bytes:
return crypto.sign(self.key, data, 'sha512') return crypto.sign(self.key, data, 'sha512')
def decrypt(self, data): def decrypt(self, data: bytes) -> bytes:
p = openssl('rsautl', '-decrypt', '-inkey', self.key_path) p = openssl('rsautl', '-decrypt', '-inkey', self.key_path)
out, err = p.communicate(data) out, err = p.communicate(data)
if p.returncode: if p.returncode:
...@@ -209,7 +208,7 @@ class Peer(object): ...@@ -209,7 +208,7 @@ class Peer(object):
stop_date = float('inf') stop_date = float('inf')
version = b'' version = b''
def __init__(self, prefix): def __init__(self, prefix: str):
self.prefix = prefix self.prefix = prefix
@property @property
...@@ -253,7 +252,7 @@ class Peer(object): ...@@ -253,7 +252,7 @@ class Peer(object):
def _hmac(self, msg): def _hmac(self, msg):
return hmac.HMAC(self._key, msg, hashlib.sha1).digest() return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key, protocol): def newSession(self, key: bytes, protocol):
if key <= self._key: if key <= self._key:
raise NewSessionError(self._key, key) raise NewSessionError(self._key, key)
self._key = key self._key = key
...@@ -266,7 +265,7 @@ class Peer(object): ...@@ -266,7 +265,7 @@ class Peer(object):
seqno_struct = struct.Struct("!L") seqno_struct = struct.Struct("!L")
def decode(self, msg, _unpack=seqno_struct.unpack): def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> str:
seqno, = _unpack(msg[:4]) seqno, = _unpack(msg[:4])
if seqno <= 2: if seqno <= 2:
msg = msg[4:] msg = msg[4:]
...@@ -280,10 +279,12 @@ class Peer(object): ...@@ -280,10 +279,12 @@ class Peer(object):
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno: if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None self._last = None
self._i = seqno self._i = seqno
return msg[4:i] return msg[4:i].decode()
def encode(self, msg, _pack=seqno_struct.pack): def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes:
self._j += 1 self._j += 1
if type(msg) is str:
msg = msg.encode()
msg = _pack(self._j) + msg msg = _pack(self._j) + msg
return msg + self._hmac(msg) return msg + self._hmac(msg)
......
...@@ -15,7 +15,7 @@ def copy_file(self, infile, outfile, *args, **kw): ...@@ -15,7 +15,7 @@ def copy_file(self, infile, outfile, *args, **kw):
if infile == version["__file__"]: if infile == version["__file__"]:
if not self.dry_run: if not self.dry_run:
log.info("generating %s -> %s", infile, outfile) log.info("generating %s -> %s", infile, outfile)
with open(outfile, "wb") as f: with open(outfile, "w", encoding="utf-8") as f:
for x in sorted(version.items()): for x in sorted(version.items()):
if not x[0].startswith("_"): if not x[0].startswith("_"):
f.write("%s = %r\n" % x) f.write("%s = %r\n" % x)
...@@ -23,7 +23,7 @@ def copy_file(self, infile, outfile, *args, **kw): ...@@ -23,7 +23,7 @@ def copy_file(self, infile, outfile, *args, **kw):
elif isinstance(self, build_py) and \ elif isinstance(self, build_py) and \
os.stat(infile).st_mode & stat.S_IEXEC: os.stat(infile).st_mode & stat.S_IEXEC:
if os.path.isdir(infile) and os.path.isdir(outfile): if os.path.isdir(infile) and os.path.isdir(outfile):
return (outfile, 0) return outfile, 0
# Adjust interpreter of OpenVPN hooks. # Adjust interpreter of OpenVPN hooks.
with open(infile) as src: with open(infile) as src:
first_line = src.readline() first_line = src.readline()
...@@ -33,11 +33,8 @@ def copy_file(self, infile, outfile, *args, **kw): ...@@ -33,11 +33,8 @@ def copy_file(self, infile, outfile, *args, **kw):
executable = self.distribution.command_obj['build'].executable executable = self.distribution.command_obj['build'].executable
patched = "#!%s%s\n" % (executable, m.group(1) or '') patched = "#!%s%s\n" % (executable, m.group(1) or '')
patched += src.read() patched += src.read()
dst = os.open(outfile, os.O_CREAT | os.O_WRONLY | os.O_TRUNC) with open(outfile, "w") as dst:
try: dst.write(patched)
os.write(dst, patched)
finally:
os.close(dst)
return outfile, 1 return outfile, 1
cls, = self.__class__.__bases__ cls, = self.__class__.__bases__
return cls.copy_file(self, infile, outfile, *args, **kw) return cls.copy_file(self, infile, outfile, *args, **kw)
...@@ -97,7 +94,7 @@ setup( ...@@ -97,7 +94,7 @@ setup(
extras_require = { extras_require = {
'geoip': ['geoip2'], 'geoip': ['geoip2'],
'multicast': ['PyYAML'], 'multicast': ['PyYAML'],
'test': ['mock', 'pathlib2', 'nemu', 'python-unshare', 'python-passfd', 'multiping'] 'test': ['mock', 'nemu3', 'unshare', 'multiping']
}, },
#dependency_links = [ #dependency_links = [
# "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7", # "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment