Commit 98c9f930 authored by Tom Niget's avatar Tom Niget Committed by Tom Niget

python3: add type annotations

parent 04154193
...@@ -5,7 +5,7 @@ if 're6st' not in sys.modules: ...@@ -5,7 +5,7 @@ if 're6st' not in sys.modules:
from re6st import utils, x509 from re6st import utils, x509
from OpenSSL import crypto from OpenSSL import crypto
with open("/etc/re6stnet/ca.crt") as f: with open("/etc/re6stnet/ca.crt", "rb") as f:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
network = x509.networkFromCa(ca) network = x509.networkFromCa(ca)
......
...@@ -5,7 +5,7 @@ from . import utils, version, x509 ...@@ -5,7 +5,7 @@ from . import utils, version, x509
class Cache: class Cache:
def __init__(self, db_path, registry, cert: x509.Cert, db_size=200): def __init__(self, db_path: str, registry, cert: x509.Cert, db_size=200):
self._prefix = cert.prefix self._prefix = cert.prefix
self._db_size = db_size self._db_size = db_size
self._decrypt = cert.decrypt self._decrypt = cert.decrypt
...@@ -50,7 +50,7 @@ class Cache: ...@@ -50,7 +50,7 @@ class Cache:
self.warnProtocol() self.warnProtocol()
logging.info("Cache initialized.") logging.info("Cache initialized.")
def _open(self, path): def _open(self, path: str) -> sqlite3.Connection:
db = sqlite3.connect(path, isolation_level=None) db = sqlite3.connect(path, isolation_level=None)
db.text_factory = str db.text_factory = str
db.execute("PRAGMA synchronous = OFF") db.execute("PRAGMA synchronous = OFF")
...@@ -147,7 +147,7 @@ class Cache: ...@@ -147,7 +147,7 @@ class Cache:
logging.warning("There's a new version of re6stnet:" logging.warning("There's a new version of re6stnet:"
" you should update.") " you should update.")
def getDh(self, path): def getDh(self, path: str):
# We'd like to do a full check here but # We'd like to do a full check here but
# from OpenSSL import SSL # from OpenSSL import SSL
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path) # SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
...@@ -179,11 +179,11 @@ class Cache: ...@@ -179,11 +179,11 @@ class Cache:
logging.trace("- %s: %s%s", prefix, address, logging.trace("- %s: %s%s", prefix, address,
' (blacklisted)' if _try else '') ' (blacklisted)' if _try else '')
def cacheMinimize(self, size): def cacheMinimize(self, size: int):
with self._db: with self._db:
self._cacheMinimize(size) self._cacheMinimize(size)
def _cacheMinimize(self, size): def _cacheMinimize(self, size: int):
a = self._db.execute( a = self._db.execute(
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1", "SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1",
(size,)).fetchall() (size,)).fetchall()
...@@ -192,26 +192,26 @@ class Cache: ...@@ -192,26 +192,26 @@ class Cache:
q("DELETE FROM peer WHERE prefix IN (?)", a) q("DELETE FROM peer WHERE prefix IN (?)", a)
q("DELETE FROM volatile.stat WHERE peer IN (?)", a) q("DELETE FROM volatile.stat WHERE peer IN (?)", a)
def connecting(self, prefix, connecting): def connecting(self, prefix: str, connecting: bool):
self._db.execute("UPDATE volatile.stat SET try=? WHERE peer=?", self._db.execute("UPDATE volatile.stat SET try=? WHERE peer=?",
(connecting, prefix)) (connecting, prefix))
def resetConnecting(self): def resetConnecting(self):
self._db.execute("UPDATE volatile.stat SET try=0") self._db.execute("UPDATE volatile.stat SET try=0")
def getAddress(self, prefix): def getAddress(self, prefix: str) -> bool:
r = self._db.execute("SELECT address FROM peer, volatile.stat" r = self._db.execute("SELECT address FROM peer, volatile.stat"
" WHERE prefix=? AND prefix=peer AND try=0", " WHERE prefix=? AND prefix=peer AND try=0",
(prefix,)).fetchone() (prefix,)).fetchone()
return r and r[0] return r and r[0]
@property @property
def my_address(self): def my_address(self) -> str:
for x, in self._db.execute("SELECT address FROM peer WHERE prefix=''"): for x, in self._db.execute("SELECT address FROM peer WHERE prefix=''"):
return x return x
@my_address.setter @my_address.setter
def my_address(self, value): def my_address(self, value: str):
if value: if value:
with self._db as db: with self._db as db:
db.execute("INSERT OR REPLACE INTO peer VALUES ('', ?)", db.execute("INSERT OR REPLACE INTO peer VALUES ('', ?)",
...@@ -229,14 +229,14 @@ class Cache: ...@@ -229,14 +229,14 @@ class Cache:
# IOW, one should probably always put our own address there. # IOW, one should probably always put our own address there.
_get_peer_sql = "SELECT %s FROM peer, volatile.stat" \ _get_peer_sql = "SELECT %s FROM peer, volatile.stat" \
" WHERE prefix=peer AND prefix!=? AND try=?" " WHERE prefix=peer AND prefix!=? AND try=?"
def getPeerList(self, failed=0, __sql=_get_peer_sql % "prefix, address" def getPeerList(self, failed=False, __sql=_get_peer_sql % "prefix, address"
+ " ORDER BY RANDOM()"): + " ORDER BY RANDOM()"):
return self._db.execute(__sql, (self._prefix, failed)) return self._db.execute(__sql, (self._prefix, failed))
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)") -> int: def getPeerCount(self, failed=False, __sql=_get_peer_sql % "COUNT(*)") -> int:
return self._db.execute(__sql, (self._prefix, failed)).next()[0] return self._db.execute(__sql, (self._prefix, failed)).next()[0]
def getBootstrapPeer(self): def getBootstrapPeer(self) -> tuple[str, str]:
logging.info('Getting Boot peer...') logging.info('Getting Boot peer...')
try: try:
bootpeer = self._registry.getBootstrapPeer(self._prefix) bootpeer = self._registry.getBootstrapPeer(self._prefix)
...@@ -250,7 +250,7 @@ class Cache: ...@@ -250,7 +250,7 @@ class Cache:
return prefix, address return prefix, address
logging.warning('Buggy registry sent us our own address') logging.warning('Buggy registry sent us our own address')
def addPeer(self, prefix, address, set_preferred=False): def addPeer(self, prefix: str, address: str, set_preferred=False):
logging.debug('Adding peer %s: %s', prefix, address) logging.debug('Adding peer %s: %s', prefix, address)
with self._db: with self._db:
q = self._db.execute q = self._db.execute
...@@ -274,7 +274,7 @@ class Cache: ...@@ -274,7 +274,7 @@ class Cache:
q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address)) q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address))
q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,)) q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,))
def getCountry(self, ip): def getCountry(self, ip: str) -> str:
try: try:
return self._registry.getCountry(self._prefix, ip).decode() return self._registry.getCountry(self._prefix, ip).decode()
except socket.error as e: except socket.error as e:
......
...@@ -272,7 +272,7 @@ def main(): ...@@ -272,7 +272,7 @@ def main():
call(args) call(args)
args[3] = 'del' args[3] = 'del'
cleanup.append(lambda: subprocess.call(args)) cleanup.append(lambda: subprocess.call(args))
def ip(object, *args): def ip(object: str, *args):
args = ['ip', '-6', object, 'add'] + list(args) args = ['ip', '-6', object, 'add'] + list(args)
call(args) call(args)
args[3] = 'del' args[3] = 'del'
......
...@@ -171,7 +171,7 @@ class Babel: ...@@ -171,7 +171,7 @@ class Babel:
_decode = None _decode = None
def __init__(self, socket_path, handler, network): def __init__(self, socket_path: str, handler, network: str):
self.socket_path = socket_path self.socket_path = socket_path
self.handler = handler self.handler = handler
self.network = network self.network = network
...@@ -252,15 +252,18 @@ class Babel: ...@@ -252,15 +252,18 @@ class Babel:
unidentified = set(n) unidentified = set(n)
self.neighbours = neighbours = {} self.neighbours = neighbours = {}
a = len(self.network) a = len(self.network)
logging.info("Routes: %r", routes)
for route in routes: for route in routes:
assert route.flags & 1, route # installed assert route.flags & 1, route # installed
if route.prefix.startswith(b'\0\0\0\0\0\0\0\0\0\0\xff\xff'): if route.prefix.startswith(b'\0\0\0\0\0\0\0\0\0\0\xff\xff'):
logging.warning("Ignoring IPv4 route: %r", route)
continue continue
assert route.neigh_address == route.nexthop, route assert route.neigh_address == route.nexthop, route
address = route.neigh_address, route.ifindex address = route.neigh_address, route.ifindex
neigh_routes = n[address] neigh_routes = n[address]
ip = utils.binFromRawIp(route.prefix) ip = utils.binFromRawIp(route.prefix)
if ip[:a] == self.network: if ip[:a] == self.network:
logging.debug("Route is on the network: %r", route)
prefix = ip[a:route.plen] prefix = ip[a:route.plen]
if prefix and not route.refmetric: if prefix and not route.refmetric:
neighbours[prefix] = neigh_routes neighbours[prefix] = neigh_routes
...@@ -275,7 +278,9 @@ class Babel: ...@@ -275,7 +278,9 @@ class Babel:
socket.inet_ntop(socket.AF_INET6, route.prefix), socket.inet_ntop(socket.AF_INET6, route.prefix),
route.plen) route.plen)
else: else:
logging.debug("Route is not on the network: %r", route)
prefix = None prefix = None
logging.debug("Adding route %r to %r", route, neigh_routes)
neigh_routes[1][prefix] = route neigh_routes[1][prefix] = route
self.locked.clear() self.locked.clear()
if unidentified: if unidentified:
...@@ -299,7 +304,7 @@ class iterRoutes: ...@@ -299,7 +304,7 @@ class iterRoutes:
_waiting = True _waiting = True
def __new__(cls, control_socket, network): def __new__(cls, control_socket: str, network: str):
self = object.__new__(cls) self = object.__new__(cls)
c = Babel(control_socket, self, network) c = Babel(control_socket, self, network)
c.request_dump() c.request_dump()
......
...@@ -3,30 +3,30 @@ import errno, os, socket, stat, threading ...@@ -3,30 +3,30 @@ import errno, os, socket, stat, threading
class Socket: class Socket:
def __init__(self, socket): def __init__(self, socket: socket.socket):
# In case that the default timeout is not None. # In case that the default timeout is not None.
socket.settimeout(None) socket.settimeout(None)
self._socket = socket self._socket = socket
self._buf = '' self._buf = b''
def close(self): def close(self):
self._socket.close() self._socket.close()
def write(self, data): def write(self, data: bytes):
self._socket.send(data) self._socket.send(data)
def readline(self): def readline(self) -> bytes:
recv = self._socket.recv recv = self._socket.recv
data = self._buf data = self._buf
while True: while True:
i = 1 + data.find('\n') i = 1 + data.find(b'\n')
if i: if i:
self._buf = data[i:] self._buf = data[i:]
return data[:i] return data[:i]
d = recv(4096) d = recv(4096)
data += d data += d
if not d: if not d:
self._buf = '' self._buf = b''
return data return data
def flush(self): def flush(self):
......
...@@ -8,7 +8,7 @@ ovpn_server = os.path.join(here, 'ovpn-server') ...@@ -8,7 +8,7 @@ ovpn_server = os.path.join(here, 'ovpn-server')
ovpn_client = os.path.join(here, 'ovpn-client') ovpn_client = os.path.join(here, 'ovpn-client')
ovpn_log: Optional[str] = None ovpn_log: Optional[str] = None
def openvpn(iface, encrypt, *args, **kw): def openvpn(iface: str, encrypt, *args, **kw) -> utils.Popen:
args = ['openvpn', args = ['openvpn',
'--dev-type', 'tap', '--dev-type', 'tap',
'--dev', iface, '--dev', iface,
...@@ -28,7 +28,7 @@ def openvpn(iface, encrypt, *args, **kw): ...@@ -28,7 +28,7 @@ def openvpn(iface, encrypt, *args, **kw):
ovpn_link_mtu_dict = {'udp4': 1432, 'udp6': 1450} ovpn_link_mtu_dict = {'udp4': 1432, 'udp6': 1450}
def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw): def server(iface: str, max_clients: int, dh_path: str, fd: int, port: int, proto: str, encrypt: bool, *args, **kw) -> utils.Popen:
if proto == 'udp': if proto == 'udp':
proto = 'udp4' proto = 'udp4'
client_script = '%s %s' % (ovpn_server, fd) client_script = '%s %s' % (ovpn_server, fd)
...@@ -49,7 +49,7 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw): ...@@ -49,7 +49,7 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
*args, pass_fds=[fd], **kw) *args, pass_fds=[fd], **kw)
def client(iface, address_list, encrypt, *args, **kw): def client(iface: str, address_list: list[tuple[str, int, str]], encrypt: bool, *args, **kw) -> utils.Popen:
remote = ['--nobind', '--client'] remote = ['--nobind', '--client']
# XXX: We'd like to pass <connection> sections at command-line. # XXX: We'd like to pass <connection> sections at command-line.
link_mtu = set() link_mtu = set()
...@@ -65,8 +65,8 @@ def client(iface, address_list, encrypt, *args, **kw): ...@@ -65,8 +65,8 @@ def client(iface, address_list, encrypt, *args, **kw):
return openvpn(iface, encrypt, *remote, **kw) return openvpn(iface, encrypt, *remote, **kw)
def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile, def router(ip: tuple[str, int], ip4, rt6: tuple[str, bool, bool], hello_interval: int, log_path: str, state_path: str, pidfile: str,
control_socket, default, hmac, *args, **kw): control_socket: str, default: str, hmac: tuple[bytes | None, bytes | None], *args, **kw) -> utils.Popen:
network, gateway, has_ipv6_subtrees = rt6 network, gateway, has_ipv6_subtrees = rt6
network_mask = int(network[network.index('/')+1:]) network_mask = int(network[network.index('/')+1:])
ip, n = ip ip, n = ip
...@@ -83,7 +83,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile, ...@@ -83,7 +83,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C', 'redistribute local deny', '-C', 'redistribute local deny',
'-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)] '-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)]
if hmac_sign: if hmac_sign:
def key(cmd, id: str, value): def key(cmd: list[str], id: str, value: bytes):
cmd += '-C', ('key type blake2s128 id %s value %s' % cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, binascii.hexlify(value).decode())) (id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign) key(cmd, 'sign', hmac_sign)
......
...@@ -22,10 +22,13 @@ import base64, hmac, hashlib, http.client, inspect, json, logging ...@@ -22,10 +22,13 @@ import base64, hmac, hashlib, http.client, inspect, json, logging
import mailbox, os, platform, random, select, smtplib, socket, sqlite3 import mailbox, os, platform, random, select, smtplib, socket, sqlite3
import string, sys, threading, time, weakref, zlib import string, sys, threading, time, weakref, zlib
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Iterator
from datetime import datetime from datetime import datetime
from http.server import HTTPServer, BaseHTTPRequestHandler from http.server import HTTPServer, BaseHTTPRequestHandler
from email.mime.text import MIMEText from email.mime.text import MIMEText
from operator import itemgetter from operator import itemgetter
from typing import Tuple
from OpenSSL import crypto from OpenSSL import crypto
from urllib.parse import urlparse, unquote, urlencode from urllib.parse import urlparse, unquote, urlencode
from . import ctl, tunnel, utils, version, x509 from . import ctl, tunnel, utils, version, x509
...@@ -56,6 +59,8 @@ class RegistryServer: ...@@ -56,6 +59,8 @@ class RegistryServer:
peers = 0, () peers = 0, ()
cert_duration = 365 * 86400 cert_duration = 365 * 86400
sessions: dict[str, list[tuple[bytes, int]]]
def _geoiplookup(self, ip): def _geoiplookup(self, ip):
raise HTTPError(http.client.BAD_REQUEST) raise HTTPError(http.client.BAD_REQUEST)
...@@ -138,10 +143,10 @@ class RegistryServer: ...@@ -138,10 +143,10 @@ class RegistryServer:
if self.geoip_db: if self.geoip_db:
from geoip2 import database, errors from geoip2 import database, errors
country = database.Reader(self.geoip_db).country country = database.Reader(self.geoip_db).country
def geoiplookup(ip): def geoiplookup(ip: str) -> Tuple[str, str]:
try: try:
req = country(ip) req = country(ip)
return req.country.iso_code.encode(), req.continent.code.encode() return req.country.iso_code, req.continent.code
except (errors.AddressNotFoundError, ValueError): except (errors.AddressNotFoundError, ValueError):
return '*', '*' return '*', '*'
self._geoiplookup = geoiplookup self._geoiplookup = geoiplookup
...@@ -203,15 +208,18 @@ class RegistryServer: ...@@ -203,15 +208,18 @@ class RegistryServer:
def sendto(self, prefix: str, code: int): def sendto(self, prefix: str, code: int):
self.sock.sendto(prefix.encode() + bytes((0, code)), ('::1', tunnel.PORT)) self.sock.sendto(prefix.encode() + bytes((0, code)), ('::1', tunnel.PORT))
def recv(self, code): def recv(self, code: int) -> (str, str):
try: try:
prefix, msg = self.sock.recv(1<<16).split(b'\x00', 1) prefix, msg = self.sock.recv(1 << 16).split(b'\x00', 1)
int(prefix, 2) int(prefix, 2)
except ValueError: except ValueError:
pass pass
else: else:
if msg and msg[0:1] == code: if len(msg) >= 1 and msg[0] == code:
return prefix, msg[1:] return prefix.decode(), msg[1:].decode()
logging.error("Invalid message or unexpected code: %r", msg)
return None, None return None, None
def select(self, r, w, t): def select(self, r, w, t):
...@@ -237,7 +245,7 @@ class RegistryServer: ...@@ -237,7 +245,7 @@ class RegistryServer:
def babel_dump(self): def babel_dump(self):
self._wait_dump = False self._wait_dump = False
def iterCert(self): def iterCert(self) -> Iterator[Tuple[crypto.X509, str, str]]:
for prefix, email, cert in self.db.execute( for prefix, email, cert in self.db.execute(
"SELECT * FROM cert WHERE cert IS NOT NULL"): "SELECT * FROM cert WHERE cert IS NOT NULL"):
try: try:
...@@ -335,12 +343,12 @@ class RegistryServer: ...@@ -335,12 +343,12 @@ class RegistryServer:
if result: if result:
request.wfile.write(result) request.wfile.write(result)
def getPeerProtocol(self, cn): def getPeerProtocol(self, cn: str) -> int:
session, = self.sessions[cn] session, = self.sessions[cn]
return session[1] return session[1]
@rpc @rpc
def hello(self, client_prefix, protocol='1'): def hello(self, client_prefix: str, protocol='1') -> bytes:
with self.lock: with self.lock:
cert = self.getCert(client_prefix) cert = self.getCert(client_prefix)
key = utils.newHmacSecret() key = utils.newHmacSecret()
...@@ -350,7 +358,7 @@ class RegistryServer: ...@@ -350,7 +358,7 @@ class RegistryServer:
assert len(key) == len(sign) assert len(key) == len(sign)
return key + sign return key + sign
def getCert(self, client_prefix): def getCert(self, client_prefix: str) -> bytes:
assert self.lock.locked() assert self.lock.locked()
cert = self.db.execute("SELECT cert FROM cert" cert = self.db.execute("SELECT cert FROM cert"
" WHERE prefix=? AND cert IS NOT NULL", (client_prefix,)).fetchone() " WHERE prefix=? AND cert IS NOT NULL", (client_prefix,)).fetchone()
...@@ -358,19 +366,19 @@ class RegistryServer: ...@@ -358,19 +366,19 @@ class RegistryServer:
return cert[0] return cert[0]
@rpc_private @rpc_private
def isToken(self, token): def isToken(self, token: str):
with self.lock: with self.lock:
if self.db.execute("SELECT 1 FROM token WHERE token = ?", if self.db.execute("SELECT 1 FROM token WHERE token = ?",
(token,)).fetchone(): (token,)).fetchone():
return b"1" return b"1"
@rpc_private @rpc_private
def deleteToken(self, token): def deleteToken(self, token: str):
with self.lock: with self.lock:
self.db.execute("DELETE FROM token WHERE token = ?", (token,)) self.db.execute("DELETE FROM token WHERE token = ?", (token,))
@rpc_private @rpc_private
def addToken(self, email, token): def addToken(self, email: str, token: str | None) -> str:
prefix_len = self.config.prefix_length prefix_len = self.config.prefix_length
if not prefix_len: if not prefix_len:
raise HTTPError(http.client.FORBIDDEN) raise HTTPError(http.client.FORBIDDEN)
...@@ -498,7 +506,7 @@ class RegistryServer: ...@@ -498,7 +506,7 @@ class RegistryServer:
q("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,)) q("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,))
@rpc @rpc
def requestCertificate(self, token, req, location='', ip=''): def requestCertificate(self, token: str | None, req: bytes, location: str='', ip: str=''):
logging.debug("Requesting certificate with token %s", token) logging.debug("Requesting certificate with token %s", token)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, req) req = crypto.load_certificate_request(crypto.FILETYPE_PEM, req)
with self.lock: with self.lock:
...@@ -572,7 +580,7 @@ class RegistryServer: ...@@ -572,7 +580,7 @@ class RegistryServer:
return cert return cert
@rpc @rpc
def renewCertificate(self, cn): def renewCertificate(self, cn: str) -> bytes:
with self.lock: with self.lock:
with self.db as db: with self.db as db:
pem = self.getCert(cn) pem = self.getCert(cn)
...@@ -588,16 +596,16 @@ class RegistryServer: ...@@ -588,16 +596,16 @@ class RegistryServer:
cert.get_subject(), cert.get_pubkey(), not_after) cert.get_subject(), cert.get_pubkey(), not_after)
@rpc @rpc
def getCa(self): def getCa(self) -> bytes:
return crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert.ca) return crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert.ca)
@rpc @rpc
def getDh(self, cn): def getDh(self, cn: str) -> bytes:
with open(self.config.dh) as f: with open(self.config.dh, "rb") as f:
return f.read() return f.read()
@rpc @rpc
def getNetworkConfig(self, cn): def getNetworkConfig(self, cn: str) -> bytes:
with self.lock: with self.lock:
cert = self.getCert(cn) cert = self.getCert(cn)
config = self.network_config.copy() config = self.network_config.copy()
...@@ -607,8 +615,8 @@ class RegistryServer: ...@@ -607,8 +615,8 @@ class RegistryServer:
v and base64.b64encode(x509.encrypt(cert, v)).decode() v and base64.b64encode(x509.encrypt(cert, v)).decode()
return zlib.compress(json.dumps(config).encode("utf-8")) return zlib.compress(json.dumps(config).encode("utf-8"))
def _queryAddress(self, peer): def _queryAddress(self, peer: str) -> str:
logging.info("Querying address for %s/%s", int(peer, 2), len(peer)) logging.info("Querying address for %s/%s %r", int(peer, 2), len(peer), peer)
self.sendto(peer, 1) self.sendto(peer, 1)
s = self.sock, s = self.sock,
timeout = 3 timeout = 3
...@@ -616,7 +624,7 @@ class RegistryServer: ...@@ -616,7 +624,7 @@ class RegistryServer:
# Loop because there may be answers from previous requests. # Loop because there may be answers from previous requests.
while select.select(s, (), (), timeout)[0]: while select.select(s, (), (), timeout)[0]:
prefix, msg = self.recv(1) prefix, msg = self.recv(1)
logging.info("* received: %s - %s", prefix, msg) logging.info("* received: %r - %r", prefix, msg)
if prefix == peer: if prefix == peer:
return msg return msg
timeout = max(0, end - time.time()) timeout = max(0, end - time.time())
...@@ -624,12 +632,12 @@ class RegistryServer: ...@@ -624,12 +632,12 @@ class RegistryServer:
int(peer, 2), len(peer)) int(peer, 2), len(peer))
@rpc @rpc
def getCountry(self, cn, address): def getCountry(self, cn: str, address: str) -> str | None:
country = self._geoiplookup(address)[0] country = self._geoiplookup(address)[0]
return None if country == '*' else country.encode() return None if country == '*' else country
@rpc @rpc
def getBootstrapPeer(self, cn): def getBootstrapPeer(self, cn: str) -> bytes | None:
logging.info("Answering bootstrap peer for %s", cn) logging.info("Answering bootstrap peer for %s", cn)
with self.peers_lock: with self.peers_lock:
age, peers = self.peers age, peers = self.peers
...@@ -661,10 +669,10 @@ class RegistryServer: ...@@ -661,10 +669,10 @@ class RegistryServer:
cert = self.getCert(cn) cert = self.getCert(cn)
msg = "%s %s" % (peer, msg) msg = "%s %s" % (peer, msg)
logging.info("Sending bootstrap peer: %s", msg) logging.info("Sending bootstrap peer: %s", msg)
return x509.encrypt(cert, msg) return x509.encrypt(cert, msg.encode())
@rpc_private @rpc_private
def revoke(self, cn_or_serial): def revoke(self, cn_or_serial: int | str):
with self.lock, self.db: with self.lock, self.db:
q = self.db.execute q = self.db.execute
try: try:
...@@ -685,12 +693,12 @@ class RegistryServer: ...@@ -685,12 +693,12 @@ class RegistryServer:
q("INSERT INTO crl VALUES (?,?)", (serial, not_after)) q("INSERT INTO crl VALUES (?,?)", (serial, not_after))
self.updateNetworkConfig() self.updateNetworkConfig()
def newHMAC(self, i, key=None): def newHMAC(self, i: int, key: bytes=None):
if key is None: if key is None:
key = os.urandom(16) key = os.urandom(16)
self.setConfig(BABEL_HMAC[i], key) self.setConfig(BABEL_HMAC[i], key)
def delHMAC(self, i): def delHMAC(self, i: int):
self.db.execute("DELETE FROM config WHERE name=?", (BABEL_HMAC[i],)) self.db.execute("DELETE FROM config WHERE name=?", (BABEL_HMAC[i],))
@rpc_private @rpc_private
...@@ -717,7 +725,7 @@ class RegistryServer: ...@@ -717,7 +725,7 @@ class RegistryServer:
self.sendto(self.prefix, 0) self.sendto(self.prefix, 0)
@rpc_private @rpc_private
def getNodePrefix(self, email): def getNodePrefix(self, email: str) -> str | None:
with self.lock, self.db: with self.lock, self.db:
try: try:
cert, = next(self.db.execute("SELECT cert FROM cert WHERE email = ?", cert, = next(self.db.execute("SELECT cert FROM cert WHERE email = ?",
...@@ -728,7 +736,7 @@ class RegistryServer: ...@@ -728,7 +736,7 @@ class RegistryServer:
return x509.subnetFromCert(certificate) return x509.subnetFromCert(certificate)
@rpc_private @rpc_private
def getIPv6Address(self, email): def getIPv6Address(self, email: str) -> str:
cn = self.getNodePrefix(email) cn = self.getNodePrefix(email)
if cn: if cn:
return utils.ipFromBin( return utils.ipFromBin(
...@@ -736,7 +744,7 @@ class RegistryServer: ...@@ -736,7 +744,7 @@ class RegistryServer:
+ utils.binFromSubnet(cn)) + utils.binFromSubnet(cn))
@rpc_private @rpc_private
def getIPv4Information(self, email): def getIPv4Information(self, email: str) -> str | None:
peer = self.getNodePrefix(email) peer = self.getNodePrefix(email)
if peer: if peer:
peer = utils.binFromSubnet(peer) peer = utils.binFromSubnet(peer)
...@@ -752,10 +760,10 @@ class RegistryServer: ...@@ -752,10 +760,10 @@ class RegistryServer:
with self.lock: with self.lock:
msg = self._queryAddress(peer).decode() msg = self._queryAddress(peer).decode()
if msg: if msg:
return msg.split(',')[0].encode() return msg.split(',')[0]
@rpc_private @rpc_private
def versions(self): def versions(self) -> str:
with self.peers_lock: with self.peers_lock:
self.request_dump() self.request_dump()
peers = {prefix peers = {prefix
...@@ -817,11 +825,16 @@ class RegistryServer: ...@@ -817,11 +825,16 @@ class RegistryServer:
class RegistryClient: class RegistryClient:
"""
Client for the re6st registry.
Method calls are forwarded to the registry server. String results are always returned as bytes.
"""
_hmac = None _hmac = None
user_agent = "re6stnet/%s, %s" % (version.version, platform.platform()) user_agent = "re6stnet/%s, %s" % (version.version, platform.platform())
def __init__(self, url, cert: x509.Cert=None, auto_close=True): def __init__(self, url: str, cert: x509.Cert=None, auto_close=True):
self.cert = cert self.cert = cert
self.auto_close = auto_close self.auto_close = auto_close
url_parsed = urlparse(url) url_parsed = urlparse(url)
...@@ -831,7 +844,7 @@ class RegistryClient: ...@@ -831,7 +844,7 @@ class RegistryClient:
)[scheme](unquote(host), timeout=60) )[scheme](unquote(host), timeout=60)
self._path = path.rstrip('/') self._path = path.rstrip('/')
def __getattr__(self, name): def __getattr__(self, name: str):
getcallargs = getattr(RegistryServer, name).getcallargs getcallargs = getattr(RegistryServer, name).getcallargs
def rpc(*args, **kw) -> bytes: def rpc(*args, **kw) -> bytes:
kw = getcallargs(*args, **kw) kw = getcallargs(*args, **kw)
......
...@@ -11,11 +11,13 @@ import hashlib ...@@ -11,11 +11,13 @@ import hashlib
import time import time
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
from sqlite3 import Cursor
from OpenSSL import crypto from OpenSSL import crypto
from mock import Mock, patch from mock import Mock, patch
from pathlib import Path from pathlib import Path
from re6st import registry from re6st import registry, x509
from re6st.tests.tools import * from re6st.tests.tools import *
from re6st.tests import DEMO_PATH from re6st.tests import DEMO_PATH
...@@ -23,7 +25,7 @@ from re6st.tests import DEMO_PATH ...@@ -23,7 +25,7 @@ from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer # TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions # getIPV4Information, versions
def load_config(filename="registry.json"): def load_config(filename: str="registry.json") -> Namespace:
with open(filename) as f: with open(filename) as f:
config = json.load(f) config = json.load(f)
config["dh"] = DEMO_PATH / "dh2048.pem" config["dh"] = DEMO_PATH / "dh2048.pem"
...@@ -37,13 +39,13 @@ def load_config(filename="registry.json"): ...@@ -37,13 +39,13 @@ def load_config(filename="registry.json"):
return Namespace(**config) return Namespace(**config)
def get_cert(cur, prefix): def get_cert(cur: Cursor, prefix: str):
res = cur.execute( res = cur.execute(
"SELECT cert FROM cert WHERE prefix=?", (prefix,)).fetchone() "SELECT cert FROM cert WHERE prefix=?", (prefix,)).fetchone()
return res[0] return res[0]
def insert_cert(cur, ca, prefix, not_after=None, email=None): def insert_cert(cur: Cursor, ca: x509.Cert, prefix: str, not_after=None, email=None):
key, csr = generate_csr() key, csr = generate_csr()
cert = generate_cert(ca.ca, ca.key, csr, prefix, insert_cert.serial, not_after) cert = generate_cert(ca.ca, ca.key, csr, prefix, insert_cert.serial, not_after)
cur.execute("INSERT INTO cert VALUES (?,?,?)", (prefix, email, cert)) cur.execute("INSERT INTO cert VALUES (?,?,?)", (prefix, email, cert))
...@@ -54,7 +56,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None): ...@@ -54,7 +56,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
insert_cert.serial = 0 insert_cert.serial = 0
def delete_cert(cur, prefix): def delete_cert(cur: Cursor, prefix: str):
cur.execute("DELETE FROM cert WHERE prefix = ?", (prefix,)) cur.execute("DELETE FROM cert WHERE prefix = ?", (prefix,))
......
...@@ -92,18 +92,15 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042): ...@@ -92,18 +92,15 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
return key, cert return key, cert
def prefix2cn(prefix): def prefix2cn(prefix: str) -> str:
return "%u/%u" % (int(prefix, 2), len(prefix)) return "%u/%u" % (int(prefix, 2), len(prefix))
def serial2prefix(serial): def serial2prefix(serial: int) -> str:
return bin(serial)[2:].rjust(16, '0') return bin(serial)[2:].rjust(16, '0')
# pkey: private key # pkey: private key
def decrypt(pkey, incontent): def decrypt(pkey: bytes, incontent: bytes) -> bytes:
with open("node.key", 'w') as f: with open("node.key", 'wb') as f:
f.write(pkey.decode()) f.write(pkey)
args = "openssl rsautl -decrypt -inkey node.key".split() args = "openssl rsautl -decrypt -inkey node.key".split()
with subprocess.Popen( return subprocess.run(args, input=incontent, stdout=subprocess.PIPE).stdout
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
outcontent, err = p.communicate(incontent)
return outcontent
...@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket ...@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket
import subprocess, struct, sys, time, weakref import subprocess, struct, sys, time, weakref
from collections import defaultdict, deque from collections import defaultdict, deque
from bisect import bisect, insort from bisect import bisect, insort
from collections.abc import Iterator, Sequence
from typing import Callable, TYPE_CHECKING
from OpenSSL import crypto from OpenSSL import crypto
from . import ctl, plib, utils, version, x509 from . import ctl, plib, utils, version, x509
if TYPE_CHECKING:
from . import cache
PORT = 326 PORT = 326
...@@ -21,7 +26,7 @@ proto_dict = { ...@@ -21,7 +26,7 @@ proto_dict = {
proto_dict['tcp'] = proto_dict['tcp4'] proto_dict['tcp'] = proto_dict['tcp4']
proto_dict['udp'] = proto_dict['udp4'] proto_dict['udp'] = proto_dict['udp4']
def resolve(ip, port, proto): def resolve(ip, port, proto: str) -> tuple[socket.AddressFamily | None, Iterator[str]]:
try: try:
family, proto = proto_dict[proto] family, proto = proto_dict[proto]
except KeyError: except KeyError:
...@@ -31,16 +36,16 @@ def resolve(ip, port, proto): ...@@ -31,16 +36,16 @@ def resolve(ip, port, proto):
class MultiGatewayManager(dict): class MultiGatewayManager(dict):
def __init__(self, gateway): def __init__(self, gateway: Callable[[str], str]):
self._gw = gateway self._gw = gateway
def _route(self, cmd, dest, gw): def _route(self, cmd: str, dest: str, gw: str):
if gw: if gw:
cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw
logging.trace('%r', cmd) logging.trace('%r', cmd)
subprocess.check_call(cmd) subprocess.check_call(cmd)
def add(self, dest, route): def add(self, dest: str, route: bool):
try: try:
self[dest][1] += 1 self[dest][1] += 1
except KeyError: except KeyError:
...@@ -48,7 +53,7 @@ class MultiGatewayManager(dict): ...@@ -48,7 +53,7 @@ class MultiGatewayManager(dict):
self[dest] = [gw, 0] self[dest] = [gw, 0]
self._route('add', dest, gw) self._route('add', dest, gw)
def remove(self, dest): def remove(self, dest: str):
gw, count = self[dest] gw, count = self[dest]
if count: if count:
self[dest][1] = count - 1 self[dest][1] = count - 1
...@@ -65,7 +70,7 @@ class Connection: ...@@ -65,7 +70,7 @@ class Connection:
serial = None serial = None
time = float('inf') time = float('inf')
def __init__(self, tunnel_manager, address_list, iface, prefix): def __init__(self, tunnel_manager: "TunnelManager", address_list, iface, prefix):
self.tunnel_manager = tunnel_manager self.tunnel_manager = tunnel_manager
self.address_list = address_list self.address_list = address_list
self.iface = iface self.iface = iface
...@@ -109,7 +114,7 @@ class Connection: ...@@ -109,7 +114,7 @@ class Connection:
if i: if i:
cache.addPeer(self._prefix, ','.join(self.address_list[i]), True) cache.addPeer(self._prefix, ','.join(self.address_list[i]), True)
else: else:
cache.connecting(self._prefix, 0) cache.connecting(self._prefix, False)
def close(self): def close(self):
try: try:
...@@ -198,7 +203,7 @@ class BaseTunnelManager: ...@@ -198,7 +203,7 @@ class BaseTunnelManager:
_geoiplookup = None _geoiplookup = None
_forward = None _forward = None
def __init__(self, control_socket, cache, cert, conf_country, address=()): def __init__(self, control_socket, cache: "cache.Cache", cert: x509.Cert, conf_country, address=()):
self.cert = cert self.cert = cert
self._network = cert.network self._network = cert.network
self._prefix = cert.prefix self._prefix = cert.prefix
...@@ -450,7 +455,7 @@ class BaseTunnelManager: ...@@ -450,7 +455,7 @@ class BaseTunnelManager:
self._sendto(to, msg[0:1] + answer.encode() if answer else b'', peer) self._sendto(to, msg[0:1] + answer.encode() if answer else b'', peer)
def _processPacket(self, msg, peer=None): def _processPacket(self, msg: bytes, peer: x509.Peer|str=None):
c = msg[0] c = msg[0]
msg = msg[1:] msg = msg[1:]
code = c & 0x7f code = c & 0x7f
...@@ -564,12 +569,12 @@ class BaseTunnelManager: ...@@ -564,12 +569,12 @@ class BaseTunnelManager:
self.selectTimeout(time.time() + 1 + self.cache.delay_restart, self.selectTimeout(time.time() + 1 + self.cache.delay_restart,
self._restart) self._restart)
def handleServerEvent(self, sock): def handleServerEvent(self, sock: socket.socket):
event, args = eval(sock.recv(65536)) event, args = eval(sock.recv(65536))
logging.debug("%s%r", event, args) logging.debug("%s%r", event, args)
r = getattr(self, '_ovpn_' + event.replace('-', '_'))(*args) r = getattr(self, '_ovpn_' + event.replace('-', '_'))(*args)
if r is not None: if r is not None:
sock.send(chr(r)) sock.send(bytes([r]))
def _ovpn_client_connect(self, common_name, iface, serial, trusted_ip): def _ovpn_client_connect(self, common_name, iface, serial, trusted_ip):
if serial in self.cache.crl: if serial in self.cache.crl:
...@@ -581,7 +586,7 @@ class BaseTunnelManager: ...@@ -581,7 +586,7 @@ class BaseTunnelManager:
self._gateway_manager.add(trusted_ip, False) self._gateway_manager.add(trusted_ip, False)
if prefix in self._connection_dict and self._prefix < prefix: if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix) self._kill(prefix)
self.cache.connecting(prefix, 0) self.cache.connecting(prefix, False)
return True return True
def _ovpn_client_disconnect(self, common_name, iface, serial, trusted_ip): def _ovpn_client_disconnect(self, common_name, iface, serial, trusted_ip):
...@@ -665,7 +670,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -665,7 +670,7 @@ class TunnelManager(BaseTunnelManager):
def __init__(self, control_socket, cache, cert, openvpn_args, def __init__(self, control_socket, cache, cert, openvpn_args,
timeout, client_count, iface_list, conf_country, address, timeout, client_count, iface_list, conf_country, address,
ip_changed, remote_gateway, disable_proto, neighbour_list=()): ip_changed, remote_gateway: Callable[[str], str], disable_proto: Sequence[str], neighbour_list=()):
super(TunnelManager, self).__init__(control_socket, super(TunnelManager, self).__init__(control_socket,
cache, cert, conf_country, address) cache, cert, conf_country, address)
self.ovpn_args = openvpn_args self.ovpn_args = openvpn_args
...@@ -877,7 +882,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -877,7 +882,7 @@ class TunnelManager(BaseTunnelManager):
address_list.append((ip, x[1], x[2])) address_list.append((ip, x[1], x[2]))
continue continue
address_list.append(x[:3]) address_list.append(x[:3])
self.cache.connecting(prefix, 1) self.cache.connecting(prefix, True)
if not address_list: if not address_list:
return False return False
logging.info('Establishing a connection with %u/%u', logging.info('Establishing a connection with %u/%u',
......
...@@ -17,7 +17,7 @@ class Forwarder: ...@@ -17,7 +17,7 @@ class Forwarder:
_lcg_n = 0 _lcg_n = 0
@classmethod @classmethod
def _getExternalPort(cls): def _getExternalPort(cls) -> int:
# Since _refresh() does not test all ports in a row, we prefer to # Since _refresh() does not test all ports in a row, we prefer to
# return random ports to maximize the chance to find a free port. # return random ports to maximize the chance to find a free port.
# A linear congruential generator should be random enough, without # A linear congruential generator should be random enough, without
...@@ -35,7 +35,7 @@ class Forwarder: ...@@ -35,7 +35,7 @@ class Forwarder:
self._u.discoverdelay = 200 self._u.discoverdelay = 200
self._rules = [] self._rules = []
def __getattr__(self, name): def __getattr__(self, name: str):
wrapped = getattr(self._u, name) wrapped = getattr(self._u, name)
def wrapper(*args, **kw): def wrapper(*args, **kw):
try: try:
......
...@@ -40,7 +40,7 @@ class FileHandler(logging.FileHandler): ...@@ -40,7 +40,7 @@ class FileHandler(logging.FileHandler):
if self.lock.acquire(False): if self.lock.acquire(False):
self.release() self.release()
def setupLog(log_level, filename=None, **kw): def setupLog(log_level: int, filename: str | None=None, **kw):
if log_level and filename: if log_level and filename:
makedirs(os.path.dirname(filename)) makedirs(os.path.dirname(filename))
handler = FileHandler(filename) handler = FileHandler(filename)
...@@ -184,7 +184,7 @@ def setCloexec(fd): ...@@ -184,7 +184,7 @@ def setCloexec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD) flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def select(R, W, T): def select(R: Mapping, W: Mapping, T):
try: try:
r, w, _ = _select.select(R, W, (), r, w, _ = _select.select(R, W, (),
max(0, min(T)[0] - time.time()) if T else None) max(0, min(T)[0] - time.time()) if T else None)
...@@ -208,15 +208,15 @@ def makedirs(*args): ...@@ -208,15 +208,15 @@ def makedirs(*args):
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
def binFromIp(ip): def binFromIp(ip: str) -> str:
return binFromRawIp(socket.inet_pton(socket.AF_INET6, ip)) return binFromRawIp(socket.inet_pton(socket.AF_INET6, ip))
def binFromRawIp(ip): def binFromRawIp(ip: bytes) -> str:
ip1, ip2 = struct.unpack('>QQ', ip) ip1, ip2 = struct.unpack('>QQ', ip)
return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0') return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0')
def ipFromBin(ip, suffix=''): def ipFromBin(ip: str, suffix='') -> str:
suffix_len = 128 - len(ip) suffix_len = 128 - len(ip)
if suffix_len > 0: if suffix_len > 0:
ip += suffix.rjust(suffix_len, '0') ip += suffix.rjust(suffix_len, '0')
...@@ -225,11 +225,11 @@ def ipFromBin(ip, suffix=''): ...@@ -225,11 +225,11 @@ def ipFromBin(ip, suffix=''):
return socket.inet_ntop(socket.AF_INET6, return socket.inet_ntop(socket.AF_INET6,
struct.pack('>QQ', int(ip[:64], 2), int(ip[64:], 2))) struct.pack('>QQ', int(ip[:64], 2), int(ip[64:], 2)))
def dump_address(address): def dump_address(address: str) -> str:
return ';'.join(map(','.join, address)) return ';'.join(map(','.join, address))
# Yield ip, port, protocol, and country if it is in the address # Yield ip, port, protocol, and country if it is in the address
def parse_address(address_list): def parse_address(address_list: str) -> Iterator[tuple[str, str, str, str]]:
for address in address_list.split(';'): for address in address_list.split(';'):
try: try:
a = address.split(',') a = address.split(',')
...@@ -239,16 +239,18 @@ def parse_address(address_list): ...@@ -239,16 +239,18 @@ def parse_address(address_list):
logging.warning("Failed to parse node address %r (%s)", logging.warning("Failed to parse node address %r (%s)",
address, e) address, e)
def binFromSubnet(subnet): def binFromSubnet(subnet: str) -> str:
p, l = subnet.split('/') p, l = subnet.split('/')
return bin(int(p))[2:].rjust(int(l), '0') return bin(int(p))[2:].rjust(int(l), '0')
def newHmacSecret(): def _newHmacSecret():
from random import getrandbits as g from random import getrandbits as g
pack = struct.Struct(">QQI").pack pack = struct.Struct(">QQI").pack
assert len(pack(0,0,0)) == HMAC_LEN assert len(pack(0,0,0)) == HMAC_LEN
# A closure is built to avoid rebuilding the `pack` function at each call.
return lambda x=None: pack(g(64) if x is None else x, g(64), g(32)) return lambda x=None: pack(g(64) if x is None else x, g(64), g(32))
newHmacSecret = newHmacSecret()
newHmacSecret = _newHmacSecret() # https://github.com/python/mypy/issues/1174
### Integer serialization ### Integer serialization
# - supports values from 0 to 0x202020202020201f # - supports values from 0 to 0x202020202020201f
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time
from typing import Callable, Any
from OpenSSL import crypto from OpenSSL import crypto
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric import padding
...@@ -9,29 +11,29 @@ from cryptography.x509 import load_pem_x509_certificate ...@@ -9,29 +11,29 @@ from cryptography.x509 import load_pem_x509_certificate
from . import utils from . import utils
from .version import protocol from .version import protocol
def newHmacSecret(): def newHmacSecret() -> bytes:
return utils.newHmacSecret(int(time.time() * 1000000)) return utils.newHmacSecret(int(time.time() * 1000000))
def networkFromCa(ca): def networkFromCa(ca: crypto.X509) -> str:
# TODO: will be ca.serial_number after migration to cryptography # TODO: will be ca.serial_number after migration to cryptography
return bin(ca.get_serial_number())[3:] return bin(ca.get_serial_number())[3:]
def subnetFromCert(cert): def subnetFromCert(cert: crypto.X509) -> str:
return cert.get_subject().CN return cert.get_subject().CN
def notBefore(cert): def notBefore(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notBefore().decode(),'%Y%m%d%H%M%SZ')) return calendar.timegm(time.strptime(cert.get_notBefore().decode(),'%Y%m%d%H%M%SZ'))
def notAfter(cert): def notAfter(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notAfter().decode(),'%Y%m%d%H%M%SZ')) return calendar.timegm(time.strptime(cert.get_notAfter().decode(),'%Y%m%d%H%M%SZ'))
def openssl(*args, fds=[]): def openssl(*args: str, fds=[]) -> utils.Popen:
return utils.Popen(('openssl',) + args, return utils.Popen(('openssl',) + args,
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, pass_fds=fds) stderr=subprocess.PIPE, pass_fds=fds)
def encrypt(cert, data): def encrypt(cert: bytes, data: bytes) -> bytes:
r, w = os.pipe() r, w = os.pipe()
try: try:
threading.Thread(target=os.write, args=(w, cert)).start() threading.Thread(target=os.write, args=(w, cert)).start()
...@@ -45,10 +47,10 @@ def encrypt(cert, data): ...@@ -45,10 +47,10 @@ def encrypt(cert, data):
raise subprocess.CalledProcessError(p.returncode, 'openssl', err) raise subprocess.CalledProcessError(p.returncode, 'openssl', err)
return out return out
def fingerprint(cert, alg='sha1'): def fingerprint(cert: crypto.X509, alg='sha1'):
return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)) return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert))
def maybe_renew(path, cert, info, renew, force=False): def maybe_renew(path: str, cert: crypto.X509, info: str, renew: Callable[[], bytes], force=False) -> tuple[crypto.X509, int]:
from .registry import RENEW_PERIOD from .registry import RENEW_PERIOD
while True: while True:
if force: if force:
...@@ -58,7 +60,7 @@ def maybe_renew(path, cert, info, renew, force=False): ...@@ -58,7 +60,7 @@ def maybe_renew(path, cert, info, renew, force=False):
if time.time() < next_renew: if time.time() < next_renew:
return cert, next_renew return cert, next_renew
try: try:
pem: bytes = renew() pem = renew()
if not pem or pem == crypto.dump_certificate( if not pem or pem == crypto.dump_certificate(
crypto.FILETYPE_PEM, cert): crypto.FILETYPE_PEM, cert):
exc_info = 0 exc_info = 0
...@@ -92,7 +94,7 @@ class NewSessionError(Exception): ...@@ -92,7 +94,7 @@ class NewSessionError(Exception):
class Cert: class Cert:
def __init__(self, ca, key, cert=None): def __init__(self, ca: str, key: str, cert: str | None=None):
self.ca_path = ca self.ca_path = ca
self.cert_path = cert self.cert_path = cert
self.key_path = key self.key_path = key
...@@ -110,24 +112,24 @@ class Cert: ...@@ -110,24 +112,24 @@ class Cert:
self.cert = self.loadVerify(f.read().encode()) self.cert = self.loadVerify(f.read().encode())
@property @property
def prefix(self): def prefix(self) -> str:
return utils.binFromSubnet(subnetFromCert(self.cert)) return utils.binFromSubnet(subnetFromCert(self.cert))
@property @property
def network(self): def network(self) -> str:
return networkFromCa(self.ca) return networkFromCa(self.ca)
@property @property
def subject_serial(self): def subject_serial(self) -> int:
return int(self.cert.get_subject().serialNumber) return int(self.cert.get_subject().serialNumber)
@property @property
def openvpn_args(self): def openvpn_args(self) -> tuple[str, ...]:
return ('--ca', self.ca_path, return ('--ca', self.ca_path,
'--cert', self.cert_path, '--cert', self.cert_path,
'--key', self.key_path) '--key', self.key_path)
def maybeRenew(self, registry, crl): def maybeRenew(self, registry, crl) -> int:
self.cert, next_renew = maybe_renew(self.cert_path, self.cert, self.cert, next_renew = maybe_renew(self.cert_path, self.cert,
"Certificate", lambda: registry.renewCertificate(self.prefix), "Certificate", lambda: registry.renewCertificate(self.prefix),
self.cert.get_serial_number() in crl) self.cert.get_serial_number() in crl)
...@@ -163,7 +165,6 @@ class Cert: ...@@ -163,7 +165,6 @@ class Cert:
return r return r
def verify(self, sign: bytes, data: bytes): def verify(self, sign: bytes, data: bytes):
assert isinstance(data, bytes)
pub_key = self.ca_crypto.public_key() pub_key = self.ca_crypto.public_key()
pub_key.verify( pub_key.verify(
sign, sign,
...@@ -173,7 +174,6 @@ class Cert: ...@@ -173,7 +174,6 @@ class Cert:
) )
def sign(self, data: bytes) -> bytes: def sign(self, data: bytes) -> bytes:
assert isinstance(data, bytes)
return self.key_crypto.sign( return self.key_crypto.sign(
data, data,
padding.PKCS1v15(), padding.PKCS1v15(),
...@@ -230,6 +230,7 @@ class Peer: ...@@ -230,6 +230,7 @@ class Peer:
serial = None serial = None
stop_date = float('inf') stop_date = float('inf')
version = b'' version = b''
cert: crypto.X509
def __init__(self, prefix: str): def __init__(self, prefix: str):
self.prefix = prefix self.prefix = prefix
...@@ -247,7 +248,7 @@ class Peer: ...@@ -247,7 +248,7 @@ class Peer:
def __lt__(self, other): def __lt__(self, other):
return self.prefix < (other if type(other) is str else other.prefix) return self.prefix < (other if type(other) is str else other.prefix)
def hello0(self, cert): def hello0(self, cert: crypto.X509) -> bytes:
if self._hello < time.time(): if self._hello < time.time():
try: try:
# Always assume peer is not old, in case it has just upgraded, # Always assume peer is not old, in case it has just upgraded,
...@@ -262,7 +263,7 @@ class Peer: ...@@ -262,7 +263,7 @@ class Peer:
def hello0Sent(self): def hello0Sent(self):
self._hello = time.time() + 60 self._hello = time.time() + 60
def hello(self, cert, protocol): def hello(self, cert: Cert, protocol: int) -> bytes:
key = self._key = newHmacSecret() key = self._key = newHmacSecret()
h = encrypt(crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert), h = encrypt(crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert),
key) key)
...@@ -272,10 +273,10 @@ class Peer: ...@@ -272,10 +273,10 @@ class Peer:
return b''.join((b'\0\0\0\2', PACKED_PROTOCOL if protocol else b'', return b''.join((b'\0\0\0\2', PACKED_PROTOCOL if protocol else b'',
h, cert.sign(h))) h, cert.sign(h)))
def _hmac(self, msg): def _hmac(self, msg: bytes) -> bytes:
return hmac.HMAC(self._key, msg, hashlib.sha1).digest() return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key: bytes, protocol): def newSession(self, key: bytes, protocol: int):
if key <= self._key: if key <= self._key:
raise NewSessionError(self._key, key) raise NewSessionError(self._key, key)
self._key = key self._key = key
...@@ -283,12 +284,12 @@ class Peer: ...@@ -283,12 +284,12 @@ class Peer:
self._last = None self._last = None
self.protocol = protocol self.protocol = protocol
def verify(self, sign, data): def verify(self, sign: bytes, data: bytes):
crypto.verify(self.cert, sign, data, 'sha512') crypto.verify(self.cert, sign, data, 'sha512')
seqno_struct = struct.Struct("!L") seqno_struct = struct.Struct("!L")
def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> str: def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> tuple[int, bytes, int | None] | bytes:
seqno, = _unpack(msg[:4]) seqno, = _unpack(msg[:4])
if seqno <= 2: if seqno <= 2:
msg = msg[4:] msg = msg[4:]
...@@ -302,7 +303,7 @@ class Peer: ...@@ -302,7 +303,7 @@ class Peer:
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno: if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None self._last = None
self._i = seqno self._i = seqno
return msg[4:i].decode() return msg[4:i]
def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes: def encode(self, msg: str | bytes, _pack=seqno_struct.pack) -> bytes:
self._j += 1 self._j += 1
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment