Commit e80ca878 authored by Thomas Gambier's avatar Thomas Gambier 🚴🏼

Port re6stnet to python3

See merge request !46
parents f2fd7247 227d63d4
......@@ -5,3 +5,11 @@
/build/
/dist/
/re6stnet.egg-info/
*.log
*.pid
*.db
*.state
*.crt
*.pem
demo/mbox
*.sock
......@@ -52,7 +52,7 @@ easily scalable to tens of thousand of nodes.
Requirements
============
- Python 2.7
- Python 3.11
- OpenSSL binary and development libraries
- OpenVPN 2.4.*
- Babel_ (with Nexedi patches)
......
Demo
====
Usage
-----
To run the demo, make sure all the dependencies are installed
and run ``./demo 8000`` (or any port).
Troubleshooting
---------------
If the demo crashes and fails to clean up its resources properly,
run the following commands::
for b in $(sudo ip l | grep -Po 'NETNS\w\w[\d\-a-f]+'); do sudo ip l del $b; done
pkill screen
killall python
killall python3
find . -name '*.crt' -delete; find . -name '*.db' -delete; find . -name '*.log' -delete
.. warning::
This will kill all Python processes. These commands assume you're running
the demo on a dedicated machine with nothing else on it.
This diff is collapsed.
# -*- coding: utf-8 -*-
# Copyright 2010, 2011 INRIA
# Copyright 2011 Martín Ferrari <martin.ferrari@gmail.com>
#
# This file is contains patches to Nemu.
#
# Nemu is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License version 2, as published by the Free
# Software Foundation.
#
# Nemu is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# Nemu. If not, see <http://www.gnu.org/licenses/>.
import re
import os
from new import function
from nemu.iproute import backticks, get_if_data, route, \
get_addr_data, get_all_route_data, interface
from nemu.interface import Switch, Interface
def _get_all_route_data():
ipdata = backticks([IP_PATH, "-o", "route", "list"]) # "table", "all"
ipdata += backticks([IP_PATH, "-o", "-f", "inet6", "route", "list"])
ifdata = get_if_data()[1]
ret = []
for line in ipdata.split("\n"):
if line == "":
continue
# PATCH: parse 'from'
# PATCH: 'dev' is missing on 'unreachable' ipv4 routes
match = re.match('(?:(unicast|local|broadcast|multicast|throw|'
r'unreachable|prohibit|blackhole|nat) )?(\S+)(?: from (\S+))?'
r'(?: via (\S+))?(?: dev (\S+))?.*(?: metric (\d+))?', line)
if not match:
raise RuntimeError("Invalid output from `ip route': `%s'" % line)
tipe = match.group(1) or "unicast"
prefix = match.group(2)
#src = match.group(3)
nexthop = match.group(4)
interface = ifdata[match.group(5) or "lo"]
metric = match.group(6)
if prefix == "default" or re.search(r'/0$', prefix):
prefix = None
prefix_len = 0
else:
match = re.match(r'([0-9a-f:.]+)(?:/(\d+))?$', prefix)
prefix = match.group(1)
prefix_len = int(match.group(2) or 32)
ret.append(route(tipe, prefix, prefix_len, nexthop, interface.index,
metric))
return ret
get_all_route_data.func_code = _get_all_route_data.func_code
interface__init__ = interface.__init__
def __init__(self, *args, **kw):
interface__init__(self, *args, **kw)
if self.name:
self.name = self.name.split('@',1)[0]
interface.__init__ = __init__
get_addr_data.orig = function(get_addr_data.func_code,
get_addr_data.func_globals)
def _get_addr_data():
byidx, bynam = get_addr_data.orig()
return byidx, {name.split('@',1)[0]: a for name, a in bynam.iteritems()}
get_addr_data.func_code = _get_addr_data.func_code
@staticmethod
def _gen_if_name():
n = Interface._gen_next_id()
# Max 15 chars
# XXX: We truncate pid to not exceed IFNAMSIZ on systems with 32-bits pids
# but we should find something better to avoid possible collision.
return "NETNSif-%.4x%.3x" % (os.getpid() % 0xffff, n)
Interface._gen_if_name = _gen_if_name
@staticmethod
def _gen_br_name():
n = Switch._gen_next_id()
# XXX: same as for _gen_if_name
return "NETNSbr-%.4x%.3x" % (os.getpid() % 0xffff, n)
Switch._gen_br_name = _gen_br_name
#!/usr/bin/env python
#!/usr/bin/env python3
def __file__():
import argparse, os, sys
sys.dont_write_bytecode = True
......@@ -30,4 +30,5 @@ def __file__():
return os.path.join(sys.path[0], sys.argv[1])
__file__ = __file__()
execfile(__file__)
with open(__file__) as f:
exec(compile(f.read(), __file__, 'exec'))
......@@ -34,7 +34,7 @@ def checkHMAC(db, machines):
else:
i = 0 if hmac[0] else 1
if hmac[i] != sign or hmac[i+1] != accept:
print 'HMAC config wrong for in %s' % args
print('HMAC config wrong for in %s' % args)
rc = False
if rc:
print('All nodes use Babel with the correct HMAC configuration')
......
......@@ -5,7 +5,7 @@ if 're6st' not in sys.modules:
from re6st import utils, x509
from OpenSSL import crypto
with open("/etc/re6stnet/ca.crt") as f:
with open("/etc/re6stnet/ca.crt", "rb") as f:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
network = x509.networkFromCa(ca)
......
import json, logging, os, sqlite3, socket, subprocess, sys, time, zlib
import base64, json, logging, os, sqlite3, socket, subprocess, sys, time, zlib
from itertools import chain
from .registry import RegistryClient
from . import utils, version, x509
class Cache(object):
class Cache:
def __init__(self, db_path, registry, cert, db_size=200):
def __init__(self, db_path: str, registry, cert: x509.Cert, db_size=200):
self._prefix = cert.prefix
self._db_size = db_size
self._decrypt = cert.decrypt
......@@ -50,7 +50,7 @@ class Cache(object):
self.warnProtocol()
logging.info("Cache initialized.")
def _open(self, path):
def _open(self, path: str) -> sqlite3.Connection:
db = sqlite3.connect(path, isolation_level=None)
db.text_factory = str
db.execute("PRAGMA synchronous = OFF")
......@@ -64,9 +64,8 @@ class Cache(object):
return db
@staticmethod
def _selectConfig(execute): # BBB: blob
return ((k, str(v) if type(v) is buffer else v)
for k, v in execute("SELECT * FROM config"))
def _selectConfig(execute):
return execute("SELECT * FROM config")
def _loadConfig(self, config):
cls = self.__class__
......@@ -89,24 +88,23 @@ class Cache(object):
logging.info("Getting new network parameters from registry...")
try:
# TODO: When possible, the registry should be queried via the re6st.
x = json.loads(zlib.decompress(
self._registry.getNetworkConfig(self._prefix)))
base64 = x.pop('', ())
network_config = self._registry.getNetworkConfig(self._prefix)
logging.debug('getNetworkConfig result: %r', network_config)
x = json.loads(zlib.decompress(network_config))
base64_list = x.pop('', ())
config = {}
for k, v in x.iteritems():
for k, v in x.items():
k = str(k)
if k.startswith('babel_hmac'):
if v:
v = self._decrypt(v.decode('base64'))
elif k in base64:
v = v.decode('base64')
elif type(v) is unicode:
v = str(v)
v = self._decrypt(base64.b64decode(v))
elif k in base64_list:
v = base64.b64decode(v)
elif isinstance(v, (list, dict)):
k += ':json'
v = json.dumps(v)
config[k] = v
except socket.error, e:
except socket.error as e:
logging.warning(e)
return
except Exception:
......@@ -130,16 +128,12 @@ class Cache(object):
remove.append(k)
db.execute("DELETE FROM config WHERE name in ('%s')"
% "','".join(remove))
# BBB: Use buffer because of http://bugs.python.org/issue13676
# on Python 2.6
db.executemany("INSERT OR REPLACE INTO config VALUES(?,?)",
((k, buffer(v) if k in base64 or
k.startswith('babel_hmac') else v)
for k, v in config.iteritems()))
self._loadConfig(config.iteritems())
config.items())
self._loadConfig(config.items())
return [k[:-5] if k.endswith(':json') else k
for k in chain(remove, (k
for k, v in config.iteritems()
for k, v in config.items()
if k not in old or old[k] != v))]
def warnProtocol(self):
......@@ -147,7 +141,7 @@ class Cache(object):
logging.warning("There's a new version of re6stnet:"
" you should update.")
def getDh(self, path):
def getDh(self, path: str):
# We'd like to do a full check here but
# from OpenSSL import SSL
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
......@@ -179,11 +173,11 @@ class Cache(object):
logging.trace("- %s: %s%s", prefix, address,
' (blacklisted)' if _try else '')
def cacheMinimize(self, size):
def cacheMinimize(self, size: int):
with self._db:
self._cacheMinimize(size)
def _cacheMinimize(self, size):
def _cacheMinimize(self, size: int):
a = self._db.execute(
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1",
(size,)).fetchall()
......@@ -192,26 +186,26 @@ class Cache(object):
q("DELETE FROM peer WHERE prefix IN (?)", a)
q("DELETE FROM volatile.stat WHERE peer IN (?)", a)
def connecting(self, prefix, connecting):
def connecting(self, prefix: str, connecting: bool):
self._db.execute("UPDATE volatile.stat SET try=? WHERE peer=?",
(connecting, prefix))
def resetConnecting(self):
self._db.execute("UPDATE volatile.stat SET try=0")
def getAddress(self, prefix):
def getAddress(self, prefix: str) -> bool:
r = self._db.execute("SELECT address FROM peer, volatile.stat"
" WHERE prefix=? AND prefix=peer AND try=0",
(prefix,)).fetchone()
return r and r[0]
@property
def my_address(self):
def my_address(self) -> str:
for x, in self._db.execute("SELECT address FROM peer WHERE prefix=''"):
return x
@my_address.setter
def my_address(self, value):
def my_address(self, value: str):
if value:
with self._db as db:
db.execute("INSERT OR REPLACE INTO peer VALUES ('', ?)",
......@@ -229,18 +223,20 @@ class Cache(object):
# IOW, one should probably always put our own address there.
_get_peer_sql = "SELECT %s FROM peer, volatile.stat" \
" WHERE prefix=peer AND prefix!=? AND try=?"
def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address"
def getPeerList(self, failed=False, __sql=_get_peer_sql % "prefix, address"
+ " ORDER BY RANDOM()"):
return self._db.execute(__sql, (self._prefix, failed))
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)"):
def getPeerCount(self, failed=False, __sql=_get_peer_sql % "COUNT(*)") \
-> int:
return self._db.execute(__sql, (self._prefix, failed)).next()[0]
def getBootstrapPeer(self):
def getBootstrapPeer(self) -> tuple[str, str]:
logging.info('Getting Boot peer...')
try:
bootpeer = self._registry.getBootstrapPeer(self._prefix)
prefix, address = self._decrypt(bootpeer).split()
except (socket.error, subprocess.CalledProcessError, ValueError), e:
prefix, address = self._decrypt(bootpeer).decode().split()
except (socket.error, subprocess.CalledProcessError, ValueError) as e:
logging.warning('Failed to bootstrap (%s)',
e if bootpeer else 'no peer returned')
else:
......@@ -249,7 +245,7 @@ class Cache(object):
return prefix, address
logging.warning('Buggy registry sent us our own address')
def addPeer(self, prefix, address, set_preferred=False):
def addPeer(self, prefix: str, address: str, set_preferred=False):
logging.debug('Adding peer %s: %s', prefix, address)
with self._db:
q = self._db.execute
......@@ -273,8 +269,8 @@ class Cache(object):
q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address))
q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,))
def getCountry(self, ip):
def getCountry(self, ip: str) -> str:
try:
return self._registry.getCountry(self._prefix, ip)
except socket.error, e:
return self._registry.getCountry(self._prefix, ip).decode()
except socket.error as e:
logging.warning('Failed to get country (%s)', ip)
#!/usr/bin/python2
#!/usr/bin/env python3
import argparse, atexit, binascii, errno, hashlib
import os, subprocess, sqlite3, sys, time
from OpenSSL import crypto
......@@ -6,14 +6,14 @@ if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, x509
def create(path, text=None, mode=0666):
def create(path, text=None, mode=0o666):
fd = os.open(path, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, mode)
try:
os.write(fd, text)
finally:
os.close(fd)
def loadCert(pem):
def loadCert(pem: bytes):
return crypto.load_certificate(crypto.FILETYPE_PEM, pem)
def main():
......@@ -68,17 +68,18 @@ def main():
fingerprint = binascii.a2b_hex(fingerprint)
if hashlib.new(alg).digest_size != len(fingerprint):
raise ValueError("wrong size")
except StandardError, e:
except Exception as e:
parser.error("invalid fingerprint: %s" % e)
if x509.fingerprint(ca, alg).digest() != fingerprint:
sys.exit("CA fingerprint doesn't match")
else:
print "WARNING: it is strongly recommended to use --fingerprint option."
print("WARNING: it is strongly recommended to use --fingerprint option.")
network = x509.networkFromCa(ca)
if config.is_needed:
route, err = subprocess.Popen(('ip', '-6', '-o', 'route', 'get',
with subprocess.Popen(('ip', '-6', '-o', 'route', 'get',
utils.ipFromBin(network)),
stdout=subprocess.PIPE).communicate()
stdout=subprocess.PIPE) as proc:
route, err = proc.communicate()
sys.exit(err or route and
utils.binFromIp(route.split()[8]).startswith(network))
......@@ -89,19 +90,20 @@ def main():
reserved = 'CN', 'serial'
req = crypto.X509Req()
try:
with open(cert_path) as f:
with open(cert_path, "rb") as f:
cert = loadCert(f.read())
components = dict(cert.get_subject().get_components())
components = \
{k.decode(): v for k, v in cert.get_subject().get_components()}
for k in reserved:
components.pop(k, None)
except IOError, e:
except IOError as e:
if e.errno != errno.ENOENT:
raise
components = {}
if config.req:
components.update(config.req)
subj = req.get_subject()
for k, v in components.iteritems():
for k, v in components.items():
if k in reserved:
sys.exit(k + " field is reserved.")
if v:
......@@ -116,35 +118,35 @@ def main():
token = ''
elif not token:
if not config.email:
config.email = raw_input('Please enter your email address: ')
config.email = input('Please enter your email address: ')
s.requestToken(config.email)
token_advice = "Use --token to retry without asking a new token\n"
while not token:
token = raw_input('Please enter your token: ')
token = input('Please enter your token: ')
try:
with open(key_path) as f:
pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
key = None
print "Reusing existing key."
except IOError, e:
print("Reusing existing key.")
except IOError as e:
if e.errno != errno.ENOENT:
raise
bits = ca.get_pubkey().bits()
print "Generating %s-bit key ..." % bits
print("Generating %s-bit key ..." % bits)
pkey = crypto.PKey()
pkey.generate_key(crypto.TYPE_RSA, bits)
key = crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey)
create(key_path, key, 0600)
create(key_path, key, 0o600)
req.set_pubkey(pkey)
req.sign(pkey, 'sha512')
req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req)
req = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req).decode()
# First make sure we can open certificate file for writing,
# to avoid using our token for nothing.
cert_fd = os.open(cert_path, os.O_CREAT | os.O_WRONLY, 0666)
print "Requesting certificate ..."
cert_fd = os.open(cert_path, os.O_CREAT | os.O_WRONLY, 0o666)
print("Requesting certificate ...")
if config.location:
cert = s.requestCertificate(token, req, location=config.location)
else:
......@@ -173,7 +175,7 @@ def main():
key_path))
if not os.path.lexists(conf_path):
create(conf_path, """\
create(conf_path, ("""\
registry %s
ca %s
cert %s
......@@ -187,14 +189,14 @@ key %s
#O--verb
#O3
""" % (config.registry, ca_path, cert_path, key_path,
('country ' + config.location.split(',', 1)[0]) \
if config.location else ''))
print "Sample configuration file created."
('country ' + config.location.split(',', 1)[0])
if config.location else '')).encode())
print("Sample configuration file created.")
cn = x509.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn)
print "Your subnet: %s/%u (CN=%s)" \
% (utils.ipFromBin(subnet), len(subnet), cn)
print("Your subnet: %s/%u (CN=%s)"
% (utils.ipFromBin(subnet), len(subnet), cn))
if __name__ == "__main__":
main()
#!/usr/bin/python2
#!/usr/bin/env python3
import atexit, errno, logging, os, shutil, signal
import socket, struct, subprocess, sys
from collections import deque
......@@ -246,7 +246,7 @@ def main():
try:
from re6st.upnpigd import Forwarder
forwarder = Forwarder('re6stnet openvpn server')
except Exception, e:
except Exception as e:
if ipv4:
raise
logging.info("%s: assume we are not NATed", e)
......@@ -266,19 +266,13 @@ def main():
def call(cmd):
logging.debug('%r', cmd)
p = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout, stderr = p.communicate()
if p.returncode:
raise EnvironmentError("%r failed with error %u\n%s"
% (' '.join(cmd), p.returncode, stderr))
return stdout
def ip4(object, *args):
return subprocess.run(cmd, capture_output=True, check=True).stdout
def ip4(object: str, *args):
args = ['ip', '-4', object, 'add'] + list(args)
call(args)
args[3] = 'del'
cleanup.append(lambda: subprocess.call(args))
def ip(object, *args):
def ip(object: str, *args):
args = ['ip', '-6', object, 'add'] + list(args)
call(args)
args[3] = 'del'
......@@ -299,7 +293,7 @@ def main():
timeout = 4 * cache.hello
cleanup = [lambda: cache.cacheMinimize(config.client_count),
lambda: shutil.rmtree(config.run, True)]
utils.makedirs(config.run, 0700)
utils.makedirs(config.run, 0o700)
control_socket = os.path.join(config.run, 'babeld.sock')
if config.client_count and not config.client:
tunnel_manager = tunnel.TunnelManager(control_socket,
......@@ -362,7 +356,7 @@ def main():
if not dh:
dh = os.path.join(config.state, "dh.pem")
cache.getDh(dh)
for iface, (port, proto) in server_tunnels.iteritems():
for iface, (port, proto) in server_tunnels.items():
r, x = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
utils.setCloexec(r)
cleanup.append(plib.server(iface, config.max_clients,
......@@ -442,7 +436,7 @@ def main():
except:
pass
exit.release()
except ReexecException, e:
except ReexecException as e:
logging.info(e)
except Exception:
utils.log_exception()
......@@ -455,7 +449,7 @@ def main():
if __name__ == "__main__":
try:
main()
except SystemExit, e:
except SystemExit as e:
if type(e.code) is str:
if hasattr(logging, 'trace'): # utils.setupLog called
logging.critical(e.code)
......
#!/usr/bin/python2
import httplib, logging, os, socket, sys
from BaseHTTPServer import BaseHTTPRequestHandler
from SocketServer import ThreadingTCPServer
from urlparse import parse_qsl
#!/usr/bin/env python3
import http.client, logging, os, socket, sys
from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingTCPServer
from urllib.parse import parse_qsl
if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, version
......@@ -29,14 +29,14 @@ class RequestHandler(BaseHTTPRequestHandler):
path = self.path
query = {}
else:
query = dict(parse_qsl(query, keep_blank_values=1,
strict_parsing=1))
query = dict(parse_qsl(query, keep_blank_values=True,
strict_parsing=True))
_, path = path.split('/')
if not _:
return self.server.handle_request(self, path, query)
except Exception:
logging.info(self.requestline, exc_info=1)
self.send_error(httplib.BAD_REQUEST)
logging.info(self.requestline, exc_info=True)
self.send_error(http.client.BAD_REQUEST)
def log_error(*args):
pass
......
......@@ -5,7 +5,7 @@ from . import utils
uint16 = struct.Struct("!H")
header = struct.Struct("!HI")
class Struct(object):
class Struct:
def __init__(self, format, *args):
if args:
......@@ -29,39 +29,39 @@ class Struct(object):
self.encode = encode
self.decode = decode
class Array(object):
class Array:
def __init__(self, item):
self._item = item
def encode(self, buffer, value):
def encode(self, buffer: bytes, value: list):
buffer += uint16.pack(len(value))
encode = self._item.encode
for value in value:
encode(buffer, value)
def decode(self, buffer, offset=0):
def decode(self, buffer: bytes, offset=0) -> tuple[int, list]:
r = []
o = offset + 2
decode = self._item.decode
for i in xrange(*uint16.unpack_from(buffer, offset)):
for i in range(*uint16.unpack_from(buffer, offset)):
o, x = decode(buffer, o)
r.append(x)
return o, r
class String(object):
class String:
@staticmethod
def encode(buffer, value):
buffer += value + "\0"
def encode(buffer: bytes, value: str):
buffer += value.encode("utf-8") + b'\0'
@staticmethod
def decode(buffer, offset=0):
i = buffer.index("\0", offset)
return i + 1, buffer[offset:i]
def decode(buffer: bytes, offset=0) -> tuple[int, str]:
i = buffer.index(0, offset)
return i + 1, buffer[offset:i].decode("utf-8")
class Buffer(object):
class Buffer:
def __init__(self):
self._buf = bytearray()
......@@ -104,21 +104,6 @@ class Buffer(object):
self._seek(r)
return value
try: # BBB: Python < 2.7.4 (http://bugs.python.org/issue10212)
uint16.unpack_from(bytearray(uint16.size))
except TypeError:
def unpack_from(self, struct):
r = self._r
x = r + struct.size
value = struct.unpack(buffer(self._buf)[r:x])
self._seek(x)
return value
def decode(self, decode):
r = self._r
size, value = decode(buffer(self._buf)[r:])
self._seek(r + size)
return value
# writing
def send(self, socket, *args):
......@@ -129,7 +114,7 @@ class Buffer(object):
struct.pack_into(self._buf, offset, *args)
class Packet(object):
class Packet:
response_dict = {}
......@@ -149,7 +134,7 @@ class Packet(object):
logging.trace('send %s%r', self.__class__.__name__,
(self.id,) + self.args)
offset = len(buffer)
buffer += '\0' * header.size
buffer += bytes(header.size)
r = self.request
if isinstance(r, Struct):
r.encode(buffer, self.args)
......@@ -182,11 +167,11 @@ class ConnectionClosed(BabelException):
return "connection to babeld closed (%s)" % self.args
class Babel(object):
class Babel:
_decode = None
def __init__(self, socket_path, handler, network):
def __init__(self, socket_path: str, handler, network: str):
self.socket_path = socket_path
self.handler = handler
self.network = network
......@@ -206,11 +191,11 @@ class Babel(object):
def select(*args):
try:
s.connect(self.socket_path)
except socket.error, e:
except socket.error as e:
logging.debug("Can't connect to %r (%r)", self.socket_path, e)
return e
s.send("\1")
s.setblocking(0)
s.send(b'\1')
s.setblocking(False)
del self.select
self.socket = s
return self.select(*args)
......@@ -267,15 +252,18 @@ class Babel(object):
unidentified = set(n)
self.neighbours = neighbours = {}
a = len(self.network)
logging.info("Routes: %r", routes)
for route in routes:
assert route.flags & 1, route # installed
if route.prefix.startswith('\0\0\0\0\0\0\0\0\0\0\xff\xff'):
if route.prefix.startswith(b'\0\0\0\0\0\0\0\0\0\0\xff\xff'):
logging.warning("Ignoring IPv4 route: %r", route)
continue
assert route.neigh_address == route.nexthop, route
address = route.neigh_address, route.ifindex
neigh_routes = n[address]
ip = utils.binFromRawIp(route.prefix)
if ip[:a] == self.network:
logging.debug("Route is on the network: %r", route)
prefix = ip[a:route.plen]
if prefix and not route.refmetric:
neighbours[prefix] = neigh_routes
......@@ -290,7 +278,9 @@ class Babel(object):
socket.inet_ntop(socket.AF_INET6, route.prefix),
route.plen)
else:
logging.debug("Route is not on the network: %r", route)
prefix = None
logging.debug("Adding route %r to %r", route, neigh_routes)
neigh_routes[1][prefix] = route
self.locked.clear()
if unidentified:
......@@ -310,11 +300,11 @@ class Babel(object):
pass
class iterRoutes(object):
class iterRoutes:
_waiting = True
def __new__(cls, control_socket, network):
def __new__(cls, control_socket: str, network: str):
self = object.__new__(cls)
c = Babel(control_socket, self, network)
c.request_dump()
......@@ -323,7 +313,7 @@ class iterRoutes(object):
c.select(*args)
utils.select(*args)
return (prefix
for neigh_routes in c.neighbours.itervalues()
for neigh_routes in c.neighbours.values()
for prefix in neigh_routes[1]
if prefix)
......
import errno, os, socket, stat, threading
class Socket(object):
class Socket:
def __init__(self, socket):
def __init__(self, socket: socket.socket):
# In case that the default timeout is not None.
socket.settimeout(None)
self._socket = socket
self._buf = ''
self._buf = b''
def close(self):
self._socket.close()
def write(self, data):
def write(self, data: bytes):
self._socket.send(data)
def readline(self):
def readline(self) -> bytes:
recv = self._socket.recv
data = self._buf
while True:
i = 1 + data.find('\n')
i = 1 + data.find(b'\n')
if i:
self._buf = data[i:]
return data[:i]
d = recv(4096)
data += d
if not d:
self._buf = ''
self._buf = b''
return data
def flush(self):
......@@ -37,14 +37,14 @@ class Socket(object):
try:
self._socket.recv(0)
return True
except socket.error, (err, _):
if err != errno.EAGAIN:
except socket.error as e:
if e.errno != errno.EAGAIN:
raise
self._socket.setblocking(1)
return False
class Console(object):
class Console:
def __init__(self, path, pdb):
self.path = path
......@@ -52,7 +52,7 @@ class Console(object):
socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
try:
self._removeSocket()
except OSError, e:
except OSError as e:
if e.errno != errno.ENOENT:
raise
s.bind(path)
......
......@@ -43,7 +43,7 @@ freeifaddrs = libc.freeifaddrs
freeifaddrs.restype = None
freeifaddrs.argtypes = [POINTER(struct_ifaddrs)]
class unpacker(object):
class unpacker:
def __init__(self, buf):
self._buf = buf
......@@ -55,7 +55,7 @@ class unpacker(object):
self._offset += s.size
return result
class PimDm(object):
class PimDm:
def __init__(self):
s_netlink = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)
......
#!/usr/bin/python2 -S
#!/usr/bin/env -S python3 -S
import os, sys
script_type = os.environ['script_type']
......@@ -14,4 +14,5 @@ if script_type == 'up':
if script_type == 'route-up':
import time
os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(),
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'])))
int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip']))
.encode())
#!/usr/bin/python2 -S
#!/usr/bin/env -S python3 -S
import os, sys
script_type = os.environ['script_type']
......@@ -7,10 +7,11 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
# Write into pipe connect/disconnect events
fd = int(sys.argv[1])
os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'],
int(os.environ['tls_serial_0']), external_ip))))
int(os.environ['tls_serial_0']), external_ip)))
.encode("utf-8"))
if script_type == 'client-connect':
if os.read(fd, 1) == '\0':
if os.read(fd, 1) == b'\0':
sys.exit(1)
# Send client its external ip address
with open(sys.argv[2], 'w') as f:
......
import binascii
import logging, errno, os
from typing import Optional
from . import utils
here = os.path.realpath(os.path.dirname(__file__))
ovpn_server = os.path.join(here, 'ovpn-server')
ovpn_client = os.path.join(here, 'ovpn-client')
ovpn_log = None
ovpn_log: Optional[str] = None
def openvpn(iface, encrypt, *args, **kw):
def openvpn(iface: str, encrypt, *args, **kw) -> utils.Popen:
args = ['openvpn',
'--dev-type', 'tap',
'--dev', iface,
......@@ -19,13 +21,16 @@ def openvpn(iface, encrypt, *args, **kw):
if ovpn_log:
args += '--log-append', os.path.join(ovpn_log, '%s.log' % iface),
if not encrypt:
# TODO: --ncp-disable was deprecated in OpenVPN 2.5 and removed in 2.6
# and is no longer necessary in those versions.
args += '--cipher', 'none', '--ncp-disable'
logging.debug('%r', args)
return utils.Popen(args, **kw)
ovpn_link_mtu_dict = {'udp4': 1432, 'udp6': 1450}
def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
def server(iface: str, max_clients: int, dh_path: str, fd: int,
port: int, proto: str, encrypt: bool, *args, **kw) -> utils.Popen:
if proto == 'udp':
proto = 'udp4'
client_script = '%s %s' % (ovpn_server, fd)
......@@ -43,10 +48,11 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
'--max-clients', str(max_clients),
'--port', str(port),
'--proto', proto,
*args, **kw)
*args, pass_fds=[fd], **kw)
def client(iface, address_list, encrypt, *args, **kw):
def client(iface: str, address_list: list[tuple[str, int, str]],
encrypt: bool, *args, **kw) -> utils.Popen:
remote = ['--nobind', '--client']
# XXX: We'd like to pass <connection> sections at command-line.
link_mtu = set()
......@@ -62,8 +68,10 @@ def client(iface, address_list, encrypt, *args, **kw):
return openvpn(iface, encrypt, *remote, **kw)
def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
control_socket, default, hmac, *args, **kw):
def router(ip: tuple[str, int], ip4, rt6: tuple[str, bool, bool],
hello_interval: int, log_path: str, state_path: str, pidfile: str,
control_socket: str, default: str,
hmac: tuple[bytes | None, bytes | None], *args, **kw) -> utils.Popen:
network, gateway, has_ipv6_subtrees = rt6
network_mask = int(network[network.index('/')+1:])
ip, n = ip
......@@ -80,9 +88,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C', 'redistribute local deny',
'-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)]
if hmac_sign:
def key(cmd, id, value):
def key(cmd: list[str], id: str, value: bytes):
cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, value.encode('hex')))
(id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign)
default += ' key sign'
if hmac_accept is not None:
......@@ -132,7 +140,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
# WKRD: babeld fails to start if pidfile already exists
try:
os.remove(pidfile)
except OSError, e:
except OSError as e:
if e.errno != errno.ENOENT:
raise
logging.info('%r', cmd)
......
This diff is collapsed.
from pathlib2 import Path
from pathlib import Path
DEMO_PATH = Path(__file__).resolve().parent.parent.parent / "demo"
......@@ -15,7 +15,7 @@ from re6st.tests import DEMO_PATH
DH_FILE = DEMO_PATH / "dh2048.pem"
class DummyNode(object):
class DummyNode:
"""fake node to reuse Re6stRegistry
error: node.Popen has destory method which not in subprocess.Popen
......@@ -29,19 +29,20 @@ class TestRegistryClientInteract(unittest.TestCase):
@classmethod
def setUpClass(cls):
re6st_wrap.initial()
# if running in net ns, set lo up
subprocess.check_call(("ip", "link", "set", "lo", "up"))
def setUp(self):
re6st_wrap.initial()
self.port = 18080
self.url = "http://localhost:{}/".format(self.port)
# not important, used in network_config check
self.max_clients = 10
def tearDown(self):
self.server.proc.terminate()
with self.server.proc as p:
p.terminate()
def test_1_main(self):
""" a client interact a server, no re6stnet node test basic function"""
......@@ -60,7 +61,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# read token from db
db = sqlite3.connect(str(self.server.db), isolation_level=None)
token = None
for _ in xrange(100):
for _ in range(100):
time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?",
(email,)).fetchone()
......@@ -70,7 +71,7 @@ class TestRegistryClientInteract(unittest.TestCase):
self.fail("Request token failed, no token in database")
# token: tuple[unicode,]
token = str(token[0])
self.assertEqual(client.isToken(token), "1")
self.assertEqual(client.isToken(token), b"1")
# request ca
ca = client.getCa()
......@@ -78,7 +79,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# request a cert and get cn
key, csr = tools.generate_csr()
cert = client.requestCertificate(token, csr)
self.assertEqual(client.isToken(token), '', "token should be deleted")
self.assertEqual(client.isToken(token), b'', "token should be deleted")
# creat x509.cert object
def write_to_temp(text):
......@@ -97,7 +98,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# verfiy cn and prefix
prefix = client.cert.prefix
cn = client.getNodePrefix(email)
cn = client.getNodePrefix(email).decode()
self.assertEqual(tools.prefix2cn(prefix), cn)
# simulate the process in cache
......@@ -108,7 +109,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# no re6stnet, empty result
bootpeer = client.getBootstrapPeer(prefix)
self.assertEqual(bootpeer, "")
self.assertEqual(bootpeer, b"")
# server should not die
self.assertIsNone(self.server.proc.poll())
......
......@@ -3,14 +3,9 @@ import logging
import nemu
import time
import weakref
from subprocess import PIPE
from pathlib2 import Path
from subprocess import DEVNULL, PIPE
from pathlib import Path
from re6st.tests import DEMO_PATH
fix_file = DEMO_PATH / "fixnemu.py"
# execfile(str(fix_file)) Removed in python3
exec(open(str(fix_file)).read())
IPTABLES = 'iptables-nft'
class ConnectableError(Exception):
......@@ -50,7 +45,7 @@ class Node(nemu.Node):
if_s.add_v4_address(ip, prefix_len=prefix_len)
return if_s
class NetManager(object):
class NetManager:
"""contain all the nemu object created, so they can live more time"""
def __init__(self):
self.object = []
......@@ -60,9 +55,10 @@ class NetManager(object):
Raise:
AssertionError
"""
for reg, nodes in self.registries.iteritems():
for reg, nodes in self.registries.items():
for node in nodes:
app0 = node.Popen(["ping", "-c", "1", reg.ip], stdout=PIPE)
with node.Popen(["ping", "-c", "1", reg.ip],
stdout=DEVNULL) as app0:
ret = app0.wait()
if ret:
raise ConnectableError(
......
......@@ -6,13 +6,15 @@ import ipaddress
import json
import logging
import re
import shlex
import shutil
import sqlite3
import sys
import tempfile
import time
import weakref
from subprocess import PIPE
from pathlib2 import Path
from pathlib import Path
from re6st.tests import tools
from re6st.tests import DEMO_PATH
......@@ -20,13 +22,15 @@ from re6st.tests import DEMO_PATH
WORK_DIR = Path(__file__).parent / "temp_net_test"
DH_FILE = DEMO_PATH / "dh2048.pem"
RE6STNET = "python -m re6st.cli.node"
RE6ST_REGISTRY = "python -m re6st.cli.registry"
RE6ST_CONF = "python -m re6st.cli.conf"
PYTHON = shlex.quote(sys.executable)
RE6STNET = PYTHON + " -m re6st.cli.node"
RE6ST_REGISTRY = PYTHON + " -m re6st.cli.registry"
RE6ST_CONF = PYTHON + " -m re6st.cli.conf"
def initial():
"""create the workplace"""
if not WORK_DIR.exists():
if WORK_DIR.exists():
shutil.rmtree(str(WORK_DIR))
WORK_DIR.mkdir()
def ip_to_serial(ip6):
......@@ -36,7 +40,7 @@ def ip_to_serial(ip6):
return int(ip6, 16)
class Re6stRegistry(object):
class Re6stRegistry:
"""class run a re6st-registry service on a namespace"""
registry_seq = 0
......@@ -72,7 +76,7 @@ class Re6stRegistry(object):
self.run()
# wait the servcice started
p = self.node.Popen(['python', '-c', """if 1:
p = self.node.Popen([sys.executable, '-c', """if 1:
import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True:
......@@ -115,7 +119,7 @@ class Re6stRegistry(object):
'--client-count', (self.client_number+1)//2, '--port', self.port]
#PY3: convert PosixPath to str, can be remove in Python 3
cmd = map(str, cmd)
cmd = list(map(str, cmd))
cmd[:0] = RE6ST_REGISTRY.split()
......@@ -131,15 +135,19 @@ class Re6stRegistry(object):
if e.errno != errno.ENOENT:
raise
def __del__(self):
def terminate(self):
try:
logging.debug("teminate process %s", self.proc.pid)
self.proc.destroy()
with self.proc as p:
p.destroy()
except:
pass
def __del__(self):
self.terminate()
class Re6stNode(object):
class Re6stNode:
"""class run a re6stnet service on a namespace"""
node_seq = 0
......@@ -210,7 +218,7 @@ class Re6stNode(object):
# read token
db = sqlite3.connect(str(self.registry.db), isolation_level=None)
token = None
for _ in xrange(100):
for _ in range(100):
time.sleep(.1)
token = db.execute("SELECT token FROM token WHERE email=?",
(self.email,)).fetchone()
......@@ -223,7 +231,8 @@ class Re6stNode(object):
out, _ = p.communicate(str(token[0]))
# logging.debug("re6st-conf output: {}".format(out))
# find the ipv6 subnet of node
self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+', out).group(0)
self.ip6 = re.search('(?<=subnet: )[0-9:a-z]+',
out.decode("utf-8")).group(0)
data = {'ip6': self.ip6, 'hash': self.registry.ident}
with open(str(self.data_file), 'w') as f:
json.dump(data, f)
......@@ -236,7 +245,7 @@ class Re6stNode(object):
'--key', self.key, '-v4', '--registry', self.registry.url,
'--console', self.console]
#PY3: same as for Re6stRegistry.run
cmd = map(str, cmd)
cmd = list(map(str, cmd))
cmd[:0] = RE6STNET.split()
cmd += args
......@@ -260,7 +269,8 @@ class Re6stNode(object):
def stop(self):
"""stop running re6stnet process"""
logging.debug("%s teminate process %s", self.name, self.proc.pid)
self.proc.destroy()
with self.proc as p:
p.destroy()
def __del__(self):
"""teminate process and rm temp dir"""
......
"""contain ping-test for re6set net"""
import os
import sys
import unittest
import time
import psutil
import logging
import random
from pathlib2 import Path
from pathlib import Path
import network_build
import re6st_wrap
from . import network_build, re6st_wrap
PING_PATH = str(Path(__file__).parent.resolve() / "ping.py")
def deploy_re6st(nm, recreate=False):
net = nm.registries
nodes = []
registries = []
re6st_wrap.Re6stRegistry.registry_seq = 0
re6st_wrap.Re6stNode.node_seq = 0
for registry in net:
reg = re6st_wrap.Re6stRegistry(registry, "2001:db8:42::", len(net[registry]),
recreate=recreate)
reg_node = re6st_wrap.Re6stNode(registry, reg, name=reg.name)
registries.append(reg)
reg_node.run("--gateway", "--disable-proto", "none", "--ip", registry.ip)
nodes.append(reg_node)
for m in net[registry]:
node = re6st_wrap.Re6stNode(m, reg)
node.run("-i" + m.iface.name)
nodes.append(node)
return nodes, registries
def wait_stable(nodes, timeout=240):
"""try use ping6 from each node to the other until ping success to all the
other nodes
......@@ -47,12 +28,13 @@ def wait_stable(nodes, timeout=240):
for node in nodes:
sub_ips = set(ips) - {node.ip6}
node.ping_proc = node.node.Popen(
["python", PING_PATH, '--retry', '-a'] + list(sub_ips))
[sys.executable, PING_PATH, '--retry', '-a'] + list(sub_ips),
env=os.environ)
# check all the node network can ping each other, in order reverse
unfinished = list(nodes)
while unfinished:
for i in xrange(len(unfinished)-1, -1, -1):
for i in range(len(unfinished)-1, -1, -1):
node = unfinished[i]
if node.ping_proc.poll() is not None:
logging.debug("%s 's network is stable", node.name)
......@@ -75,8 +57,42 @@ class TestNet(unittest.TestCase):
def setUpClass(cls):
"""create work dir"""
logging.basicConfig(level=logging.INFO)
def setUp(self):
re6st_wrap.initial()
def deploy_re6st(self, nm, recreate=False):
net = nm.registries
nodes = []
registries = []
re6st_wrap.Re6stRegistry.registry_seq = 0
re6st_wrap.Re6stNode.node_seq = 0
for registry in net:
reg = re6st_wrap.Re6stRegistry(registry, "2001:db8:42::",
len(net[registry]),
recreate=recreate)
reg_node = re6st_wrap.Re6stNode(registry, reg, name=reg.name)
registries.append(reg)
reg_node.run("--gateway", "--disable-proto", "none",
"--ip", registry.ip)
nodes.append(reg_node)
for m in net[registry]:
node = re6st_wrap.Re6stNode(m, reg)
node.run("-i" + m.iface.name)
nodes.append(node)
def clean_re6st():
for node in nodes:
node.node.destroy()
node.stop()
for reg in registries:
reg.terminate()
self.addCleanup(clean_re6st)
return nodes, registries
@classmethod
def tearDownClass(cls):
"""watch any process leaked after tests"""
......@@ -94,7 +110,7 @@ class TestNet(unittest.TestCase):
"""create a network in a net segment, test the connectivity by ping
"""
nm = network_build.net_route()
nodes, _ = deploy_re6st(nm)
nodes, _ = self.deploy_re6st(nm)
wait_stable(nodes, 40)
time.sleep(10)
......@@ -107,7 +123,7 @@ class TestNet(unittest.TestCase):
then test if network recover, this test seems always failed
"""
nm = network_build.net_demo()
nodes, _ = deploy_re6st(nm)
nodes, _ = self.deploy_re6st(nm)
wait_stable(nodes, 100)
......@@ -126,7 +142,7 @@ class TestNet(unittest.TestCase):
then test if network recover,
"""
nm = network_build.net_route()
nodes, _ = deploy_re6st(nm)
nodes, _ = self.deploy_re6st(nm)
wait_stable(nodes, 40)
......
#!/usr/bin/python2
#!/usr/bin/env python3
""" unit test for re6st-conf
"""
......@@ -6,7 +6,7 @@ import os
import sys
import unittest
from shutil import rmtree
from StringIO import StringIO
from io import StringIO
from mock import patch
from OpenSSL import crypto
......@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase):
# mocked server cert and pkey
cls.pkey, cls.cert = create_ca_file(os.devnull, os.devnull)
cls.fingerprint = "".join( cls.cert.digest("sha1").split(":"))
cls.fingerprint = "".join(cls.cert.digest("sha1").decode().split(":"))
# client.getCa should return a string form cert
cls.cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cls.cert)
......@@ -72,7 +72,7 @@ class TestConf(unittest.TestCase):
# go back to original dir
os.chdir(self.origin_dir)
@patch("__builtin__.raw_input")
@patch("builtins.input")
def test_basic(self, mock_raw_input):
""" go through all the step
getCa, requestToken, requestCertificate
......
......@@ -3,7 +3,7 @@ import os
import random
import string
import json
import httplib
import http.client
import base64
import unittest
import hmac
......@@ -11,18 +11,21 @@ import hashlib
import time
import tempfile
from argparse import Namespace
from sqlite3 import Cursor
from OpenSSL import crypto
from mock import Mock, patch
from pathlib2 import Path
from pathlib import Path
from re6st import registry
from re6st import registry, x509
from re6st.tests.tools import *
from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions
def load_config(filename="registry.json"):
def load_config(filename: str="registry.json") -> Namespace:
with open(filename) as f:
config = json.load(f)
config["dh"] = DEMO_PATH / "dh2048.pem"
......@@ -36,23 +39,25 @@ def load_config(filename="registry.json"):
return Namespace(**config)
def get_cert(cur, prefix):
def get_cert(cur: Cursor, prefix: str):
res = cur.execute(
"SELECT cert FROM cert WHERE prefix=?", (prefix,)).fetchone()
return res[0]
def insert_cert(cur, ca, prefix, not_after=None, email=None):
def insert_cert(cur: Cursor, ca: x509.Cert, prefix: str,
not_after=None, email=None):
key, csr = generate_csr()
cert = generate_cert(ca.ca, ca.key, csr, prefix, insert_cert.serial, not_after)
cur.execute("INSERT INTO cert VALUES (?,?,?)", (prefix, email, cert))
insert_cert.serial += 1
return key, cert
insert_cert.serial = 0
def delete_cert(cur, prefix):
def delete_cert(cur: Cursor, prefix: str):
cur.execute("DELETE FROM cert WHERE prefix = ?", (prefix,))
......@@ -68,6 +73,7 @@ class TestRegistryServer(unittest.TestCase):
@classmethod
def tearDownClass(cls):
cls.server.close()
# remove database
for file in [cls.config.db, cls.config.ca, cls.config.key]:
try:
......@@ -80,14 +86,23 @@ class TestRegistryServer(unittest.TestCase):
+ "@mail.com"
def test_recv(self):
recv = self.server.sock.recv = Mock()
recv.side_effect = [
side_effect = iter([
"0001001001001a_msg",
"0001001001002\0001dqdq",
"0001001001001\000a_msg",
"0001001001001\000\4a_msg",
"0000000000000\0" # ERROR, IndexError: msg is null
]
])
class SocketProxy:
def __init__(self, wrappee):
self.wrappee = wrappee
self.recv = lambda _: next(side_effect)
def __getattr__(self, attr):
return getattr(self.wrappee, attr)
self.server.sock = SocketProxy(self.server.sock)
try:
res1 = self.server.recv(4)
......@@ -115,7 +130,7 @@ class TestRegistryServer(unittest.TestCase):
now = int(time.time()) - self.config.grace_period + 20
# makeup data
insert_cert(cur, self.server.cert, prefix_old, 1)
insert_cert(cur, self.server.cert, prefix, now -1)
insert_cert(cur, self.server.cert, prefix, now - 1)
cur.execute("INSERT INTO token VALUES (?,?,?,?)",
(token_old, self.email, 4, 2))
cur.execute("INSERT INTO token VALUES (?,?,?,?)",
......@@ -143,16 +158,16 @@ class TestRegistryServer(unittest.TestCase):
prefix = "0000000011111111"
method = "func"
protocol = 7
params = {"cn" : prefix, "a" : 1, "b" : 2}
params = {"cn": prefix, "a": 1, "b": 2}
func.getcallargs.return_value = params
del func._private
func.return_value = result = "this_is_a_result"
key = "this_is_a_key"
func.return_value = result = b"this_is_a_result"
key = b"this_is_a_key"
self.server.sessions[prefix] = [(key, protocol)]
request = Mock()
request.path = "/func?a=1&b=2&cn=0000000011111111"
request.headers = {registry.HMAC_HEADER: base64.b64encode(
hmac.HMAC(key, request.path, hashlib.sha1).digest())}
hmac.HMAC(key, request.path.encode(), hashlib.sha1).digest())}
self.server.handle_request(request, method, params)
......@@ -162,11 +177,12 @@ class TestRegistryServer(unittest.TestCase):
[(hashlib.sha1(key).digest(), protocol)])
func.assert_called_once_with(**params)
# http response check
request.send_response.assert_called_once_with(httplib.OK)
request.send_response.assert_called_once_with(http.client.OK)
request.send_header.assert_any_call("Content-Length", str(len(result)))
request.send_header.assert_any_call(
registry.HMAC_HEADER,
base64.b64encode(hmac.HMAC(key, result, hashlib.sha1).digest()))
base64.b64encode(hmac.HMAC(key, result, hashlib.sha1).digest())
.decode("ascii"))
request.wfile.write.assert_called_once_with(result)
# remove the create session \n
......@@ -176,7 +192,7 @@ class TestRegistryServer(unittest.TestCase):
def test_handle_request_private(self, func):
"""case request with _private attr"""
method = "func"
params = {"a" : 1, "b" : 2}
params = {"a": 1, "b": 2}
func.getcallargs.return_value = params
func.return_value = None
request_good = Mock()
......@@ -189,8 +205,8 @@ class TestRegistryServer(unittest.TestCase):
self.server.handle_request(request_bad, method, params)
func.assert_called_once_with(**params)
request_bad.send_error.assert_called_once_with(httplib.FORBIDDEN)
request_good.send_response.assert_called_once_with(httplib.NO_CONTENT)
request_bad.send_error.assert_called_once_with(http.client.FORBIDDEN)
request_good.send_response.assert_called_once_with(http.client.NO_CONTENT)
# will cause valueError, if a node send hello twice to a registry
def test_getPeerProtocol(self):
......@@ -213,7 +229,7 @@ class TestRegistryServer(unittest.TestCase):
res = self.server.hello(prefix, protocol=protocol)
# decrypt
length = len(res)/2
length = len(res) // 2
key, sign = res[:length], res[length:]
key = decrypt(pkey, key)
self.assertEqual(self.server.sessions[prefix][-1][0], key,
......@@ -282,7 +298,7 @@ class TestRegistryServer(unittest.TestCase):
nb_less = 0
for cert in self.server.iterCert():
s = cert[0].get_subject().serialNumber
if(s and int(s) <= serial):
if s and int(s) <= serial:
nb_less += 1
self.assertEqual(nb_less, serial)
......@@ -378,7 +394,7 @@ class TestRegistryServer(unittest.TestCase):
hmacs = get_hmac()
key_1 = hmacs[1]
self.assertEqual(hmacs, [None, key_1, ''])
self.assertEqual(hmacs, [None, key_1, b''])
# step 2
self.server.updateHMAC()
......@@ -397,12 +413,11 @@ class TestRegistryServer(unittest.TestCase):
self.assertEqual(get_hmac(), [None, key_2, key_1])
#setp 5
# step 5
self.server.updateHMAC()
self.assertEqual(get_hmac(), [key_2, None, None])
def test_getNodePrefix(self):
# prefix in short format
prefix = "0000000101"
......@@ -426,20 +441,38 @@ class TestRegistryServer(unittest.TestCase):
('0000000000000001', '2 0/16 6/16')
]
recv.side_effect = recv_case
def side_effct(rlist, wlist, elist, timeout):
# rlist is true until the len(recv_case)th call
side_effct.i -= side_effct.i > 0
return [side_effct.i, wlist, None]
side_effct.i = len(recv_case) + 1
select.side_effect = side_effct
res = self.server.topology()
expect_res = '{"36893488147419103232/80": ["0/16", "7/16"], ' \
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], ' \
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '\
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}'''
self.assertEqual(res, expect_res)
class CustomDecoder(json.JSONDecoder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.parse_array = self.JSONArray
self.scan_once = json.scanner.py_make_scanner(self)
def JSONArray(self, *args, **kw):
values, end = json.decoder.JSONArray(*args, **kw)
return set(values), end
res = json.loads(res, cls=CustomDecoder)
self.assertEqual(res, {
"": {"36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"},
"0/16": {"6/16", "7/16"},
"1/16": {"6/16", "0/16"},
"36893488147419103232/80": {"0/16", "7/16"},
"3/16": {"0/16", "7/16"},
"4/16": {"0/16"},
"7/16": {"6/16", "4/16"},
})
if __name__ == "__main__":
......
......@@ -2,7 +2,7 @@ import sys
import os
import unittest
import hmac
import httplib
import http.client
import base64
import hashlib
from mock import Mock, patch
......@@ -26,15 +26,15 @@ class TestRegistryClient(unittest.TestCase):
self.assertEqual(client1._path, "/example")
self.assertEqual(client1._conn.host, "localhost")
self.assertIsInstance(client1._conn, httplib.HTTPSConnection)
self.assertIsInstance(client2._conn, httplib.HTTPConnection)
self.assertIsInstance(client1._conn, http.client.HTTPSConnection)
self.assertIsInstance(client2._conn, http.client.HTTPConnection)
def test_rpc_hello(self):
prefix = "0000000011111111"
protocol = "7"
body = "a_hmac_key"
query = "/hello?client_prefix=0000000011111111&protocol=7"
response = fakeResponse(body, httplib.OK)
response = fakeResponse(body, http.client.OK)
self.client._conn.getresponse.return_value = response
res = self.client.hello(prefix, protocol)
......@@ -52,14 +52,15 @@ class TestRegistryClient(unittest.TestCase):
self.client._hmac = None
self.client.hello = Mock(return_value = "aaabbb")
self.client.cert = Mock()
key = "this_is_a_key"
key = b"this_is_a_key"
self.client.cert.decrypt.return_value = key
h = hmac.HMAC(key, query, hashlib.sha1).digest()
h = hmac.HMAC(key, query.encode(), hashlib.sha1).digest()
key = hashlib.sha1(key).digest()
# response part
body = None
response = fakeResponse(body, httplib.NO_CONTENT)
response.msg = dict(Re6stHMAC=hmac.HMAC(key, body, hashlib.sha1).digest())
body = b'this is a body'
response = fakeResponse(body, http.client.NO_CONTENT)
response.msg = dict(Re6stHMAC=base64.b64encode(
hmac.HMAC(key, body, hashlib.sha1).digest()))
self.client._conn.getresponse.return_value = response
res = self.client.getNetworkConfig(cn)
......
#!/usr/bin/python2
#!/usr/bin/env python3
import os
import sys
import unittest
......@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase):
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# def test_processPacket_address_with_msg_peer(self, makeTunnel):
# """code is 1, peer and msg not none """
# c = chr(1)
# c = b"\x01"
# msg = "address"
# peer = x509.Peer("000001")
# self.tunnel._connecting = {peer}
......@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase):
def test_processPacket_address(self):
"""code is 1, for address. And peer or msg are none"""
c = chr(1)
c = b"\x01"
self.tunnel._address = {1: "1,1", 2: "2,2"}
res = self.tunnel._processPacket(c)
......@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase):
and each address join by ;
it will truncate address which has more than 3 element
"""
c = chr(1)
c = b"\x01"
peer = x509.Peer("000001")
peer.protocol = 1
self.tunnel._peers.append(peer)
......@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase):
"""code is 0, for network version, peer is not none
2 case, one modify the version, one not
"""
c = chr(0)
c = b"\x00"
peer = x509.Peer("000001")
version1 = "00003"
version2 = "00007"
self.tunnel._version = version3 = "00005"
version1 = b"00003"
version2 = b"00007"
self.tunnel._version = version3 = b"00005"
self.tunnel._peers.append(peer)
res = self.tunnel._processPacket(c + version1, peer)
......
#!/usr/bin/python2
#!/usr/bin/env python3
import os
import sys
import unittest
......
......@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
return
crypto.X509Cert in pem format
"""
if type(ca) is str:
if type(ca) is bytes:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca)
if type(ca_key) is str:
if type(ca_key) is bytes:
ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
......@@ -40,7 +40,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
cert.gmtime_adj_notBefore(0)
if not_after:
cert.set_notAfter(
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)))
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)).encode())
else:
cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration)
subject = req.get_subject()
......@@ -56,9 +56,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial):
pkey, csr = generate_csr()
cert = generate_cert(ca, ca_key, csr, prefix, serial)
with open(pkey_file, 'w') as f:
with open(pkey_file, 'wb') as f:
f.write(pkey)
with open(cert_file, 'w') as f:
with open(cert_file, 'wb') as f:
f.write(cert)
return pkey, cert
......@@ -84,26 +84,23 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
cert.set_pubkey(key)
cert.sign(key, "sha512")
with open(pkey_file, 'w') as pkey_file:
with open(pkey_file, 'wb') as pkey_file:
pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key))
with open(cert_file, 'w') as cert_file:
with open(cert_file, 'wb') as cert_file:
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
return key, cert
def prefix2cn(prefix):
def prefix2cn(prefix: str) -> str:
return "%u/%u" % (int(prefix, 2), len(prefix))
def serial2prefix(serial):
def serial2prefix(serial: int) -> str:
return bin(serial)[2:].rjust(16, '0')
# pkey: private key
def decrypt(pkey, incontent):
with open("node.key", 'w') as f:
def decrypt(pkey: bytes, incontent: bytes) -> bytes:
with open("node.key", 'wb') as f:
f.write(pkey)
args = "openssl rsautl -decrypt -inkey node.key".split()
p = subprocess.Popen(
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
outcontent, err = p.communicate(incontent)
return outcontent
return subprocess.run(args, input=incontent, stdout=subprocess.PIPE).stdout
This diff is collapsed.
......@@ -7,7 +7,7 @@ class UPnPException(Exception):
pass
class Forwarder(object):
class Forwarder:
"""
External port is chosen randomly between 32768 & 49151 included.
"""
......@@ -17,7 +17,7 @@ class Forwarder(object):
_lcg_n = 0
@classmethod
def _getExternalPort(cls):
def _getExternalPort(cls) -> int:
# Since _refresh() does not test all ports in a row, we prefer to
# return random ports to maximize the chance to find a free port.
# A linear congruential generator should be random enough, without
......@@ -35,12 +35,12 @@ class Forwarder(object):
self._u.discoverdelay = 200
self._rules = []
def __getattr__(self, name):
def __getattr__(self, name: str):
wrapped = getattr(self._u, name)
def wrapper(*args, **kw):
try:
return wrapped(*args, **kw)
except Exception, e:
except Exception as e:
raise UPnPException(str(e))
return wraps(wrapped)(wrapper)
......@@ -68,14 +68,14 @@ class Forwarder(object):
else:
try:
return self._refresh()
except UPnPException, e:
logging.debug("UPnP failure", exc_info=1)
except UPnPException as e:
logging.debug("UPnP failure", exc_info=True)
self.clear()
try:
self.discover()
self.selectigd()
return self._refresh()
except UPnPException, e:
except UPnPException as e:
self.next_refresh = self._next_retry = time.time() + 60
logging.info(str(e))
self.clear()
......@@ -109,7 +109,7 @@ class Forwarder(object):
try:
self.addportmapping(port, *args)
break
except UPnPException, e:
except UPnPException as e:
if str(e) != 'ConflictInMappingEntry':
raise
port = None
......
import argparse, errno, fcntl, hashlib, logging, os, select as _select
import shlex, signal, socket, sqlite3, struct, subprocess
import sys, textwrap, threading, time, traceback
from collections.abc import Iterator, Mapping
# PY3: It will be even better to use Popen(pass_fds=...),
# and then socket.SOCK_CLOEXEC will be useless.
# (We already follow the good practice that consists in not
# relying on the GC for the closing of file descriptors.)
socket.SOCK_CLOEXEC = 0x80000
HMAC_LEN = len(hashlib.sha1('').digest())
HMAC_LEN = len(hashlib.sha1(b'').digest())
class ReexecException(Exception):
pass
......@@ -37,15 +32,15 @@ class FileHandler(logging.FileHandler):
finally:
self.lock.release()
# In the rare case _reopen is set just before the lock was released
if self._reopen and self.lock.acquire(0):
if self._reopen and self.lock.acquire(False):
self.release()
def async_reopen(self, *_):
self._reopen = True
if self.lock.acquire(0):
if self.lock.acquire(False):
self.release()
def setupLog(log_level, filename=None, **kw):
def setupLog(log_level: int, filename: str | None=None, **kw):
if log_level and filename:
makedirs(os.path.dirname(filename))
handler = FileHandler(filename)
......@@ -119,7 +114,7 @@ class ArgParser(argparse.ArgumentParser):
ca /etc/re6stnet/ca.crt""", **kw)
class exit(object):
class exit:
status = None
......@@ -150,7 +145,7 @@ class exit(object):
def handler(*args):
if self.status is None:
self.status = status
if self.acquire(0):
if self.acquire(False):
self.release()
for sig in sigs:
signal.signal(sig, handler)
......@@ -164,7 +159,7 @@ class Popen(subprocess.Popen):
self._args = tuple(args[0] if args else kw['args'])
try:
super(Popen, self).__init__(*args, **kw)
except OSError, e:
except OSError as e:
if e.errno != errno.ENOMEM:
raise
self.returncode = -1
......@@ -179,9 +174,9 @@ class Popen(subprocess.Popen):
self.terminate()
t = threading.Timer(5, self.kill)
t.start()
# PY3: use waitid(WNOWAIT) and call self.poll() after t.cancel()
r = self.wait()
r = os.waitid(os.P_PID, self.pid, os.WNOWAIT)
t.cancel()
self.poll()
return r
......@@ -189,7 +184,7 @@ def setCloexec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def select(R, W, T):
def select(R: Mapping, W: Mapping, T):
try:
r, w, _ = _select.select(R, W, (),
max(0, min(T)[0] - time.time()) if T else None)
......@@ -209,19 +204,19 @@ def select(R, W, T):
def makedirs(*args):
try:
os.makedirs(*args)
except OSError, e:
except OSError as e:
if e.errno != errno.EEXIST:
raise
def binFromIp(ip):
def binFromIp(ip: str) -> str:
return binFromRawIp(socket.inet_pton(socket.AF_INET6, ip))
def binFromRawIp(ip):
def binFromRawIp(ip: bytes) -> str:
ip1, ip2 = struct.unpack('>QQ', ip)
return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0')
def ipFromBin(ip, suffix=''):
def ipFromBin(ip: str, suffix='') -> str:
suffix_len = 128 - len(ip)
if suffix_len > 0:
ip += suffix.rjust(suffix_len, '0')
......@@ -230,30 +225,32 @@ def ipFromBin(ip, suffix=''):
return socket.inet_ntop(socket.AF_INET6,
struct.pack('>QQ', int(ip[:64], 2), int(ip[64:], 2)))
def dump_address(address):
def dump_address(address: str) -> str:
return ';'.join(map(','.join, address))
# Yield ip, port, protocol, and country if it is in the address
def parse_address(address_list):
def parse_address(address_list: str) -> Iterator[tuple[str, str, str, str]]:
for address in address_list.split(';'):
try:
a = address.split(',')
int(a[1]) # Check if port is an int
yield tuple(a[:4])
except ValueError, e:
except ValueError as e:
logging.warning("Failed to parse node address %r (%s)",
address, e)
def binFromSubnet(subnet):
def binFromSubnet(subnet: str) -> str:
p, l = subnet.split('/')
return bin(int(p))[2:].rjust(int(l), '0')
def newHmacSecret():
def _newHmacSecret():
from random import getrandbits as g
pack = struct.Struct(">QQI").pack
assert len(pack(0,0,0)) == HMAC_LEN
# A closure is built to avoid rebuilding the `pack` function at each call.
return lambda x=None: pack(g(64) if x is None else x, g(64), g(32))
newHmacSecret = newHmacSecret()
newHmacSecret = _newHmacSecret() # https://github.com/python/mypy/issues/1174
### Integer serialization
# - supports values from 0 to 0x202020202020201f
......@@ -261,21 +258,21 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes
def packInteger(i):
for n in xrange(8):
def packInteger(i: int) -> bytes:
for n in range(8):
x = 32 << 8 * n
if i < x:
return struct.pack("!Q", i + n * x)[7-n:]
i -= x
raise OverflowError
def unpackInteger(x):
n = ord(x[0]) >> 5
def unpackInteger(x: bytes) -> tuple[int, int] | None:
n = x[0] >> 5
try:
i, = struct.unpack("!Q", '\0' * (7 - n) + x[:n+1])
i, = struct.unpack("!Q", b'\0' * (7 - n) + x[:n+1])
except struct.error:
return
return sum((32 << 8 * i for i in xrange(n)),
return sum((32 << 8 * i for i in range(n)),
i - (n * 32 << 8 * n)), n + 1
###
......
......@@ -40,4 +40,4 @@ protocol = 8
min_protocol = 1
if __name__ == "__main__":
print version
print(version)
# -*- coding: utf-8 -*-
import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time
from typing import Callable, Any
from OpenSSL import crypto
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509 import load_pem_x509_certificate
from . import utils
from .version import protocol
def newHmacSecret():
def newHmacSecret() -> bytes:
return utils.newHmacSecret(int(time.time() * 1000000))
def networkFromCa(ca):
def networkFromCa(ca: crypto.X509) -> str:
# TODO: will be ca.serial_number after migration to cryptography
return bin(ca.get_serial_number())[3:]
def subnetFromCert(cert):
def subnetFromCert(cert: crypto.X509) -> str:
return cert.get_subject().CN
def notBefore(cert):
return calendar.timegm(time.strptime(cert.get_notBefore(),'%Y%m%d%H%M%SZ'))
def notBefore(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notBefore().decode(),
'%Y%m%d%H%M%SZ'))
def notAfter(cert):
return calendar.timegm(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ'))
def notAfter(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notAfter().decode(),
'%Y%m%d%H%M%SZ'))
def openssl(*args):
def openssl(*args: str, fds=[]) -> utils.Popen:
return utils.Popen(('openssl',) + args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stderr=subprocess.PIPE, pass_fds=fds)
def encrypt(cert, data):
def encrypt(cert: bytes, data: bytes) -> bytes:
r, w = os.pipe()
try:
threading.Thread(target=os.write, args=(w, cert)).start()
p = openssl('rsautl', '-encrypt', '-certin',
'-inkey', '/proc/self/fd/%u' % r)
'-inkey', '/proc/self/fd/%u' % r, fds=[r])
out, err = p.communicate(data)
finally:
os.close(r)
......@@ -39,10 +49,12 @@ def encrypt(cert, data):
raise subprocess.CalledProcessError(p.returncode, 'openssl', err)
return out
def fingerprint(cert, alg='sha1'):
def fingerprint(cert: crypto.X509, alg='sha1'):
return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert))
def maybe_renew(path, cert, info, renew, force=False):
def maybe_renew(path: str, cert: crypto.X509, info: str,
renew: Callable[[], bytes],
force=False) -> tuple[crypto.X509, int]:
from .registry import RENEW_PERIOD
while True:
if force:
......@@ -62,7 +74,7 @@ def maybe_renew(path, cert, info, renew, force=False):
exc_info = 1
break
new_path = path + '.new'
with open(new_path, 'w') as f:
with open(new_path, 'wb') as f:
f.write(pem)
try:
s = os.stat(path)
......@@ -84,39 +96,44 @@ class NewSessionError(Exception):
pass
class Cert(object):
class Cert:
def __init__(self, ca, key, cert=None):
def __init__(self, ca: str, key: str, cert: str | None=None):
self.ca_path = ca
self.cert_path = cert
self.key_path = key
with open(ca) as f:
self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
with open(key) as f:
self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read())
# TODO: finish migration from old OpenSSL module to cryptography
with open(ca, "rb") as f:
ca_pem = f.read()
self.ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca_pem)
self.ca_crypto = load_pem_x509_certificate(ca_pem)
with open(key, "rb") as f:
key_pem = f.read()
self.key = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem)
self.key_crypto = load_pem_private_key(key_pem, password=None)
if cert:
with open(cert) as f:
self.cert = self.loadVerify(f.read())
self.cert = self.loadVerify(f.read().encode())
@property
def prefix(self):
def prefix(self) -> str:
return utils.binFromSubnet(subnetFromCert(self.cert))
@property
def network(self):
def network(self) -> str:
return networkFromCa(self.ca)
@property
def subject_serial(self):
def subject_serial(self) -> int:
return int(self.cert.get_subject().serialNumber)
@property
def openvpn_args(self):
def openvpn_args(self) -> tuple[str, ...]:
return ('--ca', self.ca_path,
'--cert', self.cert_path,
'--key', self.key_path)
def maybeRenew(self, registry, crl):
def maybeRenew(self, registry, crl) -> int:
self.cert, next_renew = maybe_renew(self.cert_path, self.cert,
"Certificate", lambda: registry.renewCertificate(self.prefix),
self.cert.get_serial_number() in crl)
......@@ -143,21 +160,31 @@ class Cert(object):
"error running openssl, assuming cert is invalid")
# BBB: With old versions of openssl, detailed
# error is printed to standard output.
for err in err, out:
for x in err.splitlines():
for stream in err, out:
for x in stream.decode(errors='replace').splitlines():
if x.startswith('error '):
x, msg = x.split(':', 1)
_, code, _, depth, _ = x.split(None, 4)
raise VerifyError(int(code), int(depth), msg.strip())
return r
def verify(self, sign, data):
crypto.verify(self.ca, sign, data, 'sha512')
def sign(self, data):
return crypto.sign(self.key, data, 'sha512')
def decrypt(self, data):
def verify(self, sign: bytes, data: bytes):
pub_key = self.ca_crypto.public_key()
pub_key.verify(
sign,
data,
padding.PKCS1v15(),
hashes.SHA512()
)
def sign(self, data: bytes) -> bytes:
return self.key_crypto.sign(
data,
padding.PKCS1v15(),
hashes.SHA512()
)
def decrypt(self, data: bytes) -> bytes:
p = openssl('rsautl', '-decrypt', '-inkey', self.key_path)
out, err = p.communicate(data)
if p.returncode:
......@@ -166,7 +193,7 @@ class Cert(object):
def verifyVersion(self, version):
try:
n = 1 + (ord(version[0]) >> 5)
n = 1 + (version[0] >> 5)
self.verify(version[n:], version[:n])
except (IndexError, crypto.Error):
raise VerifyError(None, None, 'invalid network version')
......@@ -175,7 +202,7 @@ class Cert(object):
PACKED_PROTOCOL = utils.packInteger(protocol)
class Peer(object):
class Peer:
"""
UDP: A ─────────────────────────────────────────────> B
......@@ -206,9 +233,10 @@ class Peer(object):
_key = newHmacSecret()
serial = None
stop_date = float('inf')
version = ''
version = b''
cert: crypto.X509
def __init__(self, prefix):
def __init__(self, prefix: str):
self.prefix = prefix
@property
......@@ -224,35 +252,35 @@ class Peer(object):
def __lt__(self, other):
return self.prefix < (other if type(other) is str else other.prefix)
def hello0(self, cert):
def hello0(self, cert: crypto.X509) -> bytes:
if self._hello < time.time():
try:
# Always assume peer is not old, in case it has just upgraded,
# else we would be stuck with the old protocol.
msg = ('\0\0\0\1'
msg = (b'\0\0\0\1'
+ PACKED_PROTOCOL
+ fingerprint(self.cert).digest())
except AttributeError:
msg = '\0\0\0\0'
msg = b'\0\0\0\0'
return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)
def hello0Sent(self):
self._hello = time.time() + 60
def hello(self, cert, protocol):
def hello(self, cert: Cert, protocol: int) -> bytes:
key = self._key = newHmacSecret()
h = encrypt(crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert),
key)
self._i = self._j = 2
self._last = 0
self.protocol = protocol
return ''.join(('\0\0\0\2', PACKED_PROTOCOL if protocol else '',
return b''.join((b'\0\0\0\2', PACKED_PROTOCOL if protocol else b'',
h, cert.sign(h)))
def _hmac(self, msg):
def _hmac(self, msg: bytes) -> bytes:
return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key, protocol):
def newSession(self, key: bytes, protocol: int):
if key <= self._key:
raise NewSessionError(self._key, key)
self._key = key
......@@ -260,12 +288,13 @@ class Peer(object):
self._last = None
self.protocol = protocol
def verify(self, sign, data):
def verify(self, sign: bytes, data: bytes):
crypto.verify(self.cert, sign, data, 'sha512')
seqno_struct = struct.Struct("!L")
def decode(self, msg, _unpack=seqno_struct.unpack):
def decode(self, msg: bytes, _unpack=seqno_struct.unpack) \
-> tuple[int, bytes, int | None] | bytes:
seqno, = _unpack(msg[:4])
if seqno <= 2:
msg = msg[4:]
......@@ -281,8 +310,10 @@ class Peer(object):
self._i = seqno
return msg[4:i]
def encode(self, msg, _pack=seqno_struct.pack):
def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes:
self._j += 1
if type(msg) is str:
msg = msg.encode()
msg = _pack(self._j) + msg
return msg + self._hmac(msg)
......
......@@ -7,21 +7,23 @@ from setuptools.command import sdist as _sdist, build_py as _build_py
from distutils import log
version = {"__file__": "re6st/version.py"}
execfile(version["__file__"], version)
with open(version["__file__"]) as f:
code = compile(f.read(), version["__file__"], 'exec')
exec(code, version)
def copy_file(self, infile, outfile, *args, **kw):
if infile == version["__file__"]:
if not self.dry_run:
log.info("generating %s -> %s", infile, outfile)
with open(outfile, "wb") as f:
for x in sorted(version.iteritems()):
with open(outfile, "w") as f:
for x in sorted(version.items()):
if not x[0].startswith("_"):
f.write("%s = %r\n" % x)
return outfile, 1
elif isinstance(self, build_py) and \
os.stat(infile).st_mode & stat.S_IEXEC:
if os.path.isdir(infile) and os.path.isdir(outfile):
return (outfile, 0)
return outfile, 0
# Adjust interpreter of OpenVPN hooks.
with open(infile) as src:
first_line = src.readline()
......@@ -33,7 +35,7 @@ def copy_file(self, infile, outfile, *args, **kw):
patched += src.read()
dst = os.open(outfile, os.O_CREAT | os.O_WRONLY | os.O_TRUNC)
try:
os.write(dst, patched)
os.write(dst, patched.encode())
finally:
os.close(dst)
return outfile, 1
......@@ -51,7 +53,8 @@ Environment :: Console
License :: OSI Approved :: GNU General Public License (GPL)
Natural Language :: English
Operating System :: POSIX :: Linux
Programming Language :: Python :: 2.7
Programming Language :: Python :: 3
Programming Language :: Python :: 3.11
Topic :: Internet
Topic :: System :: Networking
"""
......@@ -73,6 +76,7 @@ setup(
license = 'GPL 2+',
platforms = ["any"],
classifiers=classifiers.splitlines(),
python_requires = '>=3.11',
long_description = ".. contents::\n\n" + open('README.rst').read()
+ "\n" + open('CHANGES.rst').read() + git_rev,
packages = find_packages(),
......@@ -95,7 +99,7 @@ setup(
extras_require = {
'geoip': ['geoip2'],
'multicast': ['PyYAML'],
'test': ['mock', 'pathlib2', 'nemu', 'python-unshare', 'python-passfd', 'multiping']
'test': ['mock', 'nemu3', 'unshare', 'multiping']
},
#dependency_links = [
# "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment