""" Authenticated communication: handshake (hello): C->S: CN S->C: X = Encrypt(CN)(secret), Sign(CA)(X) call: C->S: CN, ..., HMAC(secret+1)(path_info?query_string) S->C: result, HMAC(secret+2)(result) secret+1 = SHA1(secret) to protect from replay attacks HMAC in custom header, base64-encoded To prevent anyone from breaking an existing session, keep 2 secrets for each client: - the last one that was really used by the client (!hello) - the one of the last handshake (hello) """ import base64, hmac, hashlib, http.client, inspect, json, logging import mailbox, os, platform, random, select, smtplib, socket, sqlite3 import string, sys, threading, time, weakref, zlib from collections import defaultdict, deque from collections.abc import Iterator from datetime import datetime from http.server import HTTPServer, BaseHTTPRequestHandler from email.mime.text import MIMEText from operator import itemgetter from typing import Tuple from OpenSSL import crypto from urllib.parse import urlparse, unquote, urlencode from . import ctl, tunnel, utils, version, x509 HMAC_HEADER = "Re6stHMAC" RENEW_PERIOD = 30 * 86400 BABEL_HMAC = 'babel_hmac0', 'babel_hmac1', 'babel_hmac2' def rpc(f): argspec = inspect.getfullargspec(f) assert not (argspec.varargs or argspec.varkw), f sig = inspect.signature(f) sig = sig.replace(parameters=[v.replace(annotation=inspect.Parameter.empty) for v in sig.parameters.values()][1:], return_annotation=inspect.Signature.empty) f.getcallargs = eval("lambda %s: locals()" % str(sig)[1:-1]) return f def rpc_private(f): f._private = None return rpc(f) class HTTPError(Exception): pass class RegistryServer: peers = 0, () cert_duration = 365 * 86400 def _geoiplookup(self, ip): raise HTTPError(http.client.BAD_REQUEST) def __init__(self, config): self.config = config self.lock = threading.Lock() self.sessions = {} self.sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) # Parse community file self.community_map = {} if config.community: with open(config.community) as x: for x in x: x = x.strip() if x and not x.startswith('#'): x = x.split() self.community_map[x.pop(0)] = x if sum('*' in x for x in self.community_map.values()) != 1: sys.exit("Invalid community configuration: missing or multiple default location ('*')") else: self.community_map[''] = '*' # Database initializing db_dir = os.path.dirname(self.config.db) db_dir and utils.makedirs(db_dir) self.db = sqlite3.connect(self.config.db, isolation_level=None, check_same_thread=False) self.db.text_factory = str utils.sqliteCreateTable(self.db, "config", "name TEXT PRIMARY KEY NOT NULL", "value") self.prefix = self.getConfig("prefix", None) self.version = self.getConfig("version", b'\x00') utils.sqliteCreateTable(self.db, "token", "token TEXT PRIMARY KEY NOT NULL", "email TEXT NOT NULL", "prefix_len INTEGER NOT NULL", "date INTEGER NOT NULL") utils.sqliteCreateTable(self.db, "cert", "prefix TEXT PRIMARY KEY NOT NULL", "email TEXT", "cert TEXT") logging.debug("Checking for existing certs...") if not self.db.execute("SELECT 1 FROM cert LIMIT 1").fetchone(): logging.debug("No existing certs found, creating a blank one...") self.db.execute("INSERT INTO cert VALUES ('',null,null)") prev = '-' for community in sorted(self.community_map): if community.startswith(prev): err = "communities %s and %s overlap" % (prev, community) else: x = self.db.execute("SELECT prefix, cert FROM cert" " WHERE substr(?,1,length(prefix)) = prefix", (community,)).fetchone() if not x or x[1] is None: prev = community continue err = "prefix %s contains community %s" % (x[0], community) sys.exit("Invalid community configuration: " + err) utils.sqliteCreateTable(self.db, "crl", "serial INTEGER PRIMARY KEY NOT NULL", # Expiration date of revoked certificate. # TODO: purge rows with dates in the past. "date INTEGER NOT NULL") self.cert = x509.Cert(self.config.ca, self.config.key) # Get vpn network prefix self.network = self.cert.network logging.info("Network: %s/%u", utils.ipFromBin(self.network), len(self.network)) self.email = self.cert.ca.get_subject().emailAddress self.peers_lock = threading.Lock() self.ctl = ctl.Babel(os.path.join(config.run, 'babeld.sock'), weakref.proxy(self), self.network) self.geoip_db = os.getenv('GEOIP2_MMDB') if self.geoip_db: from geoip2 import database, errors country = database.Reader(self.geoip_db).country def geoiplookup(ip: str) -> Tuple[str, str]: try: req = country(ip) return req.country.iso_code, req.continent.code except (errors.AddressNotFoundError, ValueError): return '*', '*' self._geoiplookup = geoiplookup elif self.config.same_country: sys.exit("Can not respect 'same_country' network configuration" " (GEOIP2_MMDB not set)") self.onTimeout() if self.prefix: with self.db: self.updateNetworkConfig() else: self.newHMAC(0) def close(self): self.sock.close() self.db.close() self.ctl.close() def getConfig(self, name, *default): r, = next(self.db.execute( "SELECT value FROM config WHERE name=?", (name,)), default) return r def setConfig(self, *name_value): self.db.execute("INSERT OR REPLACE INTO config VALUES (?, ?)", name_value) def updateNetworkConfig(self, _it0=itemgetter(0)): kw = { 'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125', 'crl': list(map(_it0, self.db.execute( "SELECT serial FROM crl ORDER BY serial"))), 'protocol': version.protocol, 'registry_prefix': self.prefix, } if self.config.ipv4: kw['ipv4'], kw['ipv4_sublen'] = self.config.ipv4 if self.config.same_country: kw['same_country'] = self.config.same_country for x in ('client_count', 'encrypt', 'hello', 'max_clients', 'min_protocol', 'tunnel_refresh'): kw[x] = getattr(self.config, x) config = json.dumps(kw, sort_keys=True) if config != self.getConfig('last_config', None): self.increaseVersion() self.setConfig('version', self.version) self.setConfig('last_config', config) self.sendto(self.prefix, 0) # The following entry lists values that are base64-encoded. kw[''] = 'version', kw['version'] = base64.b64encode(self.version).decode("ascii") self.network_config = kw def increaseVersion(self): x = utils.packInteger(1 + utils.unpackInteger(self.version)[0]) self.version = x + self.cert.sign(x) def sendto(self, prefix: str, code: int): self.sock.sendto(prefix.encode() + bytes((0, code)), ('::1', tunnel.PORT)) def recv(self, code: int) -> (str, str): try: prefix, msg = self.sock.recv(1 << 16).split(b'\x00', 1) int(prefix, 2) except ValueError: pass else: if msg: if msg[0:1] == bytes([code]): return prefix.decode(), msg[1:].decode() else: logging.error("Unexpected code: %r", msg) else: logging.error("Empty message") return None, None def select(self, r, w, t): if self.timeout: t.append((self.timeout, self.onTimeout)) def request_dump(self): assert self.peers_lock.locked() def abort(): raise ctl.BabelException self._wait_dump = True for _ in 0, 1: self.ctl.request_dump() try: while self._wait_dump: args = {}, {}, ((time.time() + 5, abort),) self.ctl.select(*args) utils.select(*args) break except ctl.BabelException: self.ctl.reset() def babel_dump(self): self._wait_dump = False def iterCert(self) -> Iterator[Tuple[crypto.X509, str, str]]: for prefix, email, cert in self.db.execute( "SELECT * FROM cert WHERE cert IS NOT NULL"): try: yield (crypto.load_certificate(crypto.FILETYPE_PEM, cert), prefix, email) except crypto.Error: pass def onTimeout(self): # XXX: Because we use threads to process requests, the statements # 'self.timeout = 1' below have no effect as long as the # 'select' call does not return. Ideally, we should interrupt it. logging.info("Checking if there's any old entry in the database ...") not_after = None old = time.time() - self.config.grace_period q = self.db.execute with self.lock, self.db: q("BEGIN") for token, x in q("SELECT token, date FROM token"): if x <= old: q("DELETE FROM token WHERE token=?", (token,)) elif not_after is None or x < not_after: not_after = x for cert, prefix, email in self.iterCert(): x = x509.notAfter(cert) if x <= old: if prefix == self.prefix: logging.critical("Refuse to delete certificate" " of main node: wrong clock ?" " Alternatively, the database might be in an inconsistent state.") sys.exit(1) logging.info("Delete %s: %s (invalid since %s)", "certificate requested by '%s'" % email if email else "anonymous certificate", ", ".join("%s=%s" % x for x in cert.get_subject().get_components()), datetime.utcfromtimestamp(x).isoformat()) q("UPDATE cert SET email=null, cert=null WHERE prefix=?", (prefix,)) elif not_after is None or x < not_after: not_after = x self.mergePrefixes() self.timeout = not_after and not_after + self.config.grace_period def handle_request(self, request, method, kw): m = getattr(self, method) if hasattr(m, '_private'): authorized_origin = self.config.authorized_origin x_forwarded_for = request.headers.get('X-Forwarded-For') if request.client_address[0] not in authorized_origin or \ x_forwarded_for and x_forwarded_for not in authorized_origin: return request.send_error(http.client.FORBIDDEN) key = m.getcallargs(**kw).get('cn') if key: h = base64.b64decode(request.headers[HMAC_HEADER]) with self.lock: session = self.sessions[key] for key, protocol in session: if h == hmac.HMAC(key, request.path.encode(), hashlib.sha1).digest(): break else: raise Exception("Wrong HMAC") key = hashlib.sha1(key).digest() session[:] = (hashlib.sha1(key).digest(), protocol), else: logging.info("%s%s: %s, %s", method, '(' + utils.ipFromBin(x509.networkFromCa(self.cert.ca) + kw["client_prefix"]) + ')' if method == 'hello' else '', request.headers.get("X-Forwarded-For") or request.headers.get("host"), request.headers.get("user-agent")) if not kw.get('ip', True): kw['ip'] = (request.headers.get("X-Forwarded-For", "").split(',',1)[0].strip() or request.headers.get("host")) try: result = m(**kw) except HTTPError as e: return request.send_error(*e.args) except: logging.warning(request.requestline, exc_info=True) return request.send_error(http.client.INTERNAL_SERVER_ERROR) if result: if type(result) is str: result = result.encode("utf-8") request.send_response(http.client.OK) request.send_header("Content-Length", str(len(result))) else: request.send_response(http.client.NO_CONTENT) if key: request.send_header(HMAC_HEADER, base64.b64encode( hmac.HMAC(key, result, hashlib.sha1).digest()).decode("ascii")) request.end_headers() if result: request.wfile.write(result) def getPeerProtocol(self, cn): session, = self.sessions[cn] return session[1] @rpc def hello(self, client_prefix, protocol='1'): with self.lock: cert = self.getCert(client_prefix) key = utils.newHmacSecret() self.sessions.setdefault(client_prefix, [])[1:] = (key, int(protocol)), key = x509.encrypt(cert, key) sign = self.cert.sign(key) assert len(key) == len(sign) return key + sign def getCert(self, client_prefix: str) -> bytes: assert self.lock.locked() cert = self.db.execute("SELECT cert FROM cert" " WHERE prefix=? AND cert IS NOT NULL", (client_prefix,)).fetchone() assert cert, "Certificate query did not return any result; "\ "this indicates inconsistent state and should not happen" return cert[0] @rpc_private def isToken(self, token: str): with self.lock: if self.db.execute("SELECT 1 FROM token WHERE token = ?", (token,)).fetchone(): return b"1" @rpc_private def deleteToken(self, token: str): with self.lock: self.db.execute("DELETE FROM token WHERE token = ?", (token,)) @rpc_private def addToken(self, email: str, token: str | None) -> str: prefix_len = self.config.prefix_length if not prefix_len: raise HTTPError(http.client.FORBIDDEN) request = token is None with self.lock: while True: # Generating token if request: token = ''.join(random.sample(string.ascii_lowercase, 8)) args = token, email, prefix_len, int(time.time()) # Updating database try: self.db.execute("INSERT INTO token VALUES (?,?,?,?)", args) break except sqlite3.IntegrityError: if not request: raise HTTPError(http.client.CONFLICT) self.timeout = 1 if request: return token @rpc def requestToken(self, email): if not self.config.mailhost: raise HTTPError(http.client.FORBIDDEN) token = self.addToken(email, None) # Creating and sending email msg = MIMEText('Hello, your token to join re6st network is: %s\n' % token) msg['Subject'] = '[re6stnet] Token Request' if self.email: msg['From'] = self.email msg['To'] = email if os.path.isabs(self.config.mailhost) or \ os.path.isfile(self.config.mailhost): with self.lock: m = mailbox.mbox(self.config.mailhost) try: m.add(msg) finally: m.close() else: s = smtplib.SMTP(self.config.mailhost) if self.config.smtp_starttls: s.starttls() if self.config.smtp_user: s.login(self.config.smtp_user, self.config.smtp_pwd) s.sendmail(self.email, email, msg.as_string()) s.quit() def getCommunity(self, country, continent): for prefix, location_list in self.community_map.items(): if country in location_list: return prefix default = '' for prefix, location_list in self.community_map.items(): if continent in location_list: return prefix if '*' in location_list: default = prefix return default def mergePrefixes(self): logging.debug("Merging prefixes") q = self.db.execute prev_prefix = None max_len = 128, while True: max_len = q("SELECT max(length(prefix)) FROM cert" " WHERE cert is null AND length(prefix) < ?", max_len).fetchone() if not max_len[0]: break for prefix, in q("SELECT prefix FROM cert" " WHERE cert is null AND length(prefix) = ?" " ORDER BY prefix", max_len): if prev_prefix and prefix[:-1] == prev_prefix[:-1]: q("UPDATE cert SET prefix = ? WHERE prefix = ?", (prefix[:-1], prev_prefix)) q("DELETE FROM cert WHERE prefix = ?", (prefix,)) prev_prefix = None else: prev_prefix = prefix def newPrefix(self, prefix_len, community): logging.info("Allocating /%u prefix for %s", prefix_len, community) community_len = len(community) prefix_len += community_len max_len = 128 - len(self.network) assert 0 < prefix_len <= max_len q = self.db.execute while True: try: # Find longest free prefix whithin community. prefix, = next(q( "SELECT prefix FROM cert" " WHERE prefix LIKE ?" " AND length(prefix) <= ? AND cert is null" " ORDER BY length(prefix) DESC", (community + '%', prefix_len))) except StopIteration: # Community not yet allocated? # There should be exactly 1 row whose # prefix is the beginning of community. prefix, x = next(q("SELECT prefix, cert FROM cert" " WHERE substr(?,1,length(prefix)) = prefix", (community,))) if x is not None: logging.error('No more free /%u prefix available', prefix_len) raise # Split the tree until prefix has wanted length. for x in range(len(prefix), prefix_len): # Prefix starts with community, then we complete with 0. x = community[x] if x < community_len else '0' q("UPDATE cert SET prefix = ? WHERE prefix = ?", (prefix + str(1-int(x)), prefix)) prefix += x q("INSERT INTO cert VALUES (?,null,null)", (prefix,)) if len(prefix) < max_len or '1' in prefix[community_len:]: return prefix q("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,)) @rpc def requestCertificate(self, token: str | None, req: bytes, location: str='', ip: str=''): logging.debug("Requesting certificate with token %s", token) req = crypto.load_certificate_request(crypto.FILETYPE_PEM, req) with self.lock: with self.db: if token: if not self.config.prefix_length: raise HTTPError(http.client.FORBIDDEN) try: token, email, prefix_len, _ = next(self.db.execute( "SELECT * FROM token WHERE token = ?", (token,))) except StopIteration: return self.db.execute("DELETE FROM token WHERE token = ?", (token,)) else: prefix_len = self.config.anonymous_prefix_length if not prefix_len: raise HTTPError(http.client.FORBIDDEN) email = None country, continent = '*', '*' if self.geoip_db: country, continent = location.split(',') if location else self._geoiplookup(ip) if continent != '*': continent = '@' + continent prefix = self.newPrefix(prefix_len, self.getCommunity(country, continent)) self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?", (email, prefix)) if self.prefix is None: self.prefix = prefix self.setConfig('prefix', prefix) self.updateNetworkConfig() subject = req.get_subject() subject.serialNumber = str(self.getSubjectSerial()) return self.createCertificate(prefix, subject, req.get_pubkey()) def getSubjectSerial(self): # Smallest unique number, for IPv4 support. serials = [] for x in self.iterCert(): serial = x[0].get_subject().serialNumber if serial: serials.append(int(serial)) serials.sort() for serial, x in enumerate(serials): if serial != x: return serial return len(serials) def createCertificate(self, client_prefix, subject, pubkey, not_after=None): cert = crypto.X509() cert.gmtime_adj_notBefore(0) if not_after: cert.set_notAfter(not_after) else: cert.gmtime_adj_notAfter(self.cert_duration) cert.set_issuer(self.cert.ca.get_subject()) subject.CN = "%u/%u" % (int(client_prefix, 2), len(client_prefix)) cert.set_subject(subject) cert.set_pubkey(pubkey) # Certificate serial, for revocation support. Contrary to # subject serial, it does not need to be as small as possible. serial = 1 + self.getConfig('serial', 0) self.setConfig('serial', serial) cert.set_serial_number(serial) cert.sign(self.cert.key, 'sha512') cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) self.db.execute("UPDATE cert SET cert = ? WHERE prefix = ?", (cert, client_prefix)) self.timeout = 1 return cert @rpc def renewCertificate(self, cn: str) -> bytes: with self.lock: with self.db as db: pem = self.getCert(cn) cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem) if x509.notAfter(cert) - RENEW_PERIOD < time.time(): not_after = None elif db.execute("SELECT count(*) FROM crl WHERE serial=?", (cert.get_serial_number(),)).fetchone()[0]: not_after = cert.get_notAfter() else: return pem return self.createCertificate(cn, cert.get_subject(), cert.get_pubkey(), not_after) @rpc def getCa(self) -> bytes: return crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert.ca) @rpc def getDh(self, cn: str) -> bytes: with open(self.config.dh, "rb") as f: return f.read() @rpc def getNetworkConfig(self, cn: str) -> bytes: with self.lock: cert = self.getCert(cn) config = self.network_config.copy() hmac = [self.getConfig(k, None) for k in BABEL_HMAC] for i, v in enumerate(v for v in hmac if v is not None): config[('babel_hmac_sign', 'babel_hmac_accept')[i]] = \ v and base64.b64encode(x509.encrypt(cert, v)).decode("ascii") return zlib.compress(json.dumps(config).encode("utf-8")) def _queryAddress(self, peer: str) -> str: logging.info("Querying address for %s/%s %r", int(peer, 2), len(peer), peer) self.sendto(peer, 1) s = self.sock, timeout = 3 end = timeout + time.time() # Loop because there may be answers from previous requests. while select.select(s, (), (), timeout)[0]: prefix, msg = self.recv(1) logging.info("* received: %r - %r", prefix, msg) if prefix == peer: return msg timeout = max(0, end - time.time()) logging.info("Timeout while querying address for %s/%s", int(peer, 2), len(peer)) @rpc def getCountry(self, cn: str, address: str) -> bytes | None: country = self._geoiplookup(address)[0] return None if country == '*' else country.encode() @rpc def getBootstrapPeer(self, cn: str) -> bytes | None: logging.info("Answering bootstrap peer for %s", cn) with self.peers_lock: age, peers = self.peers if age < time.time() or not peers: self.request_dump() peers = [prefix for neigh_routes in self.ctl.neighbours.values() for prefix in neigh_routes[1] if prefix] peers.append(self.prefix) random.shuffle(peers) self.peers = time.time() + 60, peers logging.debug("peers: %r", peers) peer = peers.pop() if peer == cn: # Very unlikely (e.g. peer restarted with empty cache), # so don't bother looping over above code # (in case 'peers' is empty). peer = self.prefix with self.lock: msg = self._queryAddress(peer) if msg is None: logging.info("No address for %s, returning None", peer) return # Remove country for old nodes if self.getPeerProtocol(cn) < 7: msg = ';'.join(','.join(a.split(',')[:3]) for a in msg.split(';')) cert = self.getCert(cn) msg = "%s %s" % (peer, msg) logging.info("Sending bootstrap peer: %s", msg) return x509.encrypt(cert, msg.encode()) @rpc_private def revoke(self, cn_or_serial: int | str): with self.lock, self.db: q = self.db.execute try: serial = int(cn_or_serial) except ValueError: prefix = utils.binFromSubnet(cn_or_serial) cert = self.getCert(prefix) q("UPDATE cert SET email=null, cert=null WHERE prefix=?", (prefix,)) cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert) serial = cert.get_serial_number() self.sessions.pop(prefix, None) else: cert, = (cert for cert, prefix, email in self.iterCert() if cert.get_serial_number() == serial) not_after = x509.notAfter(cert) if time.time() < not_after: q("INSERT INTO crl VALUES (?,?)", (serial, not_after)) self.updateNetworkConfig() def newHMAC(self, i: int, key: bytes=None): if key is None: key = os.urandom(16) self.setConfig(BABEL_HMAC[i], key) def delHMAC(self, i: int): self.db.execute("DELETE FROM config WHERE name=?", (BABEL_HMAC[i],)) @rpc_private def updateHMAC(self): with self.lock, self.db: hmac = [self.getConfig(BABEL_HMAC[i], None) for i in (0,1,2)] if hmac[0]: if hmac[1]: self.newHMAC(2, hmac[0]) self.delHMAC(0) else: self.newHMAC(1) elif hmac[1]: self.newHMAC(0, hmac[1]) self.delHMAC(1) self.delHMAC(2) else: # Initialization of HMAC on the network self.newHMAC(1) self.newHMAC(2, b'') self.increaseVersion() self.setConfig('version', self.version) self.network_config['version'] = base64.b64encode(self.version) self.sendto(self.prefix, 0) @rpc_private def getNodePrefix(self, email: str) -> str | None: with self.lock, self.db: try: cert, = next(self.db.execute("SELECT cert FROM cert WHERE email = ?", (email,))) except StopIteration: return certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert) return x509.subnetFromCert(certificate) @rpc_private def getIPv6Address(self, email: str) -> str: cn = self.getNodePrefix(email) if cn: return utils.ipFromBin( x509.networkFromCa(self.cert.ca) + utils.binFromSubnet(cn)) @rpc_private def getIPv4Information(self, email: str) -> bytes | None: peer = self.getNodePrefix(email) if peer: peer = utils.binFromSubnet(peer) with self.peers_lock: self.request_dump() for neigh_routes in self.ctl.neighbours.values(): for prefix in neigh_routes[1]: if prefix == peer: break else: return logging.info("%s %s", email, peer) with self.lock: msg = self._queryAddress(peer).decode() if msg: return msg.split(',')[0].encode() @rpc_private def versions(self) -> str: with self.peers_lock: self.request_dump() peers = {prefix for neigh_routes in self.ctl.neighbours.values() for prefix in neigh_routes[1] if prefix} peers.add(self.prefix) peer_dict = {} s = self.sock, with self.lock: while True: r, w, _ = select.select(s, s if peers else (), (), 3) if r: prefix, ver = self.recv(4) if prefix: peer_dict[prefix] = ver if w: prefix = peers.pop() peer_dict[prefix] = None self.sendto(prefix, 4) elif not r: break return json.dumps(peer_dict) @rpc_private def topology(self) -> str: logging.info("Computing topology") p = lambda p: '%s/%s' % (int(p, 2), len(p)) peers = deque((p(self.prefix),)) graph = defaultdict(set) s = self.sock, with self.lock: while True: r, w, _ = select.select(s, s if peers else (), (), 3) if r: prefix, x = self.recv(5) logging.info("Received %s %s", prefix, x) if prefix and x: prefix = p(prefix) x = x.split() try: n = int(x.pop(0)) except ValueError: continue if n <= len(x) and prefix not in x: graph[prefix].update(x[:n]) peers += set(x).difference(graph) for x in x[n:]: graph[x].add(prefix) graph[''].add(prefix) if w: first = peers.popleft() logging.info("Sending %s", first) self.sendto(utils.binFromSubnet(first), 5) elif not r: logging.info("No more sockets, stopping") break return json.dumps({k: list(v) for k, v in graph.items()}) class RegistryClient: _hmac = None user_agent = "re6stnet/%s, %s" % (version.version, platform.platform()) def __init__(self, url: str, cert: x509.Cert=None, auto_close=True): self.cert = cert self.auto_close = auto_close url_parsed = urlparse(url) scheme, host, path = url_parsed.scheme, url_parsed.netloc, url_parsed.path self._conn = dict(http=http.client.HTTPConnection, https=http.client.HTTPSConnection, )[scheme](unquote(host), timeout=60) self._path = path.rstrip('/') def __getattr__(self, name: str): getcallargs = getattr(RegistryServer, name).getcallargs def rpc(*args, **kw) -> bytes: kw = getcallargs(*args, **kw) query = '/' + name if kw: if any(not isinstance(v, (str, bytes)) for v in kw.values()): raise TypeError(kw) query += '?' + urlencode(kw) url = self._path + query client_prefix = kw.get('cn') retry = True try: while retry: if client_prefix: key = self._hmac if not key: retry = False h = self.hello(client_prefix, str(version.protocol)) n = len(h) // 2 self.cert.verify(h[n:], h[:n]) key = self.cert.decrypt(h[:n]) h = hmac.HMAC(key, query.encode(), hashlib.sha1).digest() key = hashlib.sha1(key).digest() self._hmac = hashlib.sha1(key).digest() else: retry = False self._conn.putrequest('GET', url, skip_accept_encoding=1) self._conn.putheader('User-Agent', self.user_agent) if client_prefix: self._conn.putheader(HMAC_HEADER, base64.b64encode(h)) self._conn.endheaders() response = self._conn.getresponse() body = response.read() #print(query, repr(body)) if response.status in (http.client.OK, http.client.NO_CONTENT): if (not client_prefix or hmac.HMAC(key, body, hashlib.sha1).digest() == base64.b64decode(response.msg[HMAC_HEADER])): if self.auto_close and name != 'hello': self._conn.close() return body elif response.status == http.client.FORBIDDEN: # XXX: We should improve error handling, while making # sure re6st nodes don't crash on temporary errors. # This is currently good enough for re6st-conf, to # inform the user when registration is disabled. raise HTTPError(response.status, response.reason) if client_prefix: self._hmac = None except HTTPError: raise except Exception: logging.info(url, exc_info=True) else: logging.info('%s\nUnexpected response %s %s', url, response.status, response.reason) self._conn.close() setattr(self, name, rpc) return rpc