Commit e9de51f0 authored by Vincent Pelletier's avatar Vincent Pelletier

all: Finalise python3 support.

Basically, wrap stdout and stderr whenever they do not have an encoding
with an ascii-encoding writer, and write unicode to stdout & stderr.
wsgi.errors is defined in the reference implementation as being a StringIO,
so follow that.
Stop using argparse.FileType to get rid of python3 "file not closed"
errors.
Also, fix setup access to CHANGES.txt .
Also, fix 2to3 involvement.
Also, replace test.captureStdout with extra tool arguments.
parent e9cd6586
This diff is collapsed.
......@@ -173,8 +173,12 @@ class CaucaseWSGIRequestHandler(WSGIRequestHandler):
remote_user_name = '-'
def __init__(self, *args, **kw):
self._log_file = kw.pop('log_file', sys.stdout)
self._error_file = kw.pop('error_file', sys.stderr)
self._log_file = utils.toUnicodeWritableStream(
kw.pop('log_file', sys.stdout),
)
self._error_file = utils.toUnicodeWritableStream(
kw.pop('error_file', sys.stderr),
)
WSGIRequestHandler.__init__(self, *args, **kw)
def log_date_time_string(self):
......@@ -580,8 +584,10 @@ def main(
help='Number of days between backups. default: %(default)s',
)
args = parser.parse_args(argv)
log_file = utils.toUnicodeWritableStream(log_file)
error_file = utils.toUnicodeWritableStream(error_file)
base_url = u'http://' + utils.toUnicode(args.netloc)
base_url = 'http://' + utils.toUnicode(args.netloc)
parsed_base_url = urlparse(base_url)
hostname = parsed_base_url.hostname
name_constraints_permited = []
......@@ -881,7 +887,7 @@ def main(
server.server_close()
server.shutdown()
def manage(argv=None):
def manage(argv=None, stdout=sys.stdout):
"""
caucased database management tool.
"""
......@@ -922,7 +928,6 @@ def manage(argv=None):
default=[],
metavar='PEM_FILE',
action='append',
type=argparse.FileType('rb'),
help='Import key pairs as initial service CA certificate. '
'May be provided multiple times to import multiple key pairs. '
'Keys and certificates may be in separate files. '
......@@ -948,7 +953,6 @@ def manage(argv=None):
default=[],
metavar='PEM_FILE',
action='append',
type=argparse.FileType('rb'),
help='Import service revocation list. Corresponding CA certificate must '
'be already present in the database (including added in the same run '
'using --import-ca).',
......@@ -956,11 +960,11 @@ def manage(argv=None):
parser.add_argument(
'--export-ca',
metavar='PEM_FILE',
type=argparse.FileType('wb'),
help='Export all CA certificates in a PEM file. Passphrase will be '
'prompted to protect all keys.',
)
args = parser.parse_args(argv)
stdout = utils.toUnicodeWritableStream(stdout)
db_path = args.db
if args.restore_backup:
(
......@@ -1008,9 +1012,11 @@ def manage(argv=None):
import_ca_dict = defaultdict(
(lambda: {'crt': None, 'key': None, 'from': []}),
)
for ca_file in args.import_ca:
for index, component in enumerate(pem.parse(ca_file.read())):
name = '%r, block %i' % (ca_file.name, index)
for import_ca in args.import_ca:
with open(import_ca, 'rb') as ca_file:
ca_data = ca_file.read()
for index, component in enumerate(pem.parse(ca_data)):
name = '%r, block %i' % (import_ca, index)
if isinstance(component, pem.Certificate):
component_name = 'crt'
component_value = x509.load_pem_x509_certificate(
......@@ -1053,11 +1059,16 @@ def manage(argv=None):
found_from = ', '.join(ca_pair['from'])
crt = ca_pair['crt']
if crt is None:
print(b'No certificate correspond to', found_from, b'- skipping')
print(
'No certificate correspond to',
found_from,
'- skipping',
file=stdout,
)
continue
expiration = utils.datetime2timestamp(crt.not_valid_after)
if expiration < now:
print(b'Skipping expired certificate from', found_from)
print('Skipping expired certificate from', found_from, file=stdout)
del import_ca_dict[identifier]
continue
if not args.import_bad_ca:
......@@ -1076,11 +1087,16 @@ def manage(argv=None):
or not key_usage.key_cert_sign or not key_usage.crl_sign
)
if failed:
print(b'Skipping non-CA certificate from', found_from)
print('Skipping non-CA certificate from', found_from, file=stdout)
continue
key = ca_pair['key']
if key is None:
print(b'No private key correspond to', found_from, b'- skipping')
print(
'No private key correspond to',
found_from,
'- skipping',
file=stdout,
)
continue
imported += 1
cas_db.appendCAKeyPair(
......@@ -1092,7 +1108,7 @@ def manage(argv=None):
)
if not imported:
raise ValueError('No CA certificate imported')
print(b'Imported %i CA certificates' % imported)
print('Imported %i CA certificates' % imported, file=stdout)
if args.import_crl:
db = SQLite3Storage(db_path, table_prefix='cas')
trusted_ca_crt_set = [
......@@ -1104,8 +1120,10 @@ def manage(argv=None):
for x in trusted_ca_crt_set
)
already_revoked_count = revoked_count = 0
for crl_file in args.import_crl:
for revoked in utils.load_crl(crl_file.read(), trusted_ca_crt_set):
for import_crl in args.import_crl:
with open(import_crl, 'rb') as crl_file:
crl_data = crl_file.read()
for revoked in utils.load_crl(crl_data, trusted_ca_crt_set):
try:
db.revoke(
revoked.serial_number,
......@@ -1115,28 +1133,31 @@ def manage(argv=None):
already_revoked_count += 1
else:
revoked_count += 1
print(b'Revoked %i certificates (%i were already revoked)' % (
revoked_count,
already_revoked_count,
))
print(
'Revoked %i certificates (%i were already revoked)' % (
revoked_count,
already_revoked_count,
),
file=stdout,
)
if args.export_ca is not None:
encryption_algorithm = serialization.BestAvailableEncryption(
getBytePass('CA export passphrase: ')
)
write = args.export_ca.write
for key_pair in SQLite3Storage(
db_path,
table_prefix='cas',
).getCAKeyPairList():
write(
key_pair['crt_pem'] + serialization.load_pem_private_key(
key_pair['key_pem'],
None,
_cryptography_backend,
).private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=encryption_algorithm,
),
)
args.export_ca.close()
with open(args.export_ca, 'wb') as export_ca_file:
write = export_ca_file.write
for key_pair in SQLite3Storage(
db_path,
table_prefix='cas',
).getCAKeyPairList():
write(
key_pair['crt_pem'] + serialization.load_pem_private_key(
key_pair['key_pem'],
None,
_cryptography_backend,
).private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=encryption_algorithm,
),
)
......@@ -25,7 +25,6 @@ Test suite
"""
# pylint: disable=too-many-lines, too-many-public-methods
from __future__ import absolute_import
import contextlib
from Cookie import SimpleCookie
import datetime
import errno
......@@ -305,25 +304,14 @@ def print_buffer_on_error(func):
try:
return func(self, *args, **kw)
except Exception: # pragma: no cover
sys.stdout.write(utils.toBytes(os.linesep))
sys.stdout.write(self.caucase_test_output.getvalue())
stdout = utils.toUnicodeWritableStream(sys.stdout)
stdout.write(os.linesep)
stdout.write(
self.caucase_test_output.getvalue().decode('ascii', 'replace'),
)
raise
return wrapper
@contextlib.contextmanager
def captureStdout():
"""
Replace stdout with a BytesIO object for the duration of the context manager,
and provide it to caller.
"""
orig_stdout = sys.stdout
sys.stdout = stdout = BytesIO()
try:
yield stdout
finally:
sys.stdout = orig_stdout
@unittest.skipIf(sys.version_info >= (3, ), 'Caucase currently supports python 2 only')
class CaucaseTest(unittest.TestCase):
"""
Test a complete caucase setup: spawn a caucase-http server on CAUCASE_NETLOC
......@@ -356,6 +344,9 @@ class CaucaseTest(unittest.TestCase):
self._server_backup_path = os.path.join(server_dir, 'backup')
self._server_cors_store = os.path.join(server_dir, 'cors.key')
# pylint: enable=bad-whitespace
# Using a BytesIO for caucased output here, because stdout/stderr do not
# necessarily have a known encoding, for example when output is a pipe
# (to a file, ...). caucased must deal with this.
self.caucase_test_output = BytesIO()
os.mkdir(self._server_backup_path)
......@@ -497,7 +488,7 @@ class CaucaseTest(unittest.TestCase):
)
def _getClientCRL(self):
with open(self._client_crl) as crl_pem_file:
with open(self._client_crl, 'rb') as crl_pem_file:
return x509.load_pem_x509_crl(
crl_pem_file.read(),
_cryptography_backend
......@@ -612,20 +603,24 @@ class CaucaseTest(unittest.TestCase):
Returns stdout.
"""
with captureStdout() as stdout:
try:
cli.main(
argv=(
'--ca-url', self._caucase_url,
'--ca-crt', self._client_ca_crt,
'--user-ca-crt', self._client_user_ca_crt,
'--crl', self._client_crl,
'--user-crl', self._client_user_crl,
) + argv,
)
except SystemExit:
pass
return stdout.getvalue()
# Using a BytesIO for caucased output here, because stdout/stderr do not
# necessarily have a known encoding, for example when output is a pipe
# (to a file, ...). caucase must deal with this.
stdout = BytesIO()
try:
cli.main(
argv=(
'--ca-url', self._caucase_url,
'--ca-crt', self._client_ca_crt,
'--user-ca-crt', self._client_user_ca_crt,
'--crl', self._client_crl,
'--user-crl', self._client_user_crl,
) + argv,
stdout=stdout,
)
except SystemExit:
pass
return stdout.getvalue().decode('ascii')
@staticmethod
def _setCertificateRemainingLifeTime(key, crt, delta):
......@@ -1676,7 +1671,10 @@ class CaucaseTest(unittest.TestCase):
"""
Non-standard shorthand for invoking the WSGI application.
"""
environ.setdefault('wsgi.errors', self.caucase_test_output)
environ.setdefault(
'wsgi.errors',
utils.toUnicodeWritableStream(self.caucase_test_output),
)
environ.setdefault('wsgi.url_scheme', 'http')
environ.setdefault('SERVER_NAME', server_name)
environ.setdefault('SERVER_PORT', str(server_http_port))
......@@ -2294,29 +2292,31 @@ class CaucaseTest(unittest.TestCase):
os.unlink(self._server_db)
os.unlink(self._server_key)
with captureStdout() as stdout:
cli.key_id([
'--private-key', user_key_path, user2_key_path, user2_new_key_path,
])
stdout = BytesIO()
cli.key_id(
['--private-key', user_key_path, user2_key_path, user2_new_key_path],
stdout=stdout,
)
key_id_dict = dict(
line.rsplit(' ', 1)
line.decode('ascii').rsplit(' ', 1)
for line in stdout.getvalue().splitlines()
)
key_id = key_id_dict.pop(user_key_path)
key2_id = key_id_dict.pop(user2_key_path)
new_key2_id = key_id_dict.pop(user2_new_key_path)
self.assertFalse(key_id_dict)
with captureStdout() as stdout:
cli.key_id([
'--backup', backup_path,
])
stdout = BytesIO()
cli.key_id(
['--backup', backup_path],
stdout=stdout,
)
self.assertItemsEqual(
[
backup_path,
' ' + key_id,
' ' + key2_id,
],
stdout.getvalue().splitlines(),
stdout.getvalue().decode('ascii').splitlines(),
)
try:
......@@ -2410,17 +2410,18 @@ class CaucaseTest(unittest.TestCase):
if not backup_path_list: # pragma: no cover
raise AssertionError('Backup file not created after 1 second')
backup_path, = glob.glob(backup_glob)
with captureStdout() as stdout:
cli.key_id([
'--backup', backup_path,
])
stdout = BytesIO()
cli.key_id(
['--backup', backup_path],
stdout=stdout,
)
self.assertItemsEqual(
[
backup_path,
' ' + key_id,
' ' + new_key2_id,
],
stdout.getvalue().splitlines(),
stdout.getvalue().decode('ascii').splitlines(),
)
# Now, push a lot of data to exercise chunked checksum in backup &
......@@ -2444,17 +2445,18 @@ class CaucaseTest(unittest.TestCase):
if not backup_path_list: # pragma: no cover
raise AssertionError('Backup file took too long to be created')
backup_path, = glob.glob(backup_glob)
with captureStdout() as stdout:
cli.key_id([
'--backup', backup_path,
])
stdout = BytesIO()
cli.key_id(
['--backup', backup_path],
stdout=stdout,
)
self.assertItemsEqual(
[
backup_path,
' ' + key_id,
' ' + new_key2_id,
],
stdout.getvalue().splitlines(),
stdout.getvalue().decode('ascii').splitlines(),
)
self._stopServer()
os.unlink(self._server_db)
......@@ -2510,23 +2512,24 @@ class CaucaseTest(unittest.TestCase):
self.assertTrue(os.path.exists(exported_ca), exported_ca)
server_db2 = self._server_db + '2'
self.assertFalse(os.path.exists(server_db2), server_db2)
with captureStdout() as stdout:
caucase.http.manage(
argv=(
'--db', server_db2,
'--import-ca', exported_ca,
'--import-crl', self._client_crl,
# Twice, for code coverage...
'--import-crl', self._client_crl,
),
)
stdout = BytesIO()
caucase.http.manage(
argv=(
'--db', server_db2,
'--import-ca', exported_ca,
'--import-crl', self._client_crl,
# Twice, for code coverage...
'--import-crl', self._client_crl,
),
stdout=stdout,
)
self.assertTrue(os.path.exists(server_db2), server_db2)
self.assertEqual(
[
'Imported 1 CA certificates',
'Revoked 1 certificates (1 were already revoked)',
],
stdout.getvalue().splitlines(),
stdout.getvalue().decode('ascii').splitlines(),
)
finally:
caucase.http.getBytePass = getBytePass_orig
......@@ -2729,7 +2732,7 @@ class CaucaseTest(unittest.TestCase):
until_network_issue = UntilEvent(network_issue_event)
# pylint: disable=protected-access
cli.RetryingCaucaseClient._until = until_network_issue
cli.RetryingCaucaseClient._log_file = self.caucase_test_output
cli.RetryingCaucaseClient._log_file = StringIO()
# pylint: enable=protected-access
until_network_issue.action = ON_EVENT_EXPIRE
original_HTTPConnection = cli.RetryingCaucaseClient.HTTPConnection
......
......@@ -26,6 +26,7 @@ Small-ish functions needed in many places.
from __future__ import absolute_import, print_function
from binascii import a2b_base64, b2a_base64
import calendar
import codecs
from collections import defaultdict
import datetime
import email
......@@ -499,6 +500,17 @@ def toBytes(value, encoding='ascii'):
"""
return value if isinstance(value, bytes) else value.encode(encoding)
def toUnicodeWritableStream(writable_stream, encoding='ascii'):
"""
Convert writable_stream into a writable stream accepting unicode.
If writable_stream already accepts unicode, returns it.
Otherwise, returns a writable stream accepting unicode, and sending it to
writable_stream encoded with given encoding.
"""
if getattr(writable_stream, 'encoding', None) is not None:
return writable_stream
return codecs.getwriter(encoding)(writable_stream)
def interruptibleSleep(duration): # pragma: no cover
"""
Like sleep, but raises SleepInterrupt when interrupted by KeyboardInterrupt
......
......@@ -20,15 +20,10 @@
# See https://www.nexedi.com/licensing for rationale and options.
from setuptools import setup, find_packages
import glob
import os
import sys
import versioneer
long_description = open("README.rst").read() + "\n"
for f in sorted(glob.glob(os.path.join('caucase', 'README.*.rst'))):
long_description += '\n' + open(f).read() + '\n'
long_description += open("CHANGES.txt").read() + "\n"
with open("README.rst") as readme, open("CHANGES.txt") as changes:
long_description = readme.read() + "\n" + changes.read() + "\n"
setup(
name='caucase',
......@@ -71,5 +66,5 @@ setup(
]
},
test_suite='caucase.test',
use_2to3=sys.version_info >= (3, ),
use_2to3=True,
)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment