Commit ea5e6290 authored by Tom Niget's avatar Tom Niget

py2to3: continue fixing py2 to py3 issues

parent 27d78ea4
This diff is collapsed.
......@@ -18,10 +18,10 @@
import re
import os
from new import function
from nemu.iproute import backticks, get_if_data, route, \
get_addr_data, get_all_route_data, interface
from nemu.interface import Switch, Interface
from types import FunctionType
def _get_all_route_data():
ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all"
......@@ -65,7 +65,7 @@ def __init__(self, *args, **kw):
self.name = self.name.split('@',1)[0]
interface.__init__ = __init__
get_addr_data.orig = function(get_addr_data.__code__,
get_addr_data.orig = FunctionType(get_addr_data.__code__,
get_addr_data.__globals__)
def _get_addr_data():
byidx, bynam = get_addr_data.orig()
......
......@@ -64,7 +64,7 @@ class Ping(Thread):
os.utime(csv_path, (time.time(), time.time()))
for add in no_responses:
print(('No response from %s with seq no %d' % (add, seq)))
print('No response from %s with seq no %d' % (add, seq))
parser = argparse.ArgumentParser()
parser.add_argument('n', help = 'my machine name (m1,m2...)')
......
......@@ -30,4 +30,5 @@ def __file__():
return os.path.join(sys.path[0], sys.argv[1])
__file__ = __file__()
exec(compile(open(__file__, "rb").read(), __file__, 'exec'))
with open(__file__) as f:
exec(compile(f.read(), __file__, 'exec'))
......@@ -39,5 +39,5 @@ def checkHMAC(db, machines):
if rc:
print('All nodes use Babel with the correct HMAC configuration')
else:
print(('Expected config: %s' % dict(list(zip(BABEL_HMAC, hmac)))))
print('Expected config: %s' % dict(zip(BABEL_HMAC, hmac)))
return rc
......@@ -5,7 +5,7 @@ from . import utils, version, x509
class Cache(object):
def __init__(self, db_path, registry, cert, db_size=200):
def __init__(self, db_path, registry, cert: x509.Cert, db_size=200):
self._prefix = cert.prefix
self._db_size = db_size
self._decrypt = cert.decrypt
......@@ -89,8 +89,10 @@ class Cache(object):
logging.info("Getting new network parameters from registry...")
try:
# TODO: When possible, the registry should be queried via the re6st.
network_config = self._registry.getNetworkConfig(self._prefix)
logging.debug('getNetworkConfig result: %r', network_config)
x = json.loads(zlib.decompress(
self._registry.getNetworkConfig(self._prefix)))
network_config))
base64_list = x.pop('', ())
config = {}
for k, v in x.items():
......@@ -134,7 +136,7 @@ class Cache(object):
((k, memoryview(v) if k in base64_list or
k.startswith('babel_hmac') else v)
for k, v in config.items()))
self._loadConfig(iter(config.items()))
self._loadConfig(config.items())
return [k[:-5] if k.endswith(':json') else k
for k in chain(remove, (k
for k, v in config.items()
......@@ -229,10 +231,9 @@ class Cache(object):
" WHERE prefix=peer AND prefix!=? AND try=?"
def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address"
+ " ORDER BY RANDOM()"):
#return self._db.execute(__sql, (self._prefix, failed))
r = self._db.execute(__sql, (self._prefix, failed))
return r
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)"):
return self._db.execute(__sql, (self._prefix, failed))
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)") -> int:
return self._db.execute(__sql, (self._prefix, failed)).next()[0]
def getBootstrapPeer(self):
......
#!/usr/bin/python2
#!/usr/bin/env python3
import argparse, atexit, binascii, errno, hashlib
import os, subprocess, sqlite3, sys, time
from OpenSSL import crypto
......@@ -13,7 +13,7 @@ def create(path, text=None, mode=0o666):
finally:
os.close(fd)
def loadCert(pem):
def loadCert(pem: bytes):
return crypto.load_certificate(crypto.FILETYPE_PEM, pem)
def main():
......@@ -91,8 +91,7 @@ def main():
try:
with open(cert_path) as f:
cert = loadCert(f.read())
components = dict(cert.get_subject().get_components())
components = {k.decode(): v for k, v in components.items()}
components = {k.decode(): v for k, v in cert.get_subject().get_components()}
for k in reserved:
components.pop(k, None)
except IOError as e:
......@@ -140,7 +139,7 @@ def main():
req.set_pubkey(pkey)
req.sign(pkey, 'sha512')
req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req)
req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req).decode()
# First make sure we can open certificate file for writing,
# to avoid using our token for nothing.
......@@ -165,13 +164,13 @@ def main():
cert = loadCert(cert)
not_after = x509.notAfter(cert)
print(("Setup complete. Certificate is valid until %s UTC"
print("Setup complete. Certificate is valid until %s UTC"
" and will be automatically renewed after %s UTC.\n"
"Do not forget to backup to your private key (%s) or"
" you will lose your assigned subnet." % (
time.asctime(time.gmtime(not_after)),
time.asctime(time.gmtime(not_after - registry.RENEW_PERIOD)),
key_path)))
key_path))
if not os.path.lexists(conf_path):
create(conf_path, ("""\
......@@ -188,13 +187,13 @@ key %s
#O--verb
#O3
""" % (config.registry, ca_path, cert_path, key_path,
('country ' + config.location.split(',', 1)[0]) \
('country ' + config.location.split(',', 1)[0])
if config.location else '')).encode())
print("Sample configuration file created.")
cn = x509.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn)
print("Your subnet: %s/%u (CN=%s)" \
print("Your subnet: %s/%u (CN=%s)"
% (utils.ipFromBin(subnet), len(subnet), cn))
if __name__ == "__main__":
......
#!/usr/bin/python2
#!/usr/bin/env python3
import atexit, errno, logging, os, shutil, signal
import socket, struct, subprocess, sys
from collections import deque
......@@ -256,10 +256,10 @@ def main():
forwarder.addRule(port, proto)
address.append(forwarder.checkExternalIp())
elif 'any' not in ipv4:
address += list(map(ip_changed, ipv4))
address += map(ip_changed, ipv4)
ipv4_any = ()
if ipv6:
address += list(map(ip_changed, ipv6))
address += map(ip_changed, ipv6)
ipv6_any = ()
else:
ip_changed = remote_gateway = None
......
#!/usr/bin/python2
#!/usr/bin/env python3
import http.client, logging, os, socket, sys
from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingTCPServer
......@@ -29,13 +29,13 @@ class RequestHandler(BaseHTTPRequestHandler):
path = self.path
query = {}
else:
query = dict(parse_qsl(query, keep_blank_values=1,
strict_parsing=1))
query = dict(parse_qsl(query, keep_blank_values=True,
strict_parsing=True))
_, path = path.split('/')
if not _:
return self.server.handle_request(self, path, query)
except Exception:
logging.info(self.requestline, exc_info=1)
logging.info(self.requestline, exc_info=True)
self.send_error(http.client.BAD_REQUEST)
def log_error(*args):
......
......@@ -34,13 +34,13 @@ class Array(object):
def __init__(self, item):
self._item = item
def encode(self, buffer, value):
def encode(self, buffer: bytes, value: list):
buffer += uint16.pack(len(value))
encode = self._item.encode
for value in value:
encode(buffer, value)
def decode(self, buffer, offset=0):
def decode(self, buffer: bytes, offset=0) -> tuple[int, list]:
r = []
o = offset + 2
decode = self._item.decode
......@@ -52,13 +52,13 @@ class Array(object):
class String(object):
@staticmethod
def encode(buffer, value):
buffer += value + b'\x00'
def encode(buffer: bytes, value: str):
buffer += value.encode("utf-8") + b'\x00'
@staticmethod
def decode(buffer, offset=0):
def decode(buffer: bytes, offset=0) -> tuple[int, str]:
i = buffer.index(0, offset)
return i + 1, buffer[offset:i]
return i + 1, buffer[offset:i].decode("utf-8")
class Buffer(object):
......@@ -69,7 +69,7 @@ class Buffer(object):
def __iadd__(self, value):
self._buf.extend(value)
self._buf += value
return self
def __len__(self):
......@@ -195,7 +195,7 @@ class Babel(object):
logging.debug("Can't connect to %r (%r)", self.socket_path, e)
return e
s.send(b'\x01')
s.setblocking(0)
s.setblocking(False)
del self.select
self.socket = s
return self.select(*args)
......
......@@ -38,8 +38,7 @@ class Socket(object):
self._socket.recv(0)
return True
except socket.error as e:
(err, _) = e
if err != errno.EAGAIN:
if e.errno != errno.EAGAIN:
raise
self._socket.setblocking(1)
return False
......
#!/usr/bin/python -S
#!/usr/bin/env -S python3 -S
import os, sys
script_type = os.environ['script_type']
......@@ -13,7 +13,5 @@ if script_type == 'up':
if script_type == 'route-up':
import time
with open('/opt/openvpn_route_up.log', 'w+') as f:
f.write(repr(sys.argv))
os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(),
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'])).encode())
#!/usr/bin/python -S
#!/usr/bin/env -S python3 -S
import os, sys
script_type = os.environ['script_type']
......@@ -7,7 +7,7 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
# Write into pipe connect/disconnect events
fd = int(sys.argv[1])
os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'],
int(os.environ['tls_serial_0']), external_ip))))
int(os.environ['tls_serial_0']), external_ip))).encode("utf-8"))
if script_type == 'client-connect':
if os.read(fd, 1) == b'\x00':
......
import binascii
import logging, errno, os
from typing import Optional
from . import utils
here = os.path.realpath(os.path.dirname(__file__))
ovpn_server = os.path.join(here, 'ovpn-server')
ovpn_client = os.path.join(here, 'ovpn-client')
ovpn_log = None
ovpn_log: Optional[str] = None
def openvpn(iface, encrypt, *args, **kw):
args = ['openvpn',
......@@ -80,9 +82,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C', 'redistribute local deny',
'-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)]
if hmac_sign:
def key(cmd, id, value):
def key(cmd, id: str, value):
cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, value.encode('hex')))
(id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign)
default += ' key sign'
if hmac_accept is not None:
......
......@@ -91,7 +91,7 @@ class RegistryServer(object):
"name TEXT PRIMARY KEY NOT NULL",
"value")
self.prefix = self.getConfig("prefix", None)
self.version = str(self.getConfig("version", b'\x00')) # BBB: blob
self.version = self.getConfig("version", b'\x00')
utils.sqliteCreateTable(self.db, "token",
"token TEXT PRIMARY KEY NOT NULL",
"email TEXT NOT NULL",
......@@ -189,15 +189,15 @@ class RegistryServer(object):
self.sendto(self.prefix, 0)
# The following entry lists values that are base64-encoded.
kw[''] = 'version',
kw['version'] = base64.b64encode(self.version)
kw['version'] = base64.b64encode(self.version).decode()
self.network_config = kw
def increaseVersion(self):
x = utils.packInteger(1 + utils.unpackInteger(self.version)[0:1])
x = utils.packInteger(1 + utils.unpackInteger(self.version)[0])
self.version = x + self.cert.sign(x)
def sendto(self, prefix, code):
self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT))
def sendto(self, prefix: str, code: int):
self.sock.sendto(prefix.encode() + bytes((0, code)), ('::1', tunnel.PORT))
def recv(self, code):
try:
......@@ -314,9 +314,11 @@ class RegistryServer(object):
except HTTPError as e:
return request.send_error(*e.args)
except:
logging.warning(request.requestline, exc_info=1)
logging.warning(request.requestline, exc_info=True)
return request.send_error(http.client.INTERNAL_SERVER_ERROR)
if result:
if type(result) is str:
result = result.encode("utf-8")
request.send_response(http.client.OK)
request.send_header("Content-Length", str(len(result)))
else:
......@@ -432,9 +434,9 @@ class RegistryServer(object):
prev_prefix = None
max_len = 128,
while True:
max_len = next(q("SELECT max(length(prefix)) FROM cert"
max_len = q("SELECT max(length(prefix)) FROM cert"
" WHERE cert is null AND length(prefix) < ?",
max_len))
max_len).fetchone()
if not max_len[0]:
break
for prefix, in q("SELECT prefix FROM cert"
......@@ -593,8 +595,8 @@ class RegistryServer(object):
hmac = [self.getConfig(k, None) for k in BABEL_HMAC]
for i, v in enumerate(v for v in hmac if v is not None):
config[('babel_hmac_sign', 'babel_hmac_accept')[i]] = \
v and base64.b64encode(x509.encrypt(cert, v))
return zlib.compress(json.dumps(config))
v and base64.b64encode(x509.encrypt(cert, v)).decode()
return zlib.compress(json.dumps(config).encode("utf-8"))
def _queryAddress(self, peer):
self.sendto(peer, 1)
......@@ -800,7 +802,7 @@ class RegistryClient(object):
_hmac = None
user_agent = "re6stnet/%s, %s" % (version.version, platform.platform())
def __init__(self, url, cert=None, auto_close=True):
def __init__(self, url, cert: x509.Cert=None, auto_close=True):
self.cert = cert
self.auto_close = auto_close
url_parsed = urlparse(url)
......@@ -812,12 +814,12 @@ class RegistryClient(object):
def __getattr__(self, name):
getcallargs = getattr(RegistryServer, name).getcallargs
def rpc(*args, **kw):
def rpc(*args, **kw) -> bytes:
kw = getcallargs(*args, **kw)
query = '/' + name
if kw:
if any(type(v) is not str for v in kw.values()):
raise TypeError
raise TypeError(kw)
query += '?' + urlencode(kw)
url = self._path + query
client_prefix = kw.get('cn')
......@@ -862,7 +864,7 @@ class RegistryClient(object):
except HTTPError:
raise
except Exception:
logging.info(url, exc_info=1)
logging.info(url, exc_info=True)
else:
logging.info('%s\nUnexpected response %s %s',
url, response.status, response.reason)
......
from pathlib2 import Path
from pathlib import Path
DEMO_PATH = Path(__file__).resolve().parent.parent.parent / "demo"
......@@ -60,7 +60,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# read token from db
db = sqlite3.connect(str(self.server.db), isolation_level=None)
token = None
for _ in xrange(100):
for _ in range(100):
time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?",
(email,)).fetchone()
......
......@@ -4,7 +4,7 @@ import nemu
import time
import weakref
from subprocess import PIPE
from pathlib2 import Path
from pathlib import Path
from re6st.tests import DEMO_PATH
......@@ -60,7 +60,7 @@ class NetManager(object):
Raise:
AssertionError
"""
for reg, nodes in self.registries.iteritems():
for reg, nodes in self.registries.items():
for node in nodes:
app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE)
ret = app0.wait()
......
......@@ -6,13 +6,15 @@ import ipaddress
import json
import logging
import re
import shlex
import shutil
import sqlite3
import sys
import tempfile
import time
import weakref
from subprocess import PIPE
from pathlib2 import Path
from pathlib import Path
from re6st.tests import tools
from re6st.tests import DEMO_PATH
......@@ -20,9 +22,10 @@ from re6st.tests import DEMO_PATH
WORK_DIR = Path(__file__).parent / "temp_net_test"
DH_FILE = DEMO_PATH / "dh2048.pem"
RE6STNET = "python -m re6st.cli.node"
RE6ST_REGISTRY = "python -m re6st.cli.registry"
RE6ST_CONF = "python -m re6st.cli.conf"
PYTHON = shlex.quote(sys.executable)
RE6STNET = PYTHON + " -m re6st.cli.node"
RE6ST_REGISTRY = PYTHON + " -m re6st.cli.registry"
RE6ST_CONF = PYTHON + " -m re6st.cli.conf"
def initial():
"""create the workplace"""
......@@ -72,7 +75,7 @@ class Re6stRegistry(object):
self.run()
# wait the servcice started
p = self.node.Popen(['python', '-c', """if 1:
p = self.node.Popen([sys.executable, '-c', """if 1:
import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True:
......@@ -115,7 +118,7 @@ class Re6stRegistry(object):
'--client-count', (self.client_number+1)//2, '--port', self.port]
#PY3: convert PosixPath to str, can be remove in Python 3
cmd = map(str, cmd)
cmd = list(map(str, cmd))
cmd[:0] = RE6ST_REGISTRY.split()
......@@ -210,7 +213,7 @@ class Re6stNode(object):
# read token
db = sqlite3.connect(str(self.registry.db), isolation_level=None)
token = None
for _ in xrange(100):
for _ in range(100):
time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?",
(self.email,)).fetchone()
......@@ -223,7 +226,7 @@ class Re6stNode(object):
out, _ = p.communicate(str(token[0]))
# logging.debug("re6st-conf output: {}".format(out))
# find the ipv6 subnet of node
self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out).group(0)
self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out.decode("utf-8")).group(0)
data = {'ip6': self.ip6, 'hash': self.registry.ident}
with open(str(self.data_file), 'w') as f:
json.dump(data, f)
......@@ -236,7 +239,7 @@ class Re6stNode(object):
'--key', self.key, '-v4', '--registry', self.registry.url,
'--console', self.console]
#PY3: same as for Re6stRegistry.run
cmd = map(str, cmd)
cmd = list(map(str, cmd))
cmd[:0] = RE6STNET.split()
cmd += args
......
"""contain ping-test for re6set net"""
import os
import sys
import unittest
import time
import psutil
import logging
import random
from pathlib2 import Path
from pathlib import Path
import network_build
import re6st_wrap
......@@ -47,12 +48,12 @@ def wait_stable(nodes, timeout=240):
for node in nodes:
sub_ips = set(ips) - {node.ip6}
node.ping_proc = node.node.Popen(
["python", PING_PATH, '--retry', '-a'] + list(sub_ips))
[sys.executable, PING_PATH, '--retry', '-a'] + list(sub_ips), env=os.environ)
# check all the node network can ping each other, in order reverse
unfinished = list(nodes)
while unfinished:
for i in xrange(len(unfinished)-1, -1, -1):
for i in range(len(unfinished)-1, -1, -1):
node = unfinished[i]
if node.ping_proc.poll() is not None:
logging.debug("%s 's network is stable", node.name)
......
#!/usr/bin/python2
#!/usr/bin/env python3
""" unit test for re6st-conf
"""
......@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase):
# mocked server cert and pkey
cls.pkey, cls.cert = create_ca_file(os.devnull, os.devnull)
cls.fingerprint = "".join( cls.cert.digest("sha1").split(":"))
cls.fingerprint = "".join( cls.cert.digest("sha1").decode().split(":"))
# client.getCa should return a string form cert
cls.cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cls.cert)
......
......@@ -13,12 +13,13 @@ import tempfile
from argparse import Namespace
from OpenSSL import crypto
from mock import Mock, patch
from pathlib2 import Path
from pathlib import Path
from re6st import registry
from re6st.tests.tools import *
from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions
......@@ -49,6 +50,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
insert_cert.serial += 1
return key, cert
insert_cert.serial = 0
......@@ -77,17 +79,26 @@ class TestRegistryServer(unittest.TestCase):
def setUp(self):
self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \
+ "@mail.com"
+ "@mail.com"
def test_recv(self):
recv = self.server.sock.recv = Mock()
recv.side_effect = [
side_effect = iter([
"0001001001001a_msg",
"0001001001002\0001dqdq",
"0001001001001\000a_msg",
"0001001001001\000\4a_msg",
"0000000000000\0" # ERROR, IndexError: msg is null
]
])
class SocketProxy:
def __init__(self, wrappee):
self.wrappee = wrappee
self.recv = lambda _: next(side_effect)
def __getattr__(self, attr):
return getattr(self.wrappee, attr)
self.server.sock = SocketProxy(self.server.sock)
try:
res1 = self.server.recv(4)
......@@ -115,7 +126,7 @@ class TestRegistryServer(unittest.TestCase):
now = int(time.time()) - self.config.grace_period + 20
# makeup data
insert_cert(cur, self.server.cert, prefix_old, 1)
insert_cert(cur, self.server.cert, prefix, now -1)
insert_cert(cur, self.server.cert, prefix, now - 1)
cur.execute("INSERT INTO token VALUES (?,?,?,?)",
(token_old, self.email, 4, 2))
cur.execute("INSERT INTO token VALUES (?,?,?,?)",
......@@ -143,7 +154,7 @@ class TestRegistryServer(unittest.TestCase):
prefix = "0000000011111111"
method = "func"
protocol = 7
params = {"cn" : prefix, "a" : 1, "b" : 2}
params = {"cn": prefix, "a": 1, "b": 2}
func.getcallargs.return_value = params
del func._private
func.return_value = result = b"this_is_a_result"
......@@ -176,12 +187,12 @@ class TestRegistryServer(unittest.TestCase):
def test_handle_request_private(self, func):
"""case request with _private attr"""
method = "func"
params = {"a" : 1, "b" : 2}
params = {"a": 1, "b": 2}
func.getcallargs.return_value = params
func.return_value = None
request_good = Mock()
request_good.client_address = self.config.authorized_origin
request_good.headers = {'X-Forwarded-For':self.config.authorized_origin[0]}
request_good.headers = {'X-Forwarded-For': self.config.authorized_origin[0]}
request_bad = Mock()
request_bad.client_address = ["wrong_address"]
......@@ -282,7 +293,7 @@ class TestRegistryServer(unittest.TestCase):
nb_less = 0
for cert in self.server.iterCert():
s = cert[0].get_subject().serialNumber
if(s and int(s) <= serial):
if s and int(s) <= serial:
nb_less += 1
self.assertEqual(nb_less, serial)
......@@ -378,7 +389,7 @@ class TestRegistryServer(unittest.TestCase):
hmacs = get_hmac()
key_1 = hmacs[1]
self.assertEqual(hmacs, [None, key_1, ''])
self.assertEqual(hmacs, [None, key_1, b''])
# step 2
self.server.updateHMAC()
......@@ -402,7 +413,6 @@ class TestRegistryServer(unittest.TestCase):
self.assertEqual(get_hmac(), [key_2, None, None])
def test_getNodePrefix(self):
# prefix in short format
prefix = "0000000101"
......@@ -426,19 +436,33 @@ class TestRegistryServer(unittest.TestCase):
('0000000000000001', '2 0/16 6/16')
]
recv.side_effect = recv_case
def side_effct(rlist, wlist, elist, timeout):
# rlist is true until the len(recv_case)th call
side_effct.i -= side_effct.i > 0
return [side_effct.i, wlist, None]
side_effct.i = len(recv_case) + 1
select.side_effect = side_effct
res = self.server.topology()
expect_res = '{"36893488147419103232/80": ["0/16", "7/16"], ' \
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], ' \
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '\
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}'''
class CustomDecoder(json.JSONDecoder):
def __init__(self, **kwargs):
json.JSONDecoder.__init__(self, **kwargs)
self.parse_array = self.JSONArray
self.scan_once = json.scanner.py_make_scanner(self)
def JSONArray(self, s_and_end, scan_once, **kwargs):
values, end = json.decoder.JSONArray(s_and_end, scan_once, **kwargs)
return set(values), end
res = json.loads(res, cls=CustomDecoder)
expect_res = {"36893488147419103232/80": {"0/16", "7/16"},
"": {"36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"}, "4/16": {"0/16"},
"3/16": {"0/16", "7/16"}, "0/16": {"6/16", "7/16"}, "1/16": {"6/16", "0/16"},
"7/16": {"6/16", "4/16"}}
self.assertEqual(res, expect_res)
......
......@@ -52,9 +52,9 @@ class TestRegistryClient(unittest.TestCase):
self.client._hmac = None
self.client.hello = Mock(return_value = "aaabbb")
self.client.cert = Mock()
key = "this_is_a_key"
key = b"this_is_a_key"
self.client.cert.decrypt.return_value = key
h = hmac.HMAC(key, query, hashlib.sha1).digest()
h = hmac.HMAC(key, query.encode(), hashlib.sha1).digest()
key = hashlib.sha1(key).digest()
# response part
body = b'this is a body'
......
#!/usr/bin/python2
#!/usr/bin/env python3
import os
import sys
import unittest
......@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase):
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# def test_processPacket_address_with_msg_peer(self, makeTunnel):
# """code is 1, peer and msg not none """
# c = chr(1)
# c = b"\x01"
# msg = "address"
# peer = x509.Peer("000001")
# self.tunnel._connecting = {peer}
......@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase):
def test_processPacket_address(self):
"""code is 1, for address. And peer or msg are none"""
c = chr(1)
c = b"\x01"
self.tunnel._address = {1: "1,1", 2: "2,2"}
res = self.tunnel._processPacket(c)
......@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase):
and each address join by ;
it will truncate address which has more than 3 element
"""
c = chr(1)
c = b"\x01"
peer = x509.Peer("000001")
peer.protocol = 1
self.tunnel._peers.append(peer)
......@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase):
"""code is 0, for network version, peer is not none
2 case, one modify the version, one not
"""
c = chr(0)
c = b"\x00"
peer = x509.Peer("000001")
version1 = "00003"
version2 = "00007"
self.tunnel._version = version3 = "00005"
version1 = b"00003"
version2 = b"00007"
self.tunnel._version = version3 = b"00005"
self.tunnel._peers.append(peer)
res = self.tunnel._processPacket(c + version1, peer)
......
#!/usr/bin/python2
#!/usr/bin/env python3
import os
import sys
import unittest
......
......@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
return
crypto.X509Cert in pem format
"""
if type(ca) is str:
if type(ca) is bytes:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca)
if type(ca_key) is str:
if type(ca_key) is bytes:
ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
......@@ -56,10 +56,10 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial):
pkey, csr = generate_csr()
cert = generate_cert(ca, ca_key, csr, prefix, serial)
with open(pkey_file, 'w') as f:
f.write(pkey.decode())
with open(cert_file, 'w') as f:
f.write(cert.decode())
with open(pkey_file, 'wb') as f:
f.write(pkey)
with open(cert_file, 'wb') as f:
f.write(cert)
return pkey, cert
......@@ -84,9 +84,9 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
cert.set_pubkey(key)
cert.sign(key, "sha512")
with open(pkey_file, 'w') as pkey_file:
with open(pkey_file, 'wb') as pkey_file:
pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key))
with open(cert_file, 'w') as cert_file:
with open(cert_file, 'wb') as cert_file:
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
return key, cert
......
......@@ -249,7 +249,7 @@ class BaseTunnelManager(object):
self._address = {family: utils.dump_address(address)
for family, address in address_dict.items()
if address}
cache.my_address = ';'.join(iter(self._address.values()))
cache.my_address = ';'.join(self._address.values())
self.sock = socket.socket(socket.AF_INET6,
socket.SOCK_DGRAM | socket.SOCK_CLOEXEC)
......@@ -329,7 +329,7 @@ class BaseTunnelManager(object):
def _getPeer(self, prefix):
return self._peers[bisect(self._peers, prefix) - 1]
def sendto(self, prefix, msg):
def sendto(self, prefix: str, msg):
to = utils.ipFromBin(self._network + prefix), PORT
peer = self._getPeer(prefix)
if peer.prefix != prefix:
......@@ -344,6 +344,8 @@ class BaseTunnelManager(object):
peer.hello0Sent()
def _sendto(self, to, msg, peer=None):
if type(msg) is str:
msg = msg.encode()
try:
r = self.sock.sendto(peer.encode(msg) if peer else msg, to)
except socket.error as e:
......@@ -360,6 +362,7 @@ class BaseTunnelManager(object):
if address[0] == '::1':
try:
prefix, msg = msg.split(b'\0', 1)
prefix = prefix.decode()
int(prefix, 2)
except ValueError:
return
......@@ -371,7 +374,7 @@ class BaseTunnelManager(object):
if msg:
self._sendto(to, '%s\0%c%s' % (prefix, code, msg))
else:
self.sendto(prefix, chr(code | 0x80) + msg[1:])
self.sendto(prefix, bytes([code | 0x80]) + msg[1:])
return
try:
sender = utils.binFromIp(address[0])
......@@ -384,7 +387,7 @@ class BaseTunnelManager(object):
msg = peer.decode(msg)
if type(msg) is tuple:
seqno, msg, protocol = msg
def handleHello(peer, seqno, msg, retry):
def handleHello(peer, seqno, msg: bytes, retry):
if seqno == 2:
i = len(msg) // 2
h = msg[:i]
......@@ -394,7 +397,7 @@ class BaseTunnelManager(object):
except (AttributeError, crypto.Error, x509.NewSessionError,
subprocess.CalledProcessError):
logging.debug('ignored new session key from %r',
address, exc_info=1)
address, exc_info=True)
return
peer.version = self._version \
if self._sendto(to, b'\0' + self._version, peer) else b''
......@@ -469,8 +472,8 @@ class BaseTunnelManager(object):
# Don't send country to old nodes
if self._getPeer(peer).protocol < 7:
return ';'.join(','.join(a.split(',')[:3]) for a in
';'.join(iter(self._address.values())).split(';'))
return ';'.join(iter(self._address.values()))
';'.join(self._address.values()).split(';'))
return ';'.join(self._address.values())
elif not code: # network version
if peer:
try:
......@@ -555,7 +558,7 @@ class BaseTunnelManager(object):
if (not self.NEED_RESTART.isdisjoint(changed)
or version.protocol < self.cache.min_protocol
# TODO: With --management, we could kill clients without restarting.
or not all(crl.isdisjoint(iter(serials.values()))
or not all(crl.isdisjoint(serials.values())
for serials in self._served.values())):
# Wait at least 1 second to broadcast new version to neighbours.
self.selectTimeout(time.time() + 1 + self.cache.delay_restart,
......@@ -782,7 +785,7 @@ class TunnelManager(BaseTunnelManager):
def _cleanDeads(self):
disconnected = False
for prefix in list(self._connection_dict.keys()):
for prefix in list(self._connection_dict):
status = self._connection_dict[prefix].refresh()
if status:
disconnected |= status > 0
......@@ -989,7 +992,7 @@ class TunnelManager(BaseTunnelManager):
break
def killAll(self):
for prefix in list(self._connection_dict.keys()):
for prefix in list(self._connection_dict):
self._kill(prefix)
def handleClientEvent(self):
......@@ -1012,7 +1015,7 @@ class TunnelManager(BaseTunnelManager):
if self.cache.same_country:
address = self._updateCountry(address)
self._address[family] = utils.dump_address(address)
self.cache.my_address = ';'.join(iter(self._address.values()))
self.cache.my_address = ';'.join(self._address.values())
def broadcastNewVersion(self):
self._babel_dump_new_version()
......
......@@ -69,7 +69,7 @@ class Forwarder(object):
try:
return self._refresh()
except UPnPException as e:
logging.debug("UPnP failure", exc_info=1)
logging.debug("UPnP failure", exc_info=True)
self.clear()
try:
self.discover()
......
import argparse, errno, fcntl, hashlib, logging, os, select as _select
import shlex, signal, socket, sqlite3, struct, subprocess
import sys, textwrap, threading, time, traceback
# PY3: It will be even better to use Popen(pass_fds=...),
# and then socket.SOCK_CLOEXEC will be useless.
# (We already follow the good practice that consists in not
# relying on the GC for the closing of file descriptors.)
#socket.SOCK_CLOEXEC = 0x80000
from typing import Optional
HMAC_LEN = len(hashlib.sha1(b'').digest())
......@@ -37,12 +32,12 @@ class FileHandler(logging.FileHandler):
finally:
self.lock.release()
# In the rare case _reopen is set just before the lock was released
if self._reopen and self.lock.acquire(0):
if self._reopen and self.lock.acquire(False):
self.release()
def async_reopen(self, *_):
self._reopen = True
if self.lock.acquire(0):
if self.lock.acquire(False):
self.release()
def setupLog(log_level, filename=None, **kw):
......@@ -150,7 +145,7 @@ class exit(object):
def handler(*args):
if self.status is None:
self.status = status
if self.acquire(0):
if self.acquire(False):
self.release()
for sig in sigs:
signal.signal(sig, handler)
......@@ -179,11 +174,9 @@ class Popen(subprocess.Popen):
self.terminate()
t = threading.Timer(5, self.kill)
t.start()
# PY3: use waitid(WNOWAIT) and call self.poll() after t.cancel()
#r = self.wait()
r = self.waitid(WNOWAIT) # PY3
r = os.waitid(os.P_PID, self.pid, os.WNOWAIT)
t.cancel()
self.poll() # PY3
self.poll()
return r
......@@ -263,7 +256,7 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes
def packInteger(i):
def packInteger(i: int) -> bytes:
for n in range(8):
x = 32 << 8 * n
if i < x:
......@@ -271,7 +264,7 @@ def packInteger(i):
i -= x
raise OverflowError
def unpackInteger(x):
def unpackInteger(x: bytes) -> Optional[tuple[int, int]]:
n = x[0] >> 5
try:
i, = struct.unpack("!Q", b'\0' * (7 - n) + x[:n+1])
......
......@@ -52,7 +52,7 @@ def maybe_renew(path, cert, info, renew, force=False):
if time.time() < next_renew:
return cert, next_renew
try:
pem = renew()
pem: bytes = renew()
if not pem or pem == crypto.dump_certificate(
crypto.FILETYPE_PEM, cert):
exc_info = 0
......@@ -62,7 +62,7 @@ def maybe_renew(path, cert, info, renew, force=False):
exc_info = 1
break
new_path = path + '.new'
with open(new_path, 'w') as f:
with open(new_path, 'wb') as f:
f.write(pem)
try:
s = os.stat(path)
......@@ -90,9 +90,9 @@ class Cert(object):
self.ca_path = ca
self.cert_path = cert
self.key_path = key
with open(ca) as f:
with open(ca, "rb") as f:
self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
with open(key) as f:
with open(key, "rb") as f:
self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
if cert:
with open(cert) as f:
......@@ -143,22 +143,21 @@ class Cert(object):
"error running openssl, assuming cert is invalid")
# BBB: With old versions of openssl, detailed
# error is printed to standard output.
out, err = out.decode(), err.decode()
for err in err, out:
for x in err.splitlines():
for stream in err, out:
for x in stream.decode(errors='replace').splitlines():
if x.startswith('error '):
x, msg = x.split(':', 1)
_, code, _, depth, _ = x.split(None, 4)
raise VerifyError(int(code), int(depth), msg.strip())
return r
def verify(self, sign, data):
def verify(self, sign: bytes, data):
crypto.verify(self.ca, sign, data, 'sha512')
def sign(self, data):
def sign(self, data) -> bytes:
return crypto.sign(self.key, data, 'sha512')
def decrypt(self, data):
def decrypt(self, data: bytes) -> bytes:
p = openssl('rsautl', '-decrypt', '-inkey', self.key_path)
out, err = p.communicate(data)
if p.returncode:
......@@ -209,7 +208,7 @@ class Peer(object):
stop_date = float('inf')
version = b''
def __init__(self, prefix):
def __init__(self, prefix: str):
self.prefix = prefix
@property
......@@ -253,7 +252,7 @@ class Peer(object):
def _hmac(self, msg):
return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key, protocol):
def newSession(self, key: bytes, protocol):
if key <= self._key:
raise NewSessionError(self._key, key)
self._key = key
......@@ -266,7 +265,7 @@ class Peer(object):
seqno_struct = struct.Struct("!L")
def decode(self, msg, _unpack=seqno_struct.unpack):
def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> str:
seqno, = _unpack(msg[:4])
if seqno <= 2:
msg = msg[4:]
......@@ -280,10 +279,12 @@ class Peer(object):
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None
self._i = seqno
return msg[4:i]
return msg[4:i].decode()
def encode(self, msg, _pack=seqno_struct.pack):
def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes:
self._j += 1
if type(msg) is str:
msg = msg.encode()
msg = _pack(self._j) + msg
return msg + self._hmac(msg)
......
......@@ -15,7 +15,7 @@ def copy_file(self, infile, outfile, *args, **kw):
if infile == version["__file__"]:
if not self.dry_run:
log.info("generating %s -> %s", infile, outfile)
with open(outfile, "wb") as f:
with open(outfile, "w", encoding="utf-8") as f:
for x in sorted(version.items()):
if not x[0].startswith("_"):
f.write("%s = %r\n" % x)
......@@ -23,7 +23,7 @@ def copy_file(self, infile, outfile, *args, **kw):
elif isinstance(self, build_py) and \
os.stat(infile).st_mode & stat.S_IEXEC:
if os.path.isdir(infile) and os.path.isdir(outfile):
return (outfile, 0)
return outfile, 0
# Adjust interpreter of OpenVPN hooks.
with open(infile) as src:
first_line = src.readline()
......@@ -33,11 +33,8 @@ def copy_file(self, infile, outfile, *args, **kw):
executable = self.distribution.command_obj['build'].executable
patched = "#!%s%s\n" % (executable, m.group(1) or '')
patched += src.read()
dst = os.open(outfile, os.O_CREAT | os.O_WRONLY | os.O_TRUNC)
try:
os.write(dst, patched)
finally:
os.close(dst)
with open(outfile, "w") as dst:
dst.write(patched)
return outfile, 1
cls, = self.__class__.__bases__
return cls.copy_file(self, infile, outfile, *args, **kw)
......@@ -97,7 +94,7 @@ setup(
extras_require = {
'geoip': ['geoip2'],
'multicast': ['PyYAML'],
'test': ['mock', 'pathlib2', 'nemu', 'python-unshare', 'python-passfd', 'multiping']
'test': ['mock', 'nemu3', 'unshare', 'multiping']
},
#dependency_links = [
# "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment