Commit 5fe1e86b authored by Vincent Pelletier's avatar Vincent Pelletier

all: Keep track of certificate issuances.

And use this tracking to to warn about surviving certificates which are
related to the one just revoked - they may need some attention too.

NOTE: While this should be correctly implemented, I think this is not
usable, and hence probably not worth the extra complexity: what can one
do when given a list of serials ? This version discards old tracking
entries, but even if it did not how is one supposed to browse these ?
parent 228e01d7
...@@ -38,7 +38,9 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes ...@@ -38,7 +38,9 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from . import utils from . import utils
from .exceptions import ( from .exceptions import (
CertificateVerificationError, CertificateVerificationError,
CertificateRevokedError,
NotACertificateSigningRequest, NotACertificateSigningRequest,
Found,
) )
__all__ = ('CertificateAuthority', 'UserCertificateAuthority', 'Extension') __all__ = ('CertificateAuthority', 'UserCertificateAuthority', 'Extension')
...@@ -245,7 +247,12 @@ class CertificateAuthority(object): ...@@ -245,7 +247,12 @@ class CertificateAuthority(object):
if requested_amount is not None and \ if requested_amount is not None and \
requested_amount <= self._auto_sign_csr_amount: requested_amount <= self._auto_sign_csr_amount:
# if allowed to sign this certificate automaticaly # if allowed to sign this certificate automaticaly
self._createCertificate(csr_id, auto_signed=_AUTO_SIGNED_YES) self._createCertificate(
csr_id,
auto_signed=_AUTO_SIGNED_YES,
is_renewal=False,
authorisation_serial=None,
)
return csr_id return csr_id
def deletePendingCertificateSigningRequest(self, csr_id): def deletePendingCertificateSigningRequest(self, csr_id):
...@@ -264,7 +271,7 @@ class CertificateAuthority(object): ...@@ -264,7 +271,7 @@ class CertificateAuthority(object):
""" """
return self._storage.getCertificateSigningRequestList() return self._storage.getCertificateSigningRequestList()
def createCertificate(self, csr_id, template_csr=None): def createCertificate(self, csr_id, template_csr=None, authorisation_serial=None):
""" """
Sign a pending certificate signing request, storing produced certificate. Sign a pending certificate signing request, storing produced certificate.
...@@ -274,39 +281,97 @@ class CertificateAuthority(object): ...@@ -274,39 +281,97 @@ class CertificateAuthority(object):
Copy extensions and subject from this CSR instead of stored one. Copy extensions and subject from this CSR instead of stored one.
Useful to renew a certificate. Useful to renew a certificate.
Public key is always copied from stored CSR. Public key is always copied from stored CSR.
authorisation_serial (int, None)
Serial of the certificate which authorised certificate issuance.
""" """
self._createCertificate( self._createCertificate(
csr_id=csr_id, csr_id=csr_id,
auto_signed=_AUTO_SIGNED_NO, auto_signed=_AUTO_SIGNED_NO,
template_csr=template_csr, template_csr=template_csr,
is_renewal=False,
authorisation_serial=authorisation_serial,
) )
def _createCertificate(self, csr_id, auto_signed, template_csr=None): def _createCertificate(
self,
csr_id,
auto_signed,
is_renewal,
authorisation_serial,
template_csr=None,
):
""" """
auto_signed (bool) auto_signed (bool)
When True, mark certificate as having been auto-signed. When True, mark certificate as having been auto-signed.
When False, prevent such mark from being set. When False, prevent such mark from being set.
When None, do not filter (useful when renewing). When None, do not filter (useful when renewing).
is_renewal (bool)
True to signal a renewal, False otherwise.
authorisation_serial (int, None)
If non-None, the serial of the certificate which authorised certificate
issuance:
- If is_renewal is True, this is the serial of the certificate being
renewed.
- Otherwise, this is the serial of user certificate who triggered the
renewal.
If None, this is an auto-issued certificate.
tempate_csr (None or X509Req) tempate_csr (None or X509Req)
Copy extensions and subject from this CSR instead of stored one. Copy extensions and subject from this CSR instead of stored one.
Useful to renew a certificate. Useful to renew a certificate.
Public key is always copied from stored CSR. Public key is always copied from stored CSR.
""" """
csr_pem = self._storage.getCertificateSigningRequest(csr_id) not_valid_before = datetime.datetime.utcnow()
csr = utils.load_certificate_request(csr_pem) not_valid_after = not_valid_before + self._crt_life_time
csr = utils.load_certificate_request(
self._storage.getCertificateSigningRequest(csr_id),
)
# Note: this is quite unlikely to loop even once, as
# x509.random_serial_number produces a 160-bits random number.
while True:
serial_number = x509.random_serial_number()
try:
with self._storage.trackIssuance(
is_renewal=is_renewal,
authorisation_serial=authorisation_serial,
serial=serial_number,
not_valid_after=utils.datetime2timestamp(not_valid_after),
) as storeCertificate:
return self.__createCertificate(
csr_id=csr_id,
csr=csr,
auto_signed=auto_signed,
template_csr=template_csr,
serial_number=serial_number,
not_valid_before=not_valid_before,
not_valid_after=not_valid_after,
storeCertificate=storeCertificate,
)
except Found: # pragma: no cover
pass
def __createCertificate(
self,
csr_id,
csr,
auto_signed,
template_csr,
serial_number,
not_valid_before,
not_valid_after,
storeCertificate,
):
if template_csr is None: if template_csr is None:
template_csr = csr template_csr = csr
ca_key_pair = self._getCurrentCAKeypair() ca_key_pair = self._getCurrentCAKeypair()
ca_crt = ca_key_pair['crt'] ca_crt = ca_key_pair['crt']
public_key = csr.public_key() public_key = csr.public_key()
now = datetime.datetime.utcnow()
builder = x509.CertificateBuilder( builder = x509.CertificateBuilder(
subject_name=template_csr.subject, subject_name=template_csr.subject,
issuer_name=ca_crt.subject, issuer_name=ca_crt.subject,
not_valid_before=now, not_valid_before=not_valid_before,
not_valid_after=now + self._crt_life_time, not_valid_after=not_valid_after,
serial_number=x509.random_serial_number(), serial_number=serial_number,
public_key=public_key, public_key=public_key,
extensions=[ extensions=[
Extension( Extension(
...@@ -459,7 +524,7 @@ class CertificateAuthority(object): ...@@ -459,7 +524,7 @@ class CertificateAuthority(object):
algorithm=self._default_digest_class(), algorithm=self._default_digest_class(),
backend=_cryptography_backend, backend=_cryptography_backend,
)) ))
self._storage.storeCertificate(csr_id, cert_pem) storeCertificate(csr_id, cert_pem)
return cert_pem return cert_pem
def getCertificate(self, csr_id): def getCertificate(self, csr_id):
...@@ -662,6 +727,7 @@ class CertificateAuthority(object): ...@@ -662,6 +727,7 @@ class CertificateAuthority(object):
crt_pem (str) crt_pem (str)
PEM-encoded certificat to revoke. PEM-encoded certificat to revoke.
""" """
try:
crt = utils.load_certificate( crt = utils.load_certificate(
crt_pem, crt_pem,
self.getCACertificateList(), self.getCACertificateList(),
...@@ -670,6 +736,8 @@ class CertificateAuthority(object): ...@@ -670,6 +736,8 @@ class CertificateAuthority(object):
_cryptography_backend, _cryptography_backend,
), ),
) )
except CertificateRevokedError:
raise Found
self._storage.revoke( self._storage.revoke(
serial=crt.serial_number, serial=crt.serial_number,
expiration_date=utils.datetime2timestamp(crt.not_valid_after), expiration_date=utils.datetime2timestamp(crt.not_valid_after),
...@@ -695,6 +763,18 @@ class CertificateAuthority(object): ...@@ -695,6 +763,18 @@ class CertificateAuthority(object):
)), )),
) )
def getIssuedBy(self, serial_list, renewal):
"""
(see .storage.getIssuedBy)
"""
return self._storage.getIssuedBy(serial_list, renewal)
def getNonRevokedCertificateSerialList(self, serial_list):
"""
(see .storage.getNonRevokedCertificateSerialList)
"""
return self._storage.getNonRevokedCertificateSerialList(serial_list)
def renew(self, crt_pem, csr_pem): def renew(self, crt_pem, csr_pem):
""" """
Renew certificate. Renew certificate.
...@@ -718,6 +798,8 @@ class CertificateAuthority(object): ...@@ -718,6 +798,8 @@ class CertificateAuthority(object):
override_limits=True, override_limits=True,
), ),
auto_signed=_AUTO_SIGNED_PASSTHROUGH, auto_signed=_AUTO_SIGNED_PASSTHROUGH,
is_renewal=True,
authorisation_serial=crt.serial_number,
# Do a dummy signature, just so we get a usable # Do a dummy signature, just so we get a usable
# x509.CertificateSigningRequest instance. Use latest CA private key just # x509.CertificateSigningRequest instance. Use latest CA private key just
# because it is available for free (unlike generating a new one). # because it is available for free (unlike generating a new one).
......
...@@ -193,6 +193,18 @@ class CLICaucaseClient(object): ...@@ -193,6 +193,18 @@ class CLICaucaseClient(object):
crt_file.write(crt_pem) crt_file.write(crt_pem)
return warning, error return warning, error
def _printRevocationResult(self, result):
if result:
self._print(
'WARNING: The certificate just revoked has been used to issue '
'certificates with the following serials:',
)
for mode, serial_list in result.iteritems():
self._print('mode:', mode)
for serial in serial_list:
self._print(' ', serial)
self._print('You may want to check these.')
def revokeCRT(self, error, crt_key_list): def revokeCRT(self, error, crt_key_list):
""" """
--revoke-crt --revoke-crt
...@@ -209,7 +221,14 @@ class CLICaucaseClient(object): ...@@ -209,7 +221,14 @@ class CLICaucaseClient(object):
) )
error = True error = True
continue continue
self._client.revokeCertificate(crt, key) try:
result = self._client.revokeCertificate(crt, key)
except CaucaseError as e:
if e.args[0] != httplib.CONFLICT:
raise
self._print('Certificate', crt_path, 'was already revoked')
result = e.args[2]
self._printRevocationResult(result)
return error return error
def renewCRT( def renewCRT(
...@@ -343,14 +362,29 @@ class CLICaucaseClient(object): ...@@ -343,14 +362,29 @@ class CLICaucaseClient(object):
), ),
file=self._stderr, file=self._stderr,
) )
self._client.revokeCertificate(crt_pem) try:
result = self._client.revokeCertificate(crt_pem)
except CaucaseError as e:
if e.args[0] != httplib.CONFLICT:
raise
self._print('Certificate', crt_path, 'was already revoked')
result = e.args[2]
self._printRevocationResult(result)
return error
def revokeSerial(self, serial_list): def revokeSerial(self, serial_list):
""" """
--revoke-serial --revoke-serial
""" """
for serial in serial_list: for serial in serial_list:
self._client.revokeSerial(serial) try:
result = self._client.revokeSerial(serial)
except CaucaseError as e:
if e.args[0] != httplib.CONFLICT:
raise
self._print('Certificate', serial, 'was already revoked')
result = e.args[2]
self._printRevocationResult(result)
def main(argv=None, stdout=sys.stdout, stderr=sys.stderr): def main(argv=None, stdout=sys.stdout, stderr=sys.stderr):
""" """
......
...@@ -330,12 +330,19 @@ class CaucaseClient(object): ...@@ -330,12 +330,19 @@ class CaucaseClient(object):
data = utils.nullWrap({ data = utils.nullWrap({
'revoke_crt_pem': crt, 'revoke_crt_pem': crt,
}) })
method( try:
return json.loads(method(
'PUT', 'PUT',
'/crt/revoke', '/crt/revoke',
json.dumps(data).encode('utf-8'), json.dumps(data).encode('utf-8'),
{'Content-Type': 'application/json'}, {'Content-Type': 'application/json'},
) ).decode('utf-8'))
except CaucaseError as e:
if e.args[0] != httplib.CONFLICT: # pragma: no cover
raise
args = list(e.args)
args[2] = json.loads(e.args[2].decode('utf-8'))
raise CaucaseError(*args)
def revokeSerial(self, serial): def revokeSerial(self, serial):
""" """
...@@ -345,12 +352,19 @@ class CaucaseClient(object): ...@@ -345,12 +352,19 @@ class CaucaseClient(object):
[AUTHENTICATED] [AUTHENTICATED]
""" """
self._https( try:
return json.loads(self._https(
'PUT', 'PUT',
'/crt/revoke', '/crt/revoke',
json.dumps(utils.nullWrap({'revoke_serial': serial})).encode('utf-8'), json.dumps(utils.nullWrap({'revoke_serial': serial})).encode('utf-8'),
{'Content-Type': 'application/json'}, {'Content-Type': 'application/json'},
) ).decode('utf-8'))
except CaucaseError as e:
if e.args[0] != httplib.CONFLICT: # pragma: no cover
raise
args = list(e.args)
args[2] = json.loads(e.args[2].decode('utf-8'))
raise CaucaseError(*args)
def createCertificate(self, csr_id, template_csr=''): def createCertificate(self, csr_id, template_csr=''):
""" """
......
...@@ -42,6 +42,10 @@ class CertificateVerificationError(CertificateAuthorityException): ...@@ -42,6 +42,10 @@ class CertificateVerificationError(CertificateAuthorityException):
"""Certificate is not valid, it was not signed by CA""" """Certificate is not valid, it was not signed by CA"""
pass pass
class CertificateRevokedError(CertificateVerificationError):
"""Certificate is revoked"""
pass
class NotACertificateSigningRequest(CertificateAuthorityException): class NotACertificateSigningRequest(CertificateAuthorityException):
"""Provided value is not a certificate signing request""" """Provided value is not a certificate signing request"""
pass pass
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
Caucase - Certificate Authority for Users, Certificate Authority for SErvices Caucase - Certificate Authority for Users, Certificate Authority for SErvices
""" """
from __future__ import absolute_import from __future__ import absolute_import
import contextlib
from random import getrandbits from random import getrandbits
import os import os
import sqlite3 import sqlite3
...@@ -53,7 +54,7 @@ class NoReentryConnection(sqlite3.Connection): ...@@ -53,7 +54,7 @@ class NoReentryConnection(sqlite3.Connection):
self.__entered = False self.__entered = False
return super(NoReentryConnection, self).__exit__(exc_type, exc_value, traceback) return super(NoReentryConnection, self).__exit__(exc_type, exc_value, traceback)
class SQLite3Storage(local): class SQLite3Storage(local): # pylint: disable=too-many-public-methods
""" """
CA data storage. CA data storage.
...@@ -106,10 +107,10 @@ class SQLite3Storage(local): ...@@ -106,10 +107,10 @@ class SQLite3Storage(local):
self._table_prefix = table_prefix self._table_prefix = table_prefix
db.row_factory = sqlite3.Row db.row_factory = sqlite3.Row
self._max_csr_amount = max_csr_amount self._max_csr_amount = max_csr_amount
self._crt_keep_time = crt_keep_time * DAY_IN_SECONDS self._crt_keep_time = int(crt_keep_time * DAY_IN_SECONDS)
self._crt_read_keep_time = crt_read_keep_time * DAY_IN_SECONDS self._crt_read_keep_time = crt_read_keep_time * DAY_IN_SECONDS
with db: with db:
# Note about revoked.serial: certificate serials exceed the 63 bits # Note about serials: certificate serials exceed the 63 bits
# sqlite can accept as integers, so store these as text. Use a trivial # sqlite can accept as integers, so store these as text. Use a trivial
# string serialisation: not very space efficient, but this should not be # string serialisation: not very space efficient, but this should not be
# a limiting issue for our use-cases anyway. # a limiting issue for our use-cases anyway.
...@@ -142,7 +143,13 @@ class SQLite3Storage(local): ...@@ -142,7 +143,13 @@ class SQLite3Storage(local):
CREATE TABLE IF NOT EXISTS %(prefix)sconfig_once ( CREATE TABLE IF NOT EXISTS %(prefix)sconfig_once (
name TEXT PRIMARY KEY, name TEXT PRIMARY KEY,
value TEXT value TEXT
) );
CREATE TABLE IF NOT EXISTS %(prefix)sissuance (
serial TEXT PRIMARY KEY,
is_renewal INTEGER,
authorisation_serial TEXT,
not_valid_after INTEGER
);
''' % { ''' % {
'prefix': table_prefix, 'prefix': table_prefix,
'key_id_constraint': 'UNIQUE' if enforce_unique_key_id else '', 'key_id_constraint': 'UNIQUE' if enforce_unique_key_id else '',
...@@ -374,29 +381,6 @@ class SQLite3Storage(local): ...@@ -374,29 +381,6 @@ class SQLite3Storage(local):
).fetchall() ).fetchall()
] ]
def storeCertificate(self, csr_id, crt):
"""
Store certificate for pre-existing CSR.
Raises NotFound if there is no matching CSR, or if a certificate was
already stored.
"""
with self._db as db:
c = db.cursor()
c.execute(
'UPDATE %scrt SET crt=?, expiration_date = ? '
'WHERE id = ? AND crt IS NULL' % (
self._table_prefix,
),
(
crt,
int(time() + self._crt_keep_time),
csr_id,
),
)
if c.rowcount == 0:
raise NotFound
def getCertificate(self, crt_id): def getCertificate(self, crt_id):
""" """
Retrieve a PEM-encoded certificate. Retrieve a PEM-encoded certificate.
...@@ -463,6 +447,82 @@ class SQLite3Storage(local): ...@@ -463,6 +447,82 @@ class SQLite3Storage(local):
break break
yield toBytes(row['crt']) yield toBytes(row['crt'])
@contextlib.contextmanager
def trackIssuance(self, is_renewal, authorisation_serial, serial, not_valid_after):
"""
Track certificate issuance.
is_renewal (bool)
If true, this is a renewal.
Otherwise, it is an original issuance.
authorisation_serial (int)
Serial of the certificate which authorised issuance:
- renewed certificate serial if is_renewal is true
- user certificate serial otherwise
- None if certificate was auto-signed
serial (int)
Serial of issued certificate.
not_valid_after (int)
Expiration timestamp of issued certificate.
Returns a context manager which, on entry raises Found if serial has
already been issued. Otherwise, returns a callable taking 2 parameters
storing certificate for a pre-existing CSR:
csr_id (int)
CSR identifier.
crt (str)
PEM-encoded certificate.
Raises NotFound if there is no matching CSR, or if a certificate was
already stored.
"""
with self._db as db:
now = int(time())
c = db.cursor()
c.execute(
'DELETE FROM %sissuance '
'WHERE not_valid_after < ?' % (
self._table_prefix,
),
(
now,
),
)
try:
c.execute(
'INSERT INTO %sissuance '
' (is_renewal, authorisation_serial, serial, not_valid_after) '
'VALUES (?, ?, ?, ?)' % (
self._table_prefix,
),
(
int(is_renewal),
str(authorisation_serial),
str(serial),
not_valid_after,
),
)
except sqlite3.IntegrityError: # pragma: no cover
raise Found
# Just to have a mutable
store_argument_list = []
yield lambda csr_id, crt: store_argument_list.append((csr_id, crt))
# pylint: disable=unbalanced-tuple-unpacking
(csr_id, crt), = store_argument_list
# pylint: enable=unbalanced-tuple-unpacking
c.execute(
'UPDATE %scrt SET crt=?, expiration_date = ? '
'WHERE id = ? AND crt IS NULL' % (
self._table_prefix,
),
(
crt,
now + self._crt_keep_time,
csr_id,
),
)
if c.rowcount == 0:
raise NotFound
def revoke(self, serial, expiration_date): def revoke(self, serial, expiration_date):
""" """
Add given certificate serial to the list of revoked certificates. Add given certificate serial to the list of revoked certificates.
...@@ -495,6 +555,74 @@ class SQLite3Storage(local): ...@@ -495,6 +555,74 @@ class SQLite3Storage(local):
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
raise Found raise Found
def getIssuedBy(self, serial_list, renewal):
"""
Returns the list of serials of certificates which were issued, directly or
not, using the certificates with given serials.
serial_list (list of int)
Serial of the certificates to query the descendants of.
renewal (bool)
If true, only renewals will be followed.
Otherwise, initial issuances by given serials are fetched, and then
the renewals of these are followed. Renewals of given serials are not
followed.
"""
# do not mutate parameter
serial_list = list(serial_list)
result = set()
with self._db as db:
c = db.cursor()
while serial_list:
for next_serial, in c.execute(
'SELECT serial FROM %sissuance '
'WHERE authorisation_serial = ?%s' % (
self._table_prefix,
' and is_renewal = 1' if renewal else '',
),
(
str(serial_list.pop()),
),
):
next_serial = int(next_serial)
# Unlikely to be false: serials are server-enforced and random
# enough to be very unlikely to be reused. But still prevents a
# possible infinite loop.
if next_serial not in result:
serial_list.append(next_serial)
result.add(next_serial)
else: # pragma: no cover
pass
return list(result)
def getNonRevokedCertificateSerialList(self, serial_list):
"""
Return the list of serials whose certificates are not revoked, out of
those provided.
serial_list (list of int)
List of certificate serials to check for non-revocation.
"""
if serial_list:
with self._db as db:
revoked_serial_set = {
int(x)
for x, in db.cursor().execute(
'SELECT serial FROM %srevoked WHERE serial IN (%s)' % (
self._table_prefix,
','.join('?' * len(serial_list)),
),
[str(x) for x in serial_list],
)
}
else:
revoked_serial_set = ()
return [
x
for x in serial_list
if x not in revoked_serial_set
]
def getCertificateRevocationList(self): def getCertificateRevocationList(self):
""" """
Get PEM-encoded current Certificate Revocation List. Get PEM-encoded current Certificate Revocation List.
......
...@@ -24,8 +24,9 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices ...@@ -24,8 +24,9 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices
Test suite Test suite
""" """
# pylint: disable=too-many-lines, too-many-public-methods # pylint: disable=too-many-lines, too-many-public-methods
from __future__ import absolute_import from __future__ import absolute_import, print_function
from Cookie import SimpleCookie from Cookie import SimpleCookie
from collections import defaultdict
import datetime import datetime
# pylint: disable=no-name-in-module, import-error # pylint: disable=no-name-in-module, import-error
from distutils.spawn import find_executable from distutils.spawn import find_executable
...@@ -331,7 +332,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -331,7 +332,7 @@ class CaucaseTest(unittest.TestCase):
Prepare test data directory and file paths, and start caucased as most Prepare test data directory and file paths, and start caucased as most
tests will need to interact with it. tests will need to interact with it.
""" """
global _clean_caucased_snapshot global _clean_caucased_snapshot # pylint: disable=global-statement
self._data_dir = data_dir = tempfile.mkdtemp(prefix='caucase_test_') self._data_dir = data_dir = tempfile.mkdtemp(prefix='caucase_test_')
self._client_dir = client_dir = os.path.join(data_dir, 'client') self._client_dir = client_dir = os.path.join(data_dir, 'client')
...@@ -1067,23 +1068,17 @@ class CaucaseTest(unittest.TestCase): ...@@ -1067,23 +1068,17 @@ class CaucaseTest(unittest.TestCase):
# and crt & key did not change # and crt & key did not change
self.assertEqual(service2_crt_after, service2_crt_after2) self.assertEqual(service2_crt_after, service2_crt_after2)
self.assertEqual(service2_key_after, service2_key_after2) self.assertEqual(service2_key_after, service2_key_after2)
# revoking again one's own certificate fails # revoking again one's own certificate does not fail
self.assertRaises( self._runClient(
CaucaseError,
self._runClient,
'--revoke-crt', service2_key_path, '', '--revoke-crt', service2_key_path, '',
) )
# as does revoking with an authenticated user # as does revoking with an authenticated user
self.assertRaises( self._runClient(
CaucaseError,
self._runClient,
'--user-key', user_key_path, '--user-key', user_key_path,
'--revoke-other-crt', service2_key_path, '--revoke-other-crt', service2_key_path,
) )
# and revoking by serial # and revoking by serial
self.assertRaises( self._runClient(
CaucaseError,
self._runClient,
'--user-key', user_key_path, '--user-key', user_key_path,
'--revoke-serial', str( '--revoke-serial', str(
utils.load_certificate( utils.load_certificate(
...@@ -2290,6 +2285,16 @@ class CaucaseTest(unittest.TestCase): ...@@ -2290,6 +2285,16 @@ class CaucaseTest(unittest.TestCase):
user_key_path, user_key_path,
'user', 'user',
) )
user2_key_renewal_query_fragment = (
",1,'%i'" % utils.load_certificate(
utils.getCert(user2_key_path),
[
utils.load_ca_certificate(x)
for x in utils.getCertList(self._client_user_ca_crt)
],
None,
).serial_number
).encode('utf-8')
# user2 sacrifice their private key, and prepare its replacement # user2 sacrifice their private key, and prepare its replacement
basename = self._getBaseName() basename = self._getBaseName()
user2_new_key_path = self._createPrivateKey(basename) user2_new_key_path = self._createPrivateKey(basename)
...@@ -2377,12 +2382,14 @@ class CaucaseTest(unittest.TestCase): ...@@ -2377,12 +2382,14 @@ class CaucaseTest(unittest.TestCase):
CRL_INSERT = b'INSERT INTO "caucrl" ' CRL_INSERT = b'INSERT INTO "caucrl" '
CRT_INSERT = b'INSERT INTO "caucrt" ' CRT_INSERT = b'INSERT INTO "caucrt" '
REV_INSERT = b'INSERT INTO "caurevoked" ' REV_INSERT = b'INSERT INTO "caurevoked" '
ISSUED_INSERT = b'INSERT INTO "cauissuance" '
def filterBackup(backup, expect_rev): def filterBackup(backup, expect_rev):
""" """
Remove all lines which are know to differ between original batabase and Remove all lines which are know to differ between original batabase and
post-restoration database, so the rest (which must be the majority of the post-restoration database, so the rest (which must be the majority of the
database) can be tested to be equal. database) can be tested to be equal.
""" """
renew_found = None
rev_found = not expect_rev rev_found = not expect_rev
new_backup = [] new_backup = []
crt_list = [] crt_list = []
...@@ -2395,6 +2402,13 @@ class CaucaseTest(unittest.TestCase): ...@@ -2395,6 +2402,13 @@ class CaucaseTest(unittest.TestCase):
if row.startswith(REV_INSERT): # pragma: no cover if row.startswith(REV_INSERT): # pragma: no cover
assert not rev_found, 'Unexpected revocation found' assert not rev_found, 'Unexpected revocation found'
continue continue
if (
row.startswith(ISSUED_INSERT) and
user2_key_renewal_query_fragment in row
):
assert renew_found is None, '\n%r\n%r' % (renew_found, row)
renew_found = row
continue
new_backup.append(row) new_backup.append(row)
return new_backup, crt_list return new_backup, crt_list
...@@ -3127,6 +3141,220 @@ class CaucaseTest(unittest.TestCase): ...@@ -3127,6 +3141,220 @@ class CaucaseTest(unittest.TestCase):
) )
self.assertFalse(user_certificate_policies.critical) self.assertFalse(user_certificate_policies.critical)
def testIssuanceTracking(self):
"""
Issue a few certificates, revoke a few and check that all their
descendents are listed.
"""
self._runClient()
cas_crt_list = [
utils.load_ca_certificate(x)
for x in utils.getCertList(self._client_ca_crt)
]
self._runClient('--mode', 'user')
cau_crt_list = [
utils.load_ca_certificate(x)
for x in utils.getCertList(self._client_user_ca_crt)
]
def getSerial(key_path, ca_crt_list):
'get serial as a string'
return str(utils.load_certificate(
utils.getCert(key_path),
ca_crt_list,
None,
).serial_number)
def approveUser(user_key_path):
'issue a new user certificate'
new_key_path = self._createAndApproveCertificate(
user_key_path,
'user',
)
return new_key_path, getSerial(new_key_path, cau_crt_list)
def approveService(user_key_path):
'issue a new service certificate'
new_key_path = self._createAndApproveCertificate(
user_key_path,
'service',
)
return new_key_path, getSerial(new_key_path, cas_crt_list)
def renew(mode, key_path):
'renew a <mode> certificate'
self._runClient(
'--mode', mode,
'--threshold', '100',
'--renew-crt', key_path, '',
)
def renewUser(key_path):
'renew a user certificate'
renew('user', key_path)
return getSerial(key_path, cau_crt_list)
def renewService(key_path):
'renew a service certificate'
renew('service', key_path)
return getSerial(key_path, cas_crt_list)
def assertRevokeResult(expected_dict, args, already_revoked=False):
'revoke a certificate and check displayed descendant certificates'
mode = None
output_dict = defaultdict(list)
output = self._runClient(*args)
already_revoked_found = False
for line in output.splitlines():
if line.endswith('was already revoked'):
already_revoked_found = True
elif line.startswith('mode: '):
_, mode = line.split()
elif mode is not None and line.startswith(' '):
output_dict[mode].append(line.strip())
self.assertEqual(already_revoked, already_revoked_found)
self.assertItemsEqual(expected_dict.keys(), output_dict.keys())
for mode, expected_value_list in expected_dict.iteritems():
self.assertItemsEqual(expected_value_list, output_dict[mode])
# Get first user certificate
user0_key_path = self._createFirstUser()
user0_serial0 = getSerial(user0_key_path, cau_crt_list)
# renew user0 certificate
user0_serial1 = renewUser(user0_key_path)
# Sanity check renewal
self.assertNotEqual(user0_serial0, user0_serial1)
# Get two more user certificates approved by user0
user1_key_path, user1_serial = approveUser(user0_key_path)
user2_key_path, user2_serial0 = approveUser(user0_key_path)
# renew user2 certificate
user2_serial1 = renewUser(user2_key_path)
# Get a service certificate approved by each user
service0_key_path, service0_serial0 = approveService(user0_key_path)
service1_key_path, service1_serial0 = approveService(user1_key_path)
service2_key_path, service2_serial0 = approveService(user2_key_path)
# Renew each
service0_serial1 = renewService(service0_key_path)
service1_serial1 = renewService(service1_key_path)
service2_serial1 = renewService(service2_key_path)
# because serials are unreadable, print a mapping to display if test fails
test_output = utils.toUnicodeWritableStream(self.caucase_test_output)
for serial, name in (
(user0_serial0, 'user0_serial0'),
(user0_serial1, 'user0_serial1'),
(user1_serial, 'user1_serial'),
(user2_serial0, 'user2_serial0'),
(user2_serial1, 'user2_serial1'),
(service0_serial0, 'service0_serial0'),
(service1_serial0, 'service1_serial0'),
(service2_serial0, 'service2_serial0'),
(service0_serial1, 'service0_serial1'),
(service1_serial1, 'service1_serial1'),
(service2_serial1, 'service2_serial1'),
):
print(serial, 'is', name, file=test_output)
assertRevokeResult(
{ # One renewal, which issued one service, which was renewed
'user': [user2_serial1],
'service': [service2_serial0, service2_serial1],
}, (
'--mode', 'user',
'--user-key', user0_key_path,
'--revoke-serial', user2_serial0,
),
)
assertRevokeResult(
{ # One renewal
'service': [service2_serial1],
}, (
'--mode', 'service',
'--user-key', user0_key_path,
'--revoke-serial', service2_serial0,
),
)
assertRevokeResult(
{}, ( # No renewals
'--mode', 'service',
'--revoke-crt', service1_key_path, '',
),
)
assertRevokeResult(
{}, ( # Renewal is already revoked
'--mode', 'service',
'--user-key', user0_key_path,
'--revoke-serial', service1_serial0,
),
)
assertRevokeResult(
{ # renewal, 2 issued user certificates (one being the renewal of an
# already-revoked certificate), one issued service and its renewal,
# and the renewal of and already-revoked issued service certificate
'user': [
user0_serial1,
user1_serial,
user2_serial1,
],
'service': [
service0_serial0, service0_serial1,
service2_serial1,
],
}, (
'--mode', 'user',
'--user-key', user1_key_path,
'--revoke-serial', user0_serial0,
),
)
assertRevokeResult(
{ # 2 issued user certificates (one being the renewal of an
# already-revoked certificate), one issued service and its renewal,
# and the renewal of and already-revoked issued service certificate
'user': [
user1_serial,
user2_serial1,
],
'service': [
service0_serial0, service0_serial1,
service2_serial1,
],
}, (
'--mode', 'user',
'--revoke-crt', user0_key_path, '',
),
)
# Duplicate revocations
# Must mention they are duplicate, but still display an updated view of
# descendant certificates.
assertRevokeResult(
{ # 2 issued user certificates (one being the renewal of an
# already-revoked certificate), one issued service and its renewal,
# and the renewal of and already-revoked issued service certificate
'user': [
user1_serial,
user2_serial1,
],
'service': [
service0_serial0, service0_serial1,
service2_serial1,
],
}, (
'--mode', 'user',
'--user-key', user1_key_path,
'--revoke-serial', user0_serial0,
),
already_revoked=True,
)
assertRevokeResult(
{ # 2 issued user certificates (one being the renewal of an
# already-revoked certificate), one issued service and its renewal,
# and the renewal of and already-revoked issued service certificate
'user': [
user1_serial,
user2_serial1,
],
'service': [
service0_serial0, service0_serial1,
service2_serial1,
],
}, (
'--mode', 'user',
'--revoke-crt', user0_key_path, '',
),
already_revoked=True,
)
for property_id, property_value in CaucaseTest.__dict__.iteritems(): for property_id, property_value in CaucaseTest.__dict__.iteritems():
if property_id.startswith('test') and callable(property_value): if property_id.startswith('test') and callable(property_value):
setattr(CaucaseTest, property_id, print_buffer_on_error(property_value)) setattr(CaucaseTest, property_id, print_buffer_on_error(property_value))
......
...@@ -44,6 +44,7 @@ import cryptography.exceptions ...@@ -44,6 +44,7 @@ import cryptography.exceptions
import pem import pem
from .exceptions import ( from .exceptions import (
CertificateVerificationError, CertificateVerificationError,
CertificateRevokedError,
NotJSON, NotJSON,
) )
...@@ -342,10 +343,12 @@ def _verifyCertificateChain(cert, trusted_cert_list, crl): ...@@ -342,10 +343,12 @@ def _verifyCertificateChain(cert, trusted_cert_list, crl):
Verifies whether certificate has been signed by any of the trusted Verifies whether certificate has been signed by any of the trusted
certificates, is not revoked and is whithin its validity period. certificates, is not revoked and is whithin its validity period.
Raises CertificateVerificationError if validation fails. Raises CertificateVerificationError if validation fails, or its
CertificateRevokedError subclass if it is because this certificate was
revoked.
""" """
# Note: this function (validating a certificate without an SSL connection) # Note: this function (validating a certificate without an SSL connection)
# does not seem to have many equivalents at all in python. OpenSSL module # does not seem to have any equivalents at all in python. OpenSSL module
# seems to be a rare implementation of it, so we keep using this module. # seems to be a rare implementation of it, so we keep using this module.
# BUT it MUST NOT be used anywhere outside this function (hence the # BUT it MUST NOT be used anywhere outside this function (hence the
# bad-style local import). Use "cryptography". # bad-style local import). Use "cryptography".
...@@ -362,13 +365,14 @@ def _verifyCertificateChain(cert, trusted_cert_list, crl): ...@@ -362,13 +365,14 @@ def _verifyCertificateChain(cert, trusted_cert_list, crl):
store, store,
crypto.X509.from_cryptography(cert), crypto.X509.from_cryptography(cert),
).verify_certificate() ).verify_certificate()
except ( except crypto.X509StoreContextError as e:
crypto.X509StoreContextError, error, depth, _ = e.args[0]
crypto.Error, # 23 is X509_V_ERR_CERT_REVOKED (include/openssl/x509_vfy.h)
) as e: if error == 23 and depth == 0:
raise CertificateVerificationError( raise CertificateRevokedError(repr(e))
'Certificate verification error: %s' % str(e), raise CertificateVerificationError(repr(e))
) except crypto.Error as e:
raise CertificateVerificationError(repr(e))
def wrap(payload, key, digest): def wrap(payload, key, digest):
""" """
......
...@@ -23,6 +23,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices ...@@ -23,6 +23,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices
""" """
from __future__ import absolute_import from __future__ import absolute_import
from Cookie import SimpleCookie, CookieError from Cookie import SimpleCookie, CookieError
from functools import partial
import httplib import httplib
import json import json
import os import os
...@@ -180,6 +181,7 @@ STATUS_OK = _getStatus(httplib.OK) ...@@ -180,6 +181,7 @@ STATUS_OK = _getStatus(httplib.OK)
STATUS_CREATED = _getStatus(httplib.CREATED) STATUS_CREATED = _getStatus(httplib.CREATED)
STATUS_NO_CONTENT = _getStatus(httplib.NO_CONTENT) STATUS_NO_CONTENT = _getStatus(httplib.NO_CONTENT)
STATUS_FOUND = _getStatus(httplib.FOUND) STATUS_FOUND = _getStatus(httplib.FOUND)
STATUS_CONFLICT = Conflict.status
MAX_BODY_LENGTH = 10 * 1024 * 1024 # 10 MB MAX_BODY_LENGTH = 10 * 1024 * 1024 # 10 MB
class CORSTokenManager(object): class CORSTokenManager(object):
...@@ -297,6 +299,7 @@ class Application(object): ...@@ -297,6 +299,7 @@ class Application(object):
List of Origin values to always trust. List of Origin values to always trust.
""" """
self._cau = cau self._cau = cau
self._cas = cas
self._http_url = http_url.rstrip('/') self._http_url = http_url.rstrip('/')
self._https_url = https_url.rstrip('/') self._https_url = https_url.rstrip('/')
self._cors_cookie_id = cors_cookie_id self._cors_cookie_id = cors_cookie_id
...@@ -577,8 +580,6 @@ class Application(object): ...@@ -577,8 +580,6 @@ class Application(object):
raise raise
except exceptions.NotFound: except exceptions.NotFound:
raise NotFound raise NotFound
except exceptions.Found:
raise Conflict
except exceptions.NoStorage: except exceptions.NoStorage:
raise InsufficientStorage raise InsufficientStorage
except exceptions.NotJSON: except exceptions.NotJSON:
...@@ -604,12 +605,12 @@ class Application(object): ...@@ -604,12 +605,12 @@ class Application(object):
return result return result
@staticmethod @staticmethod
def _returnFile(data, content_type, header_list=None): def _returnFile(data, content_type, header_list=None, status=STATUS_OK):
if header_list is None: if header_list is None:
header_list = [] header_list = []
header_list.append(('Content-Type', content_type)) header_list.append(('Content-Type', content_type))
header_list.append(('Content-Length', str(len(data)))) header_list.append(('Content-Length', str(len(data))))
return (STATUS_OK, header_list, [data]) return (status, header_list, [data])
@staticmethod @staticmethod
def _getCSRID(subpath): def _getCSRID(subpath):
...@@ -644,11 +645,12 @@ class Application(object): ...@@ -644,11 +645,12 @@ class Application(object):
Verify user authentication. Verify user authentication.
Raises SSLUnauthorized if authentication does not pass checks. Raises SSLUnauthorized if authentication does not pass checks.
On success, appends a "Cache-Control" header. On success, appends a "Cache-Control" header and returns certificate
object.
""" """
try: try:
ca_list = self._cau.getCACertificateList() ca_list = self._cau.getCACertificateList()
utils.load_certificate( result = utils.load_certificate(
environ.get('SSL_CLIENT_CERT', b''), environ.get('SSL_CLIENT_CERT', b''),
trusted_cert_list=ca_list, trusted_cert_list=ca_list,
crl=utils.load_crl( crl=utils.load_crl(
...@@ -659,6 +661,7 @@ class Application(object): ...@@ -659,6 +661,7 @@ class Application(object):
except (exceptions.CertificateVerificationError, ValueError): except (exceptions.CertificateVerificationError, ValueError):
raise SSLUnauthorized raise SSLUnauthorized
header_list.append(('Cache-Control', 'private')) header_list.append(('Cache-Control', 'private'))
return result
def _readJSON(self, environ): def _readJSON(self, environ):
""" """
...@@ -1049,6 +1052,34 @@ class Application(object): ...@@ -1049,6 +1052,34 @@ class Application(object):
'application/pkix-cert', 'application/pkix-cert',
) )
def _getIssuedBy(self, context, serial):
"""
Return a dict of non-revoked certificates descending from given serial.
Keys are modes, values are list of serials.
"""
result = {}
if context is self._cau:
# Revoking a user certificate, list:
# - other user certificates it issued and their renewals
# - service certificates they issued and their renewals
user_serial_list = context.getIssuedBy([serial], False)
service_serial_list = self._cas.getNonRevokedCertificateSerialList(
self._cas.getIssuedBy([serial] + user_serial_list, False),
)
user_serial_list = context.getNonRevokedCertificateSerialList(
user_serial_list,
)
if user_serial_list:
result['user'] = user_serial_list
else:
# Revoking a service certificate, list its renewals.
service_serial_list = context.getNonRevokedCertificateSerialList(
context.getIssuedBy([serial], True),
)
if service_serial_list:
result['service'] = service_serial_list
return result
def revokeCertificate(self, context, environ): def revokeCertificate(self, context, environ):
""" """
Handle PUT /{context}/crt/revoke . Handle PUT /{context}/crt/revoke .
...@@ -1058,19 +1089,38 @@ class Application(object): ...@@ -1058,19 +1089,38 @@ class Application(object):
if data['digest'] is None: if data['digest'] is None:
self._authenticate(environ, header_list) self._authenticate(environ, header_list)
payload = utils.nullUnwrap(data) payload = utils.nullUnwrap(data)
if 'revoke_crt_pem' not in payload: if 'revoke_crt_pem' in payload:
context.revokeSerial(payload['revoke_serial']) crt_pem = utils.toBytes(payload['revoke_crt_pem'])
return (STATUS_NO_CONTENT, header_list, []) revoke = partial(context.revoke, crt_pem=crt_pem)
else: else:
payload = utils.unwrap( crt_pem = None
serial = payload['revoke_serial']
revoke = partial(context.revokeSerial, serial)
else:
crt_pem = utils.toBytes(utils.unwrap(
data, data,
lambda x: x['revoke_crt_pem'], lambda x: x['revoke_crt_pem'],
context.digest_list, context.digest_list,
)['revoke_crt_pem'])
revoke = partial(context.revoke, crt_pem=crt_pem)
if crt_pem is not None:
serial = utils.load_certificate(
crt_pem,
context.getCACertificateList(),
None, # context will check its revocation status
).serial_number
try:
revoke()
except exceptions.Found:
status = STATUS_CONFLICT
else:
status = STATUS_OK
return self._returnFile(
json.dumps(self._getIssuedBy(context, serial)).encode('utf-8'),
'application/json',
header_list,
status,
) )
context.revoke(
crt_pem=utils.toBytes(payload['revoke_crt_pem']),
)
return (STATUS_NO_CONTENT, header_list, [])
def renewCertificate(self, context, environ): def renewCertificate(self, context, environ):
""" """
...@@ -1104,9 +1154,10 @@ class Application(object): ...@@ -1104,9 +1154,10 @@ class Application(object):
else: else:
raise BadRequest(b'Bad Content-Type') raise BadRequest(b'Bad Content-Type')
header_list = [] header_list = []
self._authenticate(environ, header_list) user_crt = self._authenticate(environ, header_list)
context.createCertificate( context.createCertificate(
csr_id=crt_id, csr_id=crt_id,
template_csr=template_csr, template_csr=template_csr,
authorisation_serial=user_crt.serial_number,
) )
return (STATUS_NO_CONTENT, header_list, []) return (STATUS_NO_CONTENT, header_list, [])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment