Commit 3397afd1 authored by Tom Niget's avatar Tom Niget

py2to3: add lots of type hints and fix some python 2 migration bugs

parent 621842fa
......@@ -5,7 +5,7 @@ from . import utils, version, x509
class Cache:
def __init__(self, db_path, registry, cert: x509.Cert, db_size=200):
def __init__(self, db_path: str, registry, cert: x509.Cert, db_size=200):
self._prefix = cert.prefix
self._db_size = db_size
self._decrypt = cert.decrypt
......@@ -50,7 +50,7 @@ class Cache:
self.warnProtocol()
logging.info("Cache initialized.")
def _open(self, path):
def _open(self, path: str) -> sqlite3.Connection:
db = sqlite3.connect(path, isolation_level=None)
db.text_factory = str
db.execute("PRAGMA synchronous = OFF")
......@@ -147,7 +147,7 @@ class Cache:
logging.warning("There's a new version of re6stnet:"
" you should update.")
def getDh(self, path):
def getDh(self, path: str):
# We'd like to do a full check here but
# from OpenSSL import SSL
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
......@@ -179,11 +179,11 @@ class Cache:
logging.trace("- %s: %s%s", prefix, address,
' (blacklisted)' if _try else '')
def cacheMinimize(self, size):
def cacheMinimize(self, size: int):
with self._db:
self._cacheMinimize(size)
def _cacheMinimize(self, size):
def _cacheMinimize(self, size: int):
a = self._db.execute(
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1",
(size,)).fetchall()
......@@ -192,26 +192,27 @@ class Cache:
q("DELETE FROM peer WHERE prefix IN (?)", a)
q("DELETE FROM volatile.stat WHERE peer IN (?)", a)
def connecting(self, prefix, connecting):
def connecting(self, prefix: str, connecting: int):
# TODO: is `connecting` a bool?
self._db.execute("UPDATE volatile.stat SET try=? WHERE peer=?",
(connecting, prefix))
def resetConnecting(self):
self._db.execute("UPDATE volatile.stat SET try=0")
def getAddress(self, prefix):
def getAddress(self, prefix: str) -> bool:
r = self._db.execute("SELECT address FROM peer, volatile.stat"
" WHERE prefix=? AND prefix=peer AND try=0",
(prefix,)).fetchone()
return r and r[0]
@property
def my_address(self):
def my_address(self) -> str:
for x, in self._db.execute("SELECT address FROM peer WHERE prefix=''"):
return x
@my_address.setter
def my_address(self, value):
def my_address(self, value: str):
if value:
with self._db as db:
db.execute("INSERT OR REPLACE INTO peer VALUES ('', ?)",
......@@ -236,7 +237,7 @@ class Cache:
def getPeerCount(self, failed=0, __sql=_get_peer_sql % "COUNT(*)") -> int:
return self._db.execute(__sql, (self._prefix, failed)).next()[0]
def getBootstrapPeer(self):
def getBootstrapPeer(self) -> tuple[str, str]:
logging.info('Getting Boot peer...')
try:
bootpeer = self._registry.getBootstrapPeer(self._prefix)
......@@ -250,7 +251,7 @@ class Cache:
return prefix, address
logging.warning('Buggy registry sent us our own address')
def addPeer(self, prefix, address, set_preferred=False):
def addPeer(self, prefix: str, address: str, set_preferred=False):
logging.debug('Adding peer %s: %s', prefix, address)
with self._db:
q = self._db.execute
......@@ -274,7 +275,7 @@ class Cache:
q("INSERT OR REPLACE INTO peer VALUES (?,?)", (prefix, address))
q("INSERT OR REPLACE INTO volatile.stat VALUES (?,0)", (prefix,))
def getCountry(self, ip):
def getCountry(self, ip: str) -> str:
try:
return self._registry.getCountry(self._prefix, ip).decode()
except socket.error as e:
......
......@@ -272,7 +272,7 @@ def main():
call(args)
args[3] = 'del'
cleanup.append(lambda: subprocess.call(args))
def ip(object, *args):
def ip(object: str, *args):
args = ['ip', '-6', object, 'add'] + list(args)
call(args)
args[3] = 'del'
......
......@@ -171,7 +171,7 @@ class Babel:
_decode = None
def __init__(self, socket_path, handler, network):
def __init__(self, socket_path: str, handler, network: str):
self.socket_path = socket_path
self.handler = handler
self.network = network
......@@ -304,7 +304,7 @@ class iterRoutes:
_waiting = True
def __new__(cls, control_socket, network):
def __new__(cls, control_socket: str, network: str):
self = object.__new__(cls)
c = Babel(control_socket, self, network)
c.request_dump()
......
......@@ -3,30 +3,30 @@ import errno, os, socket, stat, threading
class Socket:
def __init__(self, socket):
def __init__(self, socket: socket.socket):
# In case that the default timeout is not None.
socket.settimeout(None)
self._socket = socket
self._buf = ''
self._buf = b''
def close(self):
self._socket.close()
def write(self, data):
def write(self, data: bytes):
self._socket.send(data)
def readline(self):
def readline(self) -> bytes:
recv = self._socket.recv
data = self._buf
while True:
i = 1 + data.find('\n')
i = 1 + data.find(b'\n')
if i:
self._buf = data[i:]
return data[:i]
d = recv(4096)
data += d
if not d:
self._buf = ''
self._buf = b''
return data
def flush(self):
......
......@@ -8,7 +8,7 @@ ovpn_server = os.path.join(here, 'ovpn-server')
ovpn_client = os.path.join(here, 'ovpn-client')
ovpn_log: Optional[str] = None
def openvpn(iface, encrypt, *args, **kw):
def openvpn(iface: str, encrypt, *args, **kw) -> utils.Popen:
args = ['openvpn',
'--dev-type', 'tap',
'--dev', iface,
......@@ -28,7 +28,7 @@ def openvpn(iface, encrypt, *args, **kw):
ovpn_link_mtu_dict = {'udp4': 1432, 'udp6': 1450}
def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
def server(iface: str, max_clients: int, dh_path: str, fd: int, port: int, proto: str, encrypt: bool, *args, **kw) -> utils.Popen:
if proto == 'udp':
proto = 'udp4'
client_script = '%s %s' % (ovpn_server, fd)
......@@ -49,7 +49,7 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
*args, pass_fds=[fd], **kw)
def client(iface, address_list, encrypt, *args, **kw):
def client(iface: str, address_list: list[tuple[str, int, str]], encrypt: bool, *args, **kw) -> utils.Popen:
remote = ['--nobind', '--client']
# XXX: We'd like to pass <connection> sections at command-line.
link_mtu = set()
......@@ -65,8 +65,8 @@ def client(iface, address_list, encrypt, *args, **kw):
return openvpn(iface, encrypt, *remote, **kw)
def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
control_socket, default, hmac, *args, **kw):
def router(ip: tuple[str, int], ip4, rt6: tuple[str, bool, bool], hello_interval: int, log_path: str, state_path: str, pidfile: str,
control_socket: str, default: str, hmac: tuple[bytes | None, bytes | None], *args, **kw) -> utils.Popen:
network, gateway, has_ipv6_subtrees = rt6
network_mask = int(network[network.index('/')+1:])
ip, n = ip
......@@ -83,7 +83,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C', 'redistribute local deny',
'-C', 'redistribute ip %s/%s eq %s' % (ip, n, n)]
if hmac_sign:
def key(cmd, id: str, value):
def key(cmd: list[str], id: str, value: bytes):
cmd += '-C', ('key type blake2s128 id %s value %s' %
(id, binascii.hexlify(value).decode()))
key(cmd, 'sign', hmac_sign)
......
......@@ -22,10 +22,13 @@ import base64, hmac, hashlib, http.client, inspect, json, logging
import mailbox, os, platform, random, select, smtplib, socket, sqlite3
import string, sys, threading, time, weakref, zlib
from collections import defaultdict, deque
from collections.abc import Iterator
from datetime import datetime
from http.server import HTTPServer, BaseHTTPRequestHandler
from email.mime.text import MIMEText
from operator import itemgetter
from typing import Tuple
from OpenSSL import crypto
from urllib.parse import urlparse, unquote, urlencode
from . import ctl, tunnel, utils, version, x509
......@@ -139,10 +142,10 @@ class RegistryServer:
if self.geoip_db:
from geoip2 import database, errors
country = database.Reader(self.geoip_db).country
def geoiplookup(ip):
def geoiplookup(ip: str) -> Tuple[str, str]:
try:
req = country(ip)
return req.country.iso_code.encode(), req.continent.code.encode()
return req.country.iso_code, req.continent.code
except (errors.AddressNotFoundError, ValueError):
return '*', '*'
self._geoiplookup = geoiplookup
......@@ -243,7 +246,7 @@ class RegistryServer:
def babel_dump(self):
self._wait_dump = False
def iterCert(self):
def iterCert(self) -> Iterator[Tuple[crypto.X509, str, str]]:
for prefix, email, cert in self.db.execute(
"SELECT * FROM cert WHERE cert IS NOT NULL"):
try:
......@@ -356,7 +359,7 @@ class RegistryServer:
assert len(key) == len(sign)
return key + sign
def getCert(self, client_prefix):
def getCert(self, client_prefix: str) -> bytes:
assert self.lock.locked()
cert = self.db.execute("SELECT cert FROM cert"
" WHERE prefix=? AND cert IS NOT NULL", (client_prefix,)).fetchone()
......@@ -365,19 +368,19 @@ class RegistryServer:
return cert[0]
@rpc_private
def isToken(self, token):
def isToken(self, token: str):
with self.lock:
if self.db.execute("SELECT 1 FROM token WHERE token = ?",
(token,)).fetchone():
return b"1"
@rpc_private
def deleteToken(self, token):
def deleteToken(self, token: str):
with self.lock:
self.db.execute("DELETE FROM token WHERE token = ?", (token,))
@rpc_private
def addToken(self, email, token):
def addToken(self, email: str, token: str | None) -> str:
prefix_len = self.config.prefix_length
if not prefix_len:
raise HTTPError(http.client.FORBIDDEN)
......@@ -505,7 +508,7 @@ class RegistryServer:
q("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,))
@rpc
def requestCertificate(self, token, req, location='', ip=''):
def requestCertificate(self, token: str | None, req: bytes, location: str='', ip: str=''):
logging.debug("Requesting certificate with token %s", token)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, req)
with self.lock:
......@@ -579,7 +582,7 @@ class RegistryServer:
return cert
@rpc
def renewCertificate(self, cn):
def renewCertificate(self, cn: str) -> bytes:
with self.lock:
with self.db as db:
pem = self.getCert(cn)
......@@ -595,16 +598,16 @@ class RegistryServer:
cert.get_subject(), cert.get_pubkey(), not_after)
@rpc
def getCa(self):
def getCa(self) -> bytes:
return crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert.ca)
@rpc
def getDh(self, cn):
with open(self.config.dh) as f:
def getDh(self, cn: str) -> bytes:
with open(self.config.dh, "rb") as f:
return f.read()
@rpc
def getNetworkConfig(self, cn):
def getNetworkConfig(self, cn: str) -> bytes:
with self.lock:
cert = self.getCert(cn)
config = self.network_config.copy()
......@@ -614,7 +617,7 @@ class RegistryServer:
v and base64.b64encode(x509.encrypt(cert, v)).decode()
return zlib.compress(json.dumps(config).encode("utf-8"))
def _queryAddress(self, peer) -> str:
def _queryAddress(self, peer: str) -> str:
logging.info("Querying address for %s/%s %r", int(peer, 2), len(peer), peer)
self.sendto(peer, 1)
s = self.sock,
......@@ -631,12 +634,12 @@ class RegistryServer:
int(peer, 2), len(peer))
@rpc
def getCountry(self, cn, address) -> bytes:
def getCountry(self, cn: str, address: str) -> bytes | None:
country = self._geoiplookup(address)[0]
return None if country == '*' else country.encode()
@rpc
def getBootstrapPeer(self, cn):
def getBootstrapPeer(self, cn: str) -> bytes | None:
logging.info("Answering bootstrap peer for %s", cn)
with self.peers_lock:
age, peers = self.peers
......@@ -671,7 +674,7 @@ class RegistryServer:
return x509.encrypt(cert, msg.encode())
@rpc_private
def revoke(self, cn_or_serial):
def revoke(self, cn_or_serial: int | str):
with self.lock, self.db:
q = self.db.execute
try:
......@@ -692,12 +695,12 @@ class RegistryServer:
q("INSERT INTO crl VALUES (?,?)", (serial, not_after))
self.updateNetworkConfig()
def newHMAC(self, i, key=None):
def newHMAC(self, i: int, key: bytes=None):
if key is None:
key = os.urandom(16)
self.setConfig(BABEL_HMAC[i], key)
def delHMAC(self, i):
def delHMAC(self, i: int):
self.db.execute("DELETE FROM config WHERE name=?", (BABEL_HMAC[i],))
@rpc_private
......@@ -724,7 +727,7 @@ class RegistryServer:
self.sendto(self.prefix, 0)
@rpc_private
def getNodePrefix(self, email):
def getNodePrefix(self, email: str) -> str | None:
with self.lock, self.db:
try:
cert, = next(self.db.execute("SELECT cert FROM cert WHERE email = ?",
......@@ -735,7 +738,7 @@ class RegistryServer:
return x509.subnetFromCert(certificate)
@rpc_private
def getIPv6Address(self, email):
def getIPv6Address(self, email: str) -> str:
cn = self.getNodePrefix(email)
if cn:
return utils.ipFromBin(
......@@ -743,7 +746,7 @@ class RegistryServer:
+ utils.binFromSubnet(cn))
@rpc_private
def getIPv4Information(self, email):
def getIPv4Information(self, email: str) -> bytes | None:
peer = self.getNodePrefix(email)
if peer:
peer = utils.binFromSubnet(peer)
......@@ -762,7 +765,7 @@ class RegistryServer:
return msg.split(',')[0].encode()
@rpc_private
def versions(self):
def versions(self) -> str:
with self.peers_lock:
self.request_dump()
peers = {prefix
......@@ -788,7 +791,7 @@ class RegistryServer:
return json.dumps(peer_dict)
@rpc_private
def topology(self):
def topology(self) -> str:
logging.info("Computing topology")
p = lambda p: '%s/%s' % (int(p, 2), len(p))
peers = deque((p(self.prefix),))
......@@ -828,7 +831,7 @@ class RegistryClient:
_hmac = None
user_agent = "re6stnet/%s, %s" % (version.version, platform.platform())
def __init__(self, url, cert: x509.Cert=None, auto_close=True):
def __init__(self, url: str, cert: x509.Cert=None, auto_close=True):
self.cert = cert
self.auto_close = auto_close
url_parsed = urlparse(url)
......@@ -838,7 +841,7 @@ class RegistryClient:
)[scheme](unquote(host), timeout=60)
self._path = path.rstrip('/')
def __getattr__(self, name):
def __getattr__(self, name: str):
getcallargs = getattr(RegistryServer, name).getcallargs
def rpc(*args, **kw) -> bytes:
kw = getcallargs(*args, **kw)
......
......@@ -11,11 +11,13 @@ import hashlib
import time
import tempfile
from argparse import Namespace
from sqlite3 import Cursor
from OpenSSL import crypto
from mock import Mock, patch
from pathlib import Path
from re6st import registry
from re6st import registry, x509
from re6st.tests.tools import *
from re6st.tests import DEMO_PATH
......@@ -23,7 +25,7 @@ from re6st.tests import DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions
def load_config(filename="registry.json"):
def load_config(filename: str="registry.json") -> Namespace:
with open(filename) as f:
config = json.load(f)
config["dh"] = DEMO_PATH / "dh2048.pem"
......@@ -37,13 +39,13 @@ def load_config(filename="registry.json"):
return Namespace(**config)
def get_cert(cur, prefix):
def get_cert(cur: Cursor, prefix: str):
res = cur.execute(
"SELECT cert FROM cert WHERE prefix=?", (prefix,)).fetchone()
return res[0]
def insert_cert(cur, ca, prefix, not_after=None, email=None):
def insert_cert(cur: Cursor, ca: x509.Cert, prefix: str, not_after=None, email=None):
key, csr = generate_csr()
cert = generate_cert(ca.ca, ca.key, csr, prefix, insert_cert.serial, not_after)
cur.execute("INSERT INTO cert VALUES (?,?,?)", (prefix, email, cert))
......@@ -54,7 +56,7 @@ def insert_cert(cur, ca, prefix, not_after=None, email=None):
insert_cert.serial = 0
def delete_cert(cur, prefix):
def delete_cert(cur: Cursor, prefix: str):
cur.execute("DELETE FROM cert WHERE prefix = ?", (prefix,))
......
......@@ -92,18 +92,15 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
return key, cert
def prefix2cn(prefix):
def prefix2cn(prefix: str) -> str:
return "%u/%u" % (int(prefix, 2), len(prefix))
def serial2prefix(serial):
def serial2prefix(serial: int) -> str:
return bin(serial)[2:].rjust(16, '0')
# pkey: private key
def decrypt(pkey, incontent):
with open("node.key", 'w') as f:
f.write(pkey.decode())
def decrypt(pkey: bytes, incontent: bytes) -> bytes:
with open("node.key", 'wb') as f:
f.write(pkey)
args = "openssl rsautl -decrypt -inkey node.key".split()
with subprocess.Popen(
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
outcontent, err = p.communicate(incontent)
return outcontent
return subprocess.run(args, input=incontent, stdout=subprocess.PIPE).stdout
......@@ -2,8 +2,11 @@ import errno, json, logging, os, platform, random, socket
import subprocess, struct, sys, time, weakref
from collections import defaultdict, deque
from bisect import bisect, insort
from collections.abc import Iterator, Sequence
from typing import Callable
from OpenSSL import crypto
from . import ctl, plib, utils, version, x509
from . import cache, ctl, plib, utils, version, x509
PORT = 326
......@@ -21,7 +24,7 @@ proto_dict = {
proto_dict['tcp'] = proto_dict['tcp4']
proto_dict['udp'] = proto_dict['udp4']
def resolve(ip, port, proto):
def resolve(ip, port, proto: str) -> tuple[socket.AddressFamily | None, Iterator[str]]:
try:
family, proto = proto_dict[proto]
except KeyError:
......@@ -31,16 +34,16 @@ def resolve(ip, port, proto):
class MultiGatewayManager(dict):
def __init__(self, gateway):
def __init__(self, gateway: Callable[[str], str]):
self._gw = gateway
def _route(self, cmd, dest, gw):
def _route(self, cmd: str, dest: str, gw: str):
if gw:
cmd = 'ip', '-4', 'route', cmd, '%s/32' % dest, 'via', gw
logging.trace('%r', cmd)
subprocess.check_call(cmd)
def add(self, dest, route):
def add(self, dest: str, route: bool):
try:
self[dest][1] += 1
except KeyError:
......@@ -48,7 +51,7 @@ class MultiGatewayManager(dict):
self[dest] = [gw, 0]
self._route('add', dest, gw)
def remove(self, dest):
def remove(self, dest: str):
gw, count = self[dest]
if count:
self[dest][1] = count - 1
......@@ -198,7 +201,7 @@ class BaseTunnelManager:
_geoiplookup = None
_forward = None
def __init__(self, control_socket, cache, cert, conf_country, address=()):
def __init__(self, control_socket, cache: cache.Cache, cert: x509.Cert, conf_country, address=()):
self.cert = cert
self._network = cert.network
self._prefix = cert.prefix
......@@ -665,7 +668,7 @@ class TunnelManager(BaseTunnelManager):
def __init__(self, control_socket, cache, cert, openvpn_args,
timeout, client_count, iface_list, conf_country, address,
ip_changed, remote_gateway, disable_proto, neighbour_list=()):
ip_changed, remote_gateway: Callable[[str], str], disable_proto: Sequence[str], neighbour_list=()):
super(TunnelManager, self).__init__(control_socket,
cache, cert, conf_country, address)
self.ovpn_args = openvpn_args
......
......@@ -17,7 +17,7 @@ class Forwarder:
_lcg_n = 0
@classmethod
def _getExternalPort(cls):
def _getExternalPort(cls) -> int:
# Since _refresh() does not test all ports in a row, we prefer to
# return random ports to maximize the chance to find a free port.
# A linear congruential generator should be random enough, without
......@@ -35,7 +35,7 @@ class Forwarder:
self._u.discoverdelay = 200
self._rules = []
def __getattr__(self, name):
def __getattr__(self, name: str):
wrapped = getattr(self._u, name)
def wrapper(*args, **kw):
try:
......
import argparse, errno, fcntl, hashlib, logging, os, select as _select
import shlex, signal, socket, sqlite3, struct, subprocess
import sys, textwrap, threading, time, traceback
from typing import Optional
from collections.abc import Iterator, Mapping
from typing import Optional, Callable
HMAC_LEN = len(hashlib.sha1(b'').digest())
......@@ -40,7 +41,7 @@ class FileHandler(logging.FileHandler):
if self.lock.acquire(False):
self.release()
def setupLog(log_level, filename=None, **kw):
def setupLog(log_level: int, filename: str | None=None, **kw):
if log_level and filename:
makedirs(os.path.dirname(filename))
handler = FileHandler(filename)
......@@ -184,7 +185,7 @@ def setCloexec(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def select(R, W, T):
def select(R: Mapping, W: Mapping, T):
try:
r, w, _ = _select.select(R, W, (),
max(0, min(T)[0] - time.time()) if T else None)
......@@ -208,15 +209,15 @@ def makedirs(*args):
if e.errno != errno.EEXIST:
raise
def binFromIp(ip):
def binFromIp(ip: str) -> str:
return binFromRawIp(socket.inet_pton(socket.AF_INET6, ip))
def binFromRawIp(ip):
def binFromRawIp(ip: bytes) -> str:
ip1, ip2 = struct.unpack('>QQ', ip)
return bin(ip1)[2:].rjust(64, '0') + bin(ip2)[2:].rjust(64, '0')
def ipFromBin(ip, suffix=''):
def ipFromBin(ip: str, suffix='') -> str:
suffix_len = 128 - len(ip)
if suffix_len > 0:
ip += suffix.rjust(suffix_len, '0')
......@@ -225,11 +226,11 @@ def ipFromBin(ip, suffix=''):
return socket.inet_ntop(socket.AF_INET6,
struct.pack('>QQ', int(ip[:64], 2), int(ip[64:], 2)))
def dump_address(address):
def dump_address(address: str) -> str:
return ';'.join(map(','.join, address))
# Yield ip, port, protocol, and country if it is in the address
def parse_address(address_list):
def parse_address(address_list: str) -> Iterator[tuple[str, str, str, str]]:
for address in address_list.split(';'):
try:
a = address.split(',')
......@@ -239,17 +240,17 @@ def parse_address(address_list):
logging.warning("Failed to parse node address %r (%s)",
address, e)
def binFromSubnet(subnet):
def binFromSubnet(subnet: str) -> str:
p, l = subnet.split('/')
return bin(int(p))[2:].rjust(int(l), '0')
def newHmacSecret():
def _newHmacSecret() -> Callable[[Optional[int]], bytes]:
"""returns bytes"""
from random import getrandbits as g
pack = struct.Struct(">QQI").pack
assert len(pack(0,0,0)) == HMAC_LEN
return lambda x=None: pack(g(64) if x is None else x, g(64), g(32))
newHmacSecret = newHmacSecret()
newHmacSecret = _newHmacSecret()
### Integer serialization
# - supports values from 0 to 0x202020202020201f
......
# -*- coding: utf-8 -*-
import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time
from typing import Callable, Any
from OpenSSL import crypto
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
......@@ -9,29 +11,29 @@ from cryptography.x509 import load_pem_x509_certificate
from . import utils
from .version import protocol
def newHmacSecret():
def newHmacSecret() -> bytes:
return utils.newHmacSecret(int(time.time() * 1000000))
def networkFromCa(ca):
def networkFromCa(ca: crypto.X509) -> str:
# TODO: will be ca.serial_number after migration to cryptography
return bin(ca.get_serial_number())[3:]
def subnetFromCert(cert):
def subnetFromCert(cert: crypto.X509) -> str:
return cert.get_subject().CN
def notBefore(cert):
def notBefore(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notBefore().decode(),'%Y%m%d%H%M%SZ'))
def notAfter(cert):
def notAfter(cert: crypto.X509) -> int:
return calendar.timegm(time.strptime(cert.get_notAfter().decode(),'%Y%m%d%H%M%SZ'))
def openssl(*args, fds=[]):
def openssl(*args: str, fds=[]) -> utils.Popen:
return utils.Popen(('openssl',) + args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE, pass_fds=fds)
def encrypt(cert, data: bytes) -> bytes:
def encrypt(cert: bytes, data: bytes) -> bytes:
assert isinstance(data, bytes)
r, w = os.pipe()
try:
......@@ -46,10 +48,10 @@ def encrypt(cert, data: bytes) -> bytes:
raise subprocess.CalledProcessError(p.returncode, 'openssl', err)
return out
def fingerprint(cert, alg='sha1'):
def fingerprint(cert: crypto.X509, alg='sha1'):
return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert))
def maybe_renew(path, cert, info, renew, force=False):
def maybe_renew(path: str, cert: crypto.X509, info: str, renew: Callable[[], Any], force=False) -> tuple[crypto.X509, int]:
from .registry import RENEW_PERIOD
while True:
if force:
......@@ -93,7 +95,7 @@ class NewSessionError(Exception):
class Cert:
def __init__(self, ca, key, cert=None):
def __init__(self, ca: str, key: str, cert: str | None=None):
self.ca_path = ca
self.cert_path = cert
self.key_path = key
......@@ -111,24 +113,24 @@ class Cert:
self.cert = self.loadVerify(f.read().encode())
@property
def prefix(self):
def prefix(self) -> str:
return utils.binFromSubnet(subnetFromCert(self.cert))
@property
def network(self):
def network(self) -> str:
return networkFromCa(self.ca)
@property
def subject_serial(self):
def subject_serial(self) -> int:
return int(self.cert.get_subject().serialNumber)
@property
def openvpn_args(self):
def openvpn_args(self) -> tuple[str, ...]:
return ('--ca', self.ca_path,
'--cert', self.cert_path,
'--key', self.key_path)
def maybeRenew(self, registry, crl):
def maybeRenew(self, registry, crl) -> int:
self.cert, next_renew = maybe_renew(self.cert_path, self.cert,
"Certificate", lambda: registry.renewCertificate(self.prefix),
self.cert.get_serial_number() in crl)
......@@ -232,6 +234,7 @@ class Peer:
serial = None
stop_date = float('inf')
version = b''
cert: crypto.X509
def __init__(self, prefix: str):
self.prefix = prefix
......@@ -249,7 +252,7 @@ class Peer:
def __lt__(self, other):
return self.prefix < (other if type(other) is str else other.prefix)
def hello0(self, cert):
def hello0(self, cert: crypto.X509) -> bytes:
if self._hello < time.time():
try:
# Always assume peer is not old, in case it has just upgraded,
......@@ -264,7 +267,7 @@ class Peer:
def hello0Sent(self):
self._hello = time.time() + 60
def hello(self, cert, protocol):
def hello(self, cert: Cert, protocol: int) -> bytes:
key = self._key = newHmacSecret()
h = encrypt(crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert),
key)
......@@ -274,10 +277,10 @@ class Peer:
return b''.join((b'\0\0\0\2', PACKED_PROTOCOL if protocol else b'',
h, cert.sign(h)))
def _hmac(self, msg):
def _hmac(self, msg: bytes) -> bytes:
return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key: bytes, protocol):
def newSession(self, key: bytes, protocol: int):
if key <= self._key:
raise NewSessionError(self._key, key)
self._key = key
......@@ -285,12 +288,12 @@ class Peer:
self._last = None
self.protocol = protocol
def verify(self, sign, data):
def verify(self, sign: bytes, data: bytes):
crypto.verify(self.cert, sign, data, 'sha512')
seqno_struct = struct.Struct("!L")
def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> bytes:
def decode(self, msg: bytes, _unpack=seqno_struct.unpack) -> tuple[int, bytes, int | None] | bytes:
assert isinstance(msg, bytes)
seqno, = _unpack(msg[:4])
if seqno <= 2:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment