Commit 76e8a196 authored by Łukasz Nowak's avatar Łukasz Nowak Committed by Łukasz Nowak

promise/plugin: Implement enhancements to check_url_available

Enhancements:

 * implement ignore-code
 * cover http_code
 * cover timeout
 * implement verify
 * correctly cover ca-cert-file
 * use SLAPOS_TEST_IPV4
parent c7c26219
...@@ -66,6 +66,7 @@ setup(name=name, ...@@ -66,6 +66,7 @@ setup(name=name,
}, },
tests_require = [ tests_require = [
'mock', 'mock',
'cryptography',
], ],
zip_safe=False, # proxy depends on Flask, which has issues with zip_safe=False, # proxy depends on Flask, which has issues with
# accessing templates # accessing templates
......
...@@ -25,9 +25,12 @@ class RunPromise(GenericPromise): ...@@ -25,9 +25,12 @@ class RunPromise(GenericPromise):
ca_cert_file = self.getConfig('ca-cert-file') ca_cert_file = self.getConfig('ca-cert-file')
cert_file = self.getConfig('cert-file') cert_file = self.getConfig('cert-file')
key_file = self.getConfig('key-file') key_file = self.getConfig('key-file')
verify = int(self.getConfig('verify', 0))
if ca_cert_file: if ca_cert_file:
verify = ca_cert_file verify = ca_cert_file
elif verify:
verify = True
else: else:
verify = False verify = False
...@@ -39,6 +42,14 @@ class RunPromise(GenericPromise): ...@@ -39,6 +42,14 @@ class RunPromise(GenericPromise):
try: try:
result = requests.get( result = requests.get(
url, verify=verify, allow_redirects=True, timeout=timeout, cert=cert) url, verify=verify, allow_redirects=True, timeout=timeout, cert=cert)
except requests.exceptions.SSLError as e:
if 'certificate verify failed' in str(e.message):
self.logger.error(
"ERROR SSL verify failed while accessing %r" % (url,))
else:
self.logger.error(
"ERROR Unknown SSL error %r while accessing %r" % (e, url))
return
except requests.ConnectionError as e: except requests.ConnectionError as e:
self.logger.error( self.logger.error(
"ERROR connection not possible while accessing %r" % (url, )) "ERROR connection not possible while accessing %r" % (url, ))
...@@ -49,13 +60,11 @@ class RunPromise(GenericPromise): ...@@ -49,13 +60,11 @@ class RunPromise(GenericPromise):
http_code = result.status_code http_code = result.status_code
check_secure = int(self.getConfig('check-secure', 0)) check_secure = int(self.getConfig('check-secure', 0))
ignore_code = int(self.getConfig('ignore-code', 0))
if http_code == 0: if http_code == 401 and check_secure == 1:
self.logger.error("%s is not available (server not reachable)." % url)
elif http_code == 401 and check_secure == 1:
self.logger.info("%r is protected (returned %s)." % (url, http_code)) self.logger.info("%r is protected (returned %s)." % (url, http_code))
elif not ignore_code and http_code != expected_http_code:
elif http_code != expected_http_code:
self.logger.error("%r is not available (returned %s, expected %s)." % ( self.logger.error("%r is not available (returned %s, expected %s)." % (
url, http_code, expected_http_code)) url, http_code, expected_http_code))
else: else:
......
...@@ -28,22 +28,118 @@ ...@@ -28,22 +28,118 @@
from slapos.grid.promise import PromiseError from slapos.grid.promise import PromiseError
from slapos.test.promise.plugin import TestPromisePluginMixin from slapos.test.promise.plugin import TestPromisePluginMixin
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
import BaseHTTPServer import BaseHTTPServer
import datetime
import ipaddress
import json import json
import multiprocessing import multiprocessing
import os
import ssl
import tempfile
import time import time
import unittest import unittest
SLAPOS_TEST_IPV4 = '127.0.0.1' SLAPOS_TEST_IPV4 = os.environ.get('SLAPOS_TEST_IPV4', '127.0.0.1')
SLAPOS_TEST_IPV4_PORT = 57965 SLAPOS_TEST_IPV4_PORT = 57965
HTTPS_ENDPOINT = "http://%s:%s/" % (SLAPOS_TEST_IPV4, SLAPOS_TEST_IPV4_PORT) HTTPS_ENDPOINT = "https://%s:%s/" % (SLAPOS_TEST_IPV4, SLAPOS_TEST_IPV4_PORT)
def createKey():
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend())
key_pem = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
return key, key_pem
def createCSR(common_name, ip=None):
key, key_pem = createKey()
subject_alternative_name_list = []
if ip is not None:
subject_alternative_name_list.append(
x509.IPAddress(ipaddress.ip_address(unicode(ip)))
)
csr = x509.CertificateSigningRequestBuilder().subject_name(x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, unicode(common_name)),
]))
if len(subject_alternative_name_list):
csr = csr.add_extension(
x509.SubjectAlternativeName(subject_alternative_name_list),
critical=False
)
csr = csr.sign(key, hashes.SHA256(), default_backend())
csr_pem = csr.public_bytes(serialization.Encoding.PEM)
return key, key_pem, csr, csr_pem
class CertificateAuthority(object):
def __init__(self, common_name):
self.key, self.key_pem = createKey()
public_key = self.key.public_key()
builder = x509.CertificateBuilder()
builder = builder.subject_name(x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, unicode(common_name)),
]))
builder = builder.issuer_name(x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, unicode(common_name)),
]))
builder = builder.not_valid_before(
datetime.datetime.utcnow() - datetime.timedelta(days=2))
builder = builder.not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=30))
builder = builder.serial_number(x509.random_serial_number())
builder = builder.public_key(public_key)
builder = builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None), critical=True,
)
self.certificate = builder.sign(
private_key=self.key, algorithm=hashes.SHA256(),
backend=default_backend()
)
self.certificate_pem = self.certificate.public_bytes(
serialization.Encoding.PEM)
def signCSR(self, csr):
builder = x509.CertificateBuilder(
subject_name=csr.subject,
extensions=csr.extensions,
issuer_name=self.certificate.subject,
not_valid_before=datetime.datetime.utcnow() - datetime.timedelta(days=1),
not_valid_after=datetime.datetime.utcnow() + datetime.timedelta(days=30),
serial_number=x509.random_serial_number(),
public_key=csr.public_key(),
)
certificate = builder.sign(
private_key=self.key,
algorithm=hashes.SHA256(),
backend=default_backend()
)
return certificate, certificate.public_bytes(serialization.Encoding.PEM)
class TestHandler(BaseHTTPServer.BaseHTTPRequestHandler): class TestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
timeout = int(self.headers.dict.get('timeout', '0')) path = self.path.split('/')[-1]
if '_' in path:
response, timeout = path.split('_')
response = int(response)
timeout = int(timeout)
else:
timeout = 0
response = int(path)
time.sleep(timeout) time.sleep(timeout)
response = int(self.path.split('/')[-1])
self.send_response(response) self.send_response(response)
self.send_header("Content-type", "application/json") self.send_header("Content-type", "application/json")
...@@ -54,14 +150,40 @@ class TestHandler(BaseHTTPServer.BaseHTTPRequestHandler): ...@@ -54,14 +150,40 @@ class TestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
self.wfile.write(json.dumps(response, indent=2)) self.wfile.write(json.dumps(response, indent=2))
class TestCheckUrlAvailable(TestPromisePluginMixin): class CheckUrlAvailableMixin(TestPromisePluginMixin):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.another_server_ca = CertificateAuthority("Another Server Root CA")
cls.test_server_ca = CertificateAuthority("Test Server Root CA")
key, key_pem, csr, csr_pem = createCSR(
"testserver.example.com", SLAPOS_TEST_IPV4)
_, cls.test_server_certificate_pem = cls.test_server_ca.signCSR(csr)
cls.test_server_certificate_file = tempfile.NamedTemporaryFile(
delete=False
)
cls.test_server_certificate_file.write(
cls.test_server_certificate_pem + key_pem
)
cls.test_server_certificate_file.close()
cls.test_server_ca_certificate_file = tempfile.NamedTemporaryFile(
delete=False
)
cls.test_server_ca_certificate_file.write(
cls.test_server_ca.certificate_pem)
cls.test_server_ca_certificate_file.close()
server = BaseHTTPServer.HTTPServer( server = BaseHTTPServer.HTTPServer(
(SLAPOS_TEST_IPV4, SLAPOS_TEST_IPV4_PORT), (SLAPOS_TEST_IPV4, SLAPOS_TEST_IPV4_PORT),
TestHandler) TestHandler)
server.socket = ssl.wrap_socket(
server.socket,
certfile=cls.test_server_certificate_file.name,
server_side=True)
cls.server_process = multiprocessing.Process( cls.server_process = multiprocessing.Process(
target=server.serve_forever) target=server.serve_forever)
cls.server_process.start() cls.server_process.start()
...@@ -70,6 +192,14 @@ class TestCheckUrlAvailable(TestPromisePluginMixin): ...@@ -70,6 +192,14 @@ class TestCheckUrlAvailable(TestPromisePluginMixin):
def tearDownClass(cls): def tearDownClass(cls):
cls.server_process.terminate() cls.server_process.terminate()
cls.server_process.join() cls.server_process.join()
for p in [
cls.test_server_certificate_file.name,
cls.test_server_ca_certificate_file.name,
]:
try:
os.unlink(p)
except Exception:
pass
def setUp(self): def setUp(self):
TestPromisePluginMixin.setUp(self) TestPromisePluginMixin.setUp(self)
...@@ -81,17 +211,55 @@ extra_config_dict = { ...@@ -81,17 +211,55 @@ extra_config_dict = {
'url': '%(url)s', 'url': '%(url)s',
'timeout': %(timeout)s, 'timeout': %(timeout)s,
'check-secure': %(check_secure)s, 'check-secure': %(check_secure)s,
'ignore-code': %(ignore_code)s,
}
"""
self.base_content_verify = """from slapos.promise.plugin.check_url_available import RunPromise
extra_config_dict = {
'url': '%(url)s',
'timeout': %(timeout)s,
'check-secure': %(check_secure)s,
'ignore-code': %(ignore_code)s,
'verify': %(verify)s,
}
"""
self.base_content_ca_cert = """from slapos.promise.plugin.check_url_available import RunPromise
extra_config_dict = {
'url': '%(url)s',
'timeout': %(timeout)s,
'check-secure': %(check_secure)s,
'ignore-code': %(ignore_code)s,
'ca-cert-file': %(ca_cert_file)r,
}
"""
self.base_content_http_code = """from slapos.promise.plugin.check_url_available import RunPromise
extra_config_dict = {
'url': '%(url)s',
'timeout': %(timeout)s,
'check-secure': %(check_secure)s,
'ignore-code': %(ignore_code)s,
'http_code': %(http_code)s
} }
""" """
def tearDown(self): def tearDown(self):
TestPromisePluginMixin.tearDown(self) TestPromisePluginMixin.tearDown(self)
class TestCheckUrlAvailable(CheckUrlAvailableMixin):
def test_check_url_bad(self): def test_check_url_bad(self):
content = self.base_content % { content = self.base_content % {
'url': 'https://', 'url': 'https://',
'timeout': 10, 'timeout': 10,
'check_secure': 0 'check_secure': 0,
'ignore_code': 0,
} }
self.writePromise(self.promise_name, content) self.writePromise(self.promise_name, content)
self.configureLauncher() self.configureLauncher()
...@@ -108,7 +276,8 @@ extra_config_dict = { ...@@ -108,7 +276,8 @@ extra_config_dict = {
content = self.base_content % { content = self.base_content % {
'url': '', 'url': '',
'timeout': 10, 'timeout': 10,
'check_secure': 0 'check_secure': 0,
'ignore_code': 0,
} }
self.writePromise(self.promise_name, content) self.writePromise(self.promise_name, content)
self.configureLauncher() self.configureLauncher()
...@@ -122,10 +291,11 @@ extra_config_dict = { ...@@ -122,10 +291,11 @@ extra_config_dict = {
) )
def test_check_url_site_off(self): def test_check_url_site_off(self):
content = content = self.base_content % { content = self.base_content % {
'url': 'https://localhost:56789/site', 'url': 'https://localhost:56789/site',
'timeout': 10, 'timeout': 10,
'check_secure': 0 'check_secure': 0,
'ignore_code': 0,
} }
self.writePromise(self.promise_name, content) self.writePromise(self.promise_name, content)
self.configureLauncher() self.configureLauncher()
...@@ -141,10 +311,80 @@ extra_config_dict = { ...@@ -141,10 +311,80 @@ extra_config_dict = {
def test_check_200(self): def test_check_200(self):
url = HTTPS_ENDPOINT + '200' url = HTTPS_ENDPOINT + '200'
content = content = self.base_content % { content = self.base_content % {
'url': url, 'url': url,
'timeout': 10, 'timeout': 10,
'check_secure': 0 'check_secure': 0,
'ignore_code': 0,
}
self.writePromise(self.promise_name, content)
self.configureLauncher()
self.launcher.run()
result = self.getPromiseResult(self.promise_name)
self.assertEqual(result['result']['failed'], False)
self.assertEqual(
result['result']['message'],
"%r is available" % (url,)
)
def test_check_200_verify(self):
url = HTTPS_ENDPOINT + '200'
content = self.base_content_verify % {
'url': url,
'timeout': 10,
'check_secure': 0,
'ignore_code': 0,
'verify': 1,
}
try:
old = os.environ.get('REQUESTS_CA_BUNDLE')
# simulate system provided CA bundle
os.environ[
'REQUESTS_CA_BUNDLE'] = self.test_server_ca_certificate_file.name
self.writePromise(self.promise_name, content)
self.configureLauncher()
self.launcher.run()
finally:
if old is None:
del os.environ['REQUESTS_CA_BUNDLE']
else:
os.environ['REQUESTS_CA_BUNDLE'] = old
result = self.getPromiseResult(self.promise_name)
self.assertEqual(result['result']['failed'], False)
self.assertEqual(
result['result']['message'],
"%r is available" % (url,)
)
def test_check_200_verify_fail(self):
url = HTTPS_ENDPOINT + '200'
content = self.base_content_verify % {
'url': url,
'timeout': 10,
'check_secure': 0,
'ignore_code': 0,
'verify': 1,
}
self.writePromise(self.promise_name, content)
self.configureLauncher()
with self.assertRaises(PromiseError):
self.launcher.run()
result = self.getPromiseResult(self.promise_name)
self.assertEqual(result['result']['failed'], True)
self.assertEqual(
result['result']['message'],
"ERROR SSL verify failed while accessing %r" % (url,)
)
def test_check_200_verify_own(self):
url = HTTPS_ENDPOINT + '200'
content = self.base_content_ca_cert % {
'url': url,
'timeout': 10,
'check_secure': 0,
'ignore_code': 0,
'ca_cert_file': self.test_server_ca_certificate_file.name
} }
self.writePromise(self.promise_name, content) self.writePromise(self.promise_name, content)
self.configureLauncher() self.configureLauncher()
...@@ -158,10 +398,11 @@ extra_config_dict = { ...@@ -158,10 +398,11 @@ extra_config_dict = {
def test_check_401(self): def test_check_401(self):
url = HTTPS_ENDPOINT + '401' url = HTTPS_ENDPOINT + '401'
content = content = self.base_content % { content = self.base_content % {
'url': url, 'url': url,
'timeout': 10, 'timeout': 10,
'check_secure': 0 'check_secure': 0,
'ignore_code': 0,
} }
self.writePromise(self.promise_name, content) self.writePromise(self.promise_name, content)
self.configureLauncher() self.configureLauncher()
...@@ -174,12 +415,31 @@ extra_config_dict = { ...@@ -174,12 +415,31 @@ extra_config_dict = {
"%r is not available (returned 401, expected 200)." % (url,) "%r is not available (returned 401, expected 200)." % (url,)
) )
def test_check_401_secure(self): def test_check_401_ignore_code(self):
url = HTTPS_ENDPOINT + '401'
content = self.base_content % {
'url': url,
'timeout': 10,
'check_secure': 0,
'ignore_code': 1,
}
self.writePromise(self.promise_name, content)
self.configureLauncher()
self.launcher.run()
result = self.getPromiseResult(self.promise_name)
self.assertEqual(result['result']['failed'], False)
self.assertEqual(
result['result']['message'],
"%r is available" % (url,)
)
def test_check_401_check_secure(self):
url = HTTPS_ENDPOINT + '401' url = HTTPS_ENDPOINT + '401'
content = content = self.base_content % { content = self.base_content % {
'url': url, 'url': url,
'timeout': 10, 'timeout': 10,
'check_secure': 1 'check_secure': 1,
'ignore_code': 0,
} }
self.writePromise(self.promise_name, content) self.writePromise(self.promise_name, content)
self.configureLauncher() self.configureLauncher()
...@@ -191,6 +451,46 @@ extra_config_dict = { ...@@ -191,6 +451,46 @@ extra_config_dict = {
"%r is protected (returned 401)." % (url,) "%r is protected (returned 401)." % (url,)
) )
def test_check_512_http_code(self):
url = HTTPS_ENDPOINT + '512'
content = self.base_content_http_code % {
'url': url,
'timeout': 10,
'check_secure': 0,
'ignore_code': 0,
'http_code': 512,
}
self.writePromise(self.promise_name, content)
self.configureLauncher()
self.launcher.run()
result = self.getPromiseResult(self.promise_name)
self.assertEqual(result['result']['failed'], False)
self.assertEqual(
result['result']['message'],
"%r is available" % (url,)
)
class TestCheckUrlAvailableTimeout(CheckUrlAvailableMixin):
def test_check_200_timeout(self):
url = HTTPS_ENDPOINT + '200_5'
content = self.base_content % {
'url': url,
'timeout': 1,
'check_secure': 0,
'ignore_code': 0,
}
self.writePromise(self.promise_name, content)
self.configureLauncher()
with self.assertRaises(PromiseError):
self.launcher.run()
result = self.getPromiseResult(self.promise_name)
self.assertEqual(result['result']['failed'], True)
self.assertEqual(
result['result']['message'],
"Error: Promise timed out after 0.5 seconds",
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment