Commit 9c772060 authored by Vincent Pelletier's avatar Vincent Pelletier

utils: Genericise getCertList and saveCertList.

So they can be reused for more PEM-encoded types.
parent 05ca7a95
...@@ -2917,13 +2917,17 @@ class CaucaseTest(unittest.TestCase): ...@@ -2917,13 +2917,17 @@ class CaucaseTest(unittest.TestCase):
self.assertTrue(os.path.exists(self._client_ca_crt)) self.assertTrue(os.path.exists(self._client_ca_crt))
self.assertTrue(os.path.isfile(self._client_ca_crt)) self.assertTrue(os.path.isfile(self._client_ca_crt))
self.assertItemsEqual(utils.getCertList(self._client_ca_crt), [crt0_pem]) self.assertItemsEqual(utils.getCertList(self._client_ca_crt), [crt0_pem])
# Invalid file gets deleted # Invalid file gets deleted only if it has expected extension (.ca.pem)
dummy_file_path = os.path.join(self._client_ca_dir, 'not_a_pem') kept_file_path = os.path.join(self._client_ca_dir, 'not_a_ca.pem')
with open(dummy_file_path, 'wb') as dummy: deleted_file_path = os.path.join(self._client_ca_dir, 'foo.ca.pem')
with open(kept_file_path, 'wb'), open(deleted_file_path, 'wb'):
pass pass
self.assertTrue(os.path.exists(dummy_file_path)) self.assertTrue(os.path.exists(kept_file_path))
self.assertTrue(os.path.exists(deleted_file_path))
utils.saveCertList(self._client_ca_dir, [crt0_pem]) utils.saveCertList(self._client_ca_dir, [crt0_pem])
self.assertFalse(os.path.exists(dummy_file_path)) self.assertTrue(os.path.exists(kept_file_path))
self.assertFalse(os.path.exists(deleted_file_path))
os.unlink(kept_file_path)
# Storing and loading multiple certificates # Storing and loading multiple certificates
utils.saveCertList(self._client_ca_dir, [crt0_pem, crt1_pem]) utils.saveCertList(self._client_ca_dir, [crt0_pem, crt1_pem])
crta, crtb = os.listdir(self._client_ca_dir) crta, crtb = os.listdir(self._client_ca_dir)
......
...@@ -24,7 +24,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices ...@@ -24,7 +24,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices
Small-ish functions needed in many places. Small-ish functions needed in many places.
""" """
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
from binascii import a2b_base64, b2a_base64 from binascii import a2b_base64, b2a_base64, hexlify
import calendar import calendar
import codecs import codecs
from collections import defaultdict from collections import defaultdict
...@@ -124,26 +124,25 @@ def _getPEMTypeDict(path, result=None): ...@@ -124,26 +124,25 @@ def _getPEMTypeDict(path, result=None):
def getCertList(crt_path): def getCertList(crt_path):
""" """
Return a list of certificates. Return a list of certificates.
Raises if there is anything else than a certificate.
""" """
if not os.path.exists(crt_path): return _getPEMListFromPath(crt_path, pem.Certificate)
def _getPEMListFromPath(path, pem_type):
if not os.path.exists(path):
return [] return []
if os.path.isdir(crt_path): return [
file_list = [os.path.join(crt_path, x) for x in os.listdir(crt_path)] pem_object.as_bytes()
else: for file_name in (
file_list = [crt_path] [os.path.join(path, x) for x in os.listdir(path)]
result = [] if os.path.isdir(path) else
for file_name in file_list: [path]
type_dict = _getPEMTypeDict(file_name) )
crt_list = type_dict.pop(pem.Certificate) for pem_object in _getPEMTypeDict(file_name).get(pem_type, ())
if type_dict: ]
raise ValueError('%s contains more than just certificates' % (file_name, ))
result.extend(x.as_bytes() for x in crt_list)
return result
def saveCertList(crt_path, cert_pem_list): def saveCertList(crt_path, cert_pem_list):
""" """
Store given list of PEm-encoded certificates in given path. Store given list of PEM-encoded certificates in given path.
crt_path (str) crt_path (str)
May point to a directory a file, or nothing. May point to a directory a file, or nothing.
...@@ -154,61 +153,70 @@ def saveCertList(crt_path, cert_pem_list): ...@@ -154,61 +153,70 @@ def saveCertList(crt_path, cert_pem_list):
cert_pem_list (list of bytes) cert_pem_list (list of bytes)
""" """
if os.path.exists(crt_path): _savePEMList(crt_path, cert_pem_list, load_ca_certificate, '.ca.pem')
if os.path.isfile(crt_path):
saveCertListTo = _saveCertListToFile def _savePEMList(path, pem_list, pem_loader, extension):
elif os.path.isdir(crt_path): if os.path.exists(path):
saveCertListTo = _saveCertListToDirectory if os.path.isfile(path):
savePEMList = _savePEMListToFile
elif os.path.isdir(path):
savePEMList = _savePEMListToDirectory
else: else:
raise TypeError('%s exist and is neither a directory nor a file' % ( raise TypeError('%s exist and is neither a directory nor a file' % (
crt_path, path,
)) ))
else: else:
saveCertListTo = ( savePEMList = (
_saveCertListToFile _savePEMListToFile
if os.path.splitext(crt_path)[1] else if os.path.splitext(path)[1] else
_saveCertListToDirectory _savePEMListToDirectory
) )
saveCertListTo(crt_path, cert_pem_list) savePEMList(path, pem_list, pem_loader, extension)
def _saveCertListToFile(ca_crt_path, cert_pem_list): def _savePEMListToFile(file_path, pem_list, pem_loader, extension):
with open(ca_crt_path, 'wb') as ca_crt_file: _ = pem_loader # Silence pylint
ca_crt_file.write(b''.join(cert_pem_list)) _ = extension # Silence pylint
with open(file_path, 'wb') as pem_file:
def _saveCertListToDirectory(crt_dir, cert_pem_list): for pem_chunk in pem_list:
if not os.path.exists(crt_dir): pem_file.write(pem_chunk)
os.mkdir(crt_dir)
ca_cert_dict = { def _savePEMListToDirectory(dir_path, pem_list, pem_loader, extension):
'%x.pem' % (load_ca_certificate(x).serial_number, ): x if not os.path.exists(dir_path):
for x in cert_pem_list os.mkdir(dir_path)
pem_dict = {
hexlify(
pem_loader(x).extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier,
).value.key_identifier,
).decode('ascii') + extension: x
for x in pem_list
} }
for cert_filename in os.listdir(crt_dir): for filename in os.listdir(dir_path):
ca_crt_path = os.path.join(crt_dir, cert_filename) filepath = os.path.join(dir_path, filename)
if not os.path.isfile(ca_crt_path): if not filepath.endswith(extension) or not os.path.isfile(filepath):
# Not a file and not a symlink to a file, ignore # Not a managed file name and not a symlink to a file, ignore
continue continue
if not os.path.islink(ca_crt_path) and cert_filename in ca_cert_dict: if not os.path.islink(filepath) and filename in pem_dict:
try: try:
# pylint: disable=unbalanced-tuple-unpacking # pylint: disable=unbalanced-tuple-unpacking
cert, = getCertList(ca_crt_path) file_pem_item, = _getPEMTypeDict(filepath).itervalues()
# pylint: enable=unbalanced-tuple-unpacking # pylint: enable=unbalanced-tuple-unpacking
# pylint: disable=broad-except # pylint: disable=broad-except
except Exception: except Exception:
# pylint: enable=broad-except # pylint: enable=broad-except
# Inconsistent content (multiple certificates, not CA certificates, # File contains multiple PEM items: overwrite
# ...): overwrite file
pass pass
else: else:
if cert == ca_cert_dict[cert_filename]: if file_pem_item == pem_dict[filename]:
# Already consistent, do not edit. # Already consistent, do not edit.
del ca_cert_dict[cert_filename] del pem_dict[filename]
else: else:
# Unknown file (ex: expired certificate), or a symlink to a file: delete # Unknown file (ex: expired certificate), or a symlink to a file: delete
os.unlink(ca_crt_path) os.unlink(filepath)
for cert_filename, cert_pem in ca_cert_dict.items(): for filename, pem_item in pem_dict.iteritems():
ca_crt_path = os.path.join(crt_dir, cert_filename) filepath = os.path.join(dir_path, filename)
with open(ca_crt_path, 'wb') as ca_crt_file: with open(filepath, 'wb') as pem_file:
ca_crt_file.write(cert_pem) pem_file.write(pem_item)
def getCert(crt_path): def getCert(crt_path):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment