Commit 4dc7faf6 authored by Romain Courteaud's avatar Romain Courteaud

More tests

parent 11840882
import unittest
from urlchecker_db import LogDB
import peewee
class UrlCheckerNetworkTestCase(unittest.TestCase):
def setUp(self):
self.db = LogDB(":memory:")
def test_createTable(self):
assert self.db._db.pragma("user_version") == 0
self.db.createTables()
assert self.db._db.pragma("user_version") == 1
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(UrlCheckerDBTestCase))
return suite
if __name__ == "__main__":
unittest.main(defaultTest="suite")
import unittest import unittest
from urlchecker_db import LogDB from urlchecker_db import LogDB
import urlchecker_http import urlchecker_http
from urlchecker_http import getUrlHostname, getUserAgent, request, logHttpStatus, checkHttpStatus from urlchecker_http import (
getUrlHostname,
getUserAgent,
request,
logHttpStatus,
checkHttpStatus,
)
from urlchecker_status import logStatus from urlchecker_status import logStatus
import httpretty import httpretty
import mock import mock
...@@ -25,41 +31,51 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -25,41 +31,51 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
################################################ ################################################
def test_getUserAgent_default(self): def test_getUserAgent_default(self):
result = getUserAgent() result = getUserAgent()
assert result == "URLCHECKER/0 (+https://lab.nexedi.com/romain/url-checker)" assert (
result
== "URLCHECKER/0 (+https://lab.nexedi.com/romain/url-checker)"
)
def test_getUserAgent_default(self): def test_getUserAgent_default(self):
result = getUserAgent(None) result = getUserAgent(None)
assert result == "URLCHECKER/0 (+https://lab.nexedi.com/romain/url-checker)" assert (
result
== "URLCHECKER/0 (+https://lab.nexedi.com/romain/url-checker)"
)
def test_getUserAgent_default(self): def test_getUserAgent_default(self):
result = getUserAgent("0.0.3") result = getUserAgent("0.0.3")
assert result == "URLCHECKER/0.0.3 (+https://lab.nexedi.com/romain/url-checker)" assert (
result
== "URLCHECKER/0.0.3 (+https://lab.nexedi.com/romain/url-checker)"
)
################################################ ################################################
# request # request
################################################ ################################################
def test_request_arguments(self): def test_request_arguments(self):
url_to_proxy = 'http://example.org/' url_to_proxy = "http://example.org/"
with mock.patch("urlchecker_http.requests.request") as mock_request: with mock.patch("urlchecker_http.requests.request") as mock_request:
response = request( response = request(url_to_proxy)
url_to_proxy
)
assert mock_request.call_count == 1 assert mock_request.call_count == 1
mock_request.assert_called_with('GET', url_to_proxy, allow_redirects=False, mock_request.assert_called_with(
headers={'Accept': 'text/html;q=0.9,*/*;q=0.8', 'User-Agent': 'URLCHECKER/0 (+https://lab.nexedi.com/romain/url-checker)'}, "GET",
stream=False, timeout=2, verify=True) url_to_proxy,
allow_redirects=False,
headers={
"Accept": "text/html;q=0.9,*/*;q=0.8",
"User-Agent": "URLCHECKER/0 (+https://lab.nexedi.com/romain/url-checker)",
},
stream=False,
timeout=2,
verify=True,
)
@httpretty.activate @httpretty.activate
def test_request_defaultHeaders(self): def test_request_defaultHeaders(self):
url_to_proxy = 'http://example.org/' url_to_proxy = "http://example.org/"
httpretty.register_uri( httpretty.register_uri(httpretty.GET, url_to_proxy, status=418)
httpretty.GET, response = request(url_to_proxy)
url_to_proxy,
status=418
)
response = request(
url_to_proxy
)
last_request = httpretty.last_request() last_request = httpretty.last_request()
assert len(last_request.headers) == 5, last_request.headers.keys() assert len(last_request.headers) == 5, last_request.headers.keys()
...@@ -67,24 +83,24 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -67,24 +83,24 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
assert last_request.headers["Accept-Encoding"] == "gzip, deflate" assert last_request.headers["Accept-Encoding"] == "gzip, deflate"
assert last_request.headers["Connection"] == "keep-alive" assert last_request.headers["Connection"] == "keep-alive"
assert last_request.headers["Host"] == "example.org" assert last_request.headers["Host"] == "example.org"
assert last_request.headers["User-Agent"] == "URLCHECKER/0 (+https://lab.nexedi.com/romain/url-checker)" assert (
last_request.headers["User-Agent"]
== "URLCHECKER/0 (+https://lab.nexedi.com/romain/url-checker)"
)
assert len(last_request.body) == 0 assert len(last_request.body) == 0
assert response.status_code == 418 assert response.status_code == 418
@httpretty.activate @httpretty.activate
def test_request_customHeaders(self): def test_request_customHeaders(self):
url_to_proxy = 'http://example.org/' url_to_proxy = "http://example.org/"
httpretty.register_uri( httpretty.register_uri(httpretty.GET, url_to_proxy)
httpretty.GET,
url_to_proxy,
)
request( request(
url_to_proxy, url_to_proxy,
headers={ headers={
'foo': 'bar', "foo": "bar",
'User-Agent': 'foouseragent', "User-Agent": "foouseragent",
'Accept': 'fooaccept' "Accept": "fooaccept",
} },
) )
last_request = httpretty.last_request() last_request = httpretty.last_request()
assert len(last_request.headers) == 6, last_request.headers.keys() assert len(last_request.headers) == 6, last_request.headers.keys()
...@@ -97,66 +113,54 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -97,66 +113,54 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
assert len(last_request.body) == 0 assert len(last_request.body) == 0
def test_request_connectionError(self): def test_request_connectionError(self):
url_to_proxy = 'http://example.org/' url_to_proxy = "http://example.org/"
httpretty.register_uri( httpretty.register_uri(httpretty.GET, url_to_proxy)
httpretty.GET,
url_to_proxy
)
with mock.patch("urlchecker_http.requests.request") as mock_request: with mock.patch("urlchecker_http.requests.request") as mock_request:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise urlchecker_http.requests.exceptions.ConnectionError() raise urlchecker_http.requests.exceptions.ConnectionError()
mock_request.side_effect = sideEffect mock_request.side_effect = sideEffect
response = request( response = request(url_to_proxy)
url_to_proxy
)
assert mock_request.call_count == 1 assert mock_request.call_count == 1
assert response.status_code == 523, response.status_code assert response.status_code == 523, response.status_code
def test_request_timeout(self): def test_request_timeout(self):
url_to_proxy = 'http://example.org/' url_to_proxy = "http://example.org/"
httpretty.register_uri( httpretty.register_uri(httpretty.GET, url_to_proxy)
httpretty.GET,
url_to_proxy
)
with mock.patch("urlchecker_http.requests.request") as mock_request: with mock.patch("urlchecker_http.requests.request") as mock_request:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise urlchecker_http.requests.exceptions.Timeout() raise urlchecker_http.requests.exceptions.Timeout()
mock_request.side_effect = sideEffect mock_request.side_effect = sideEffect
response = request( response = request(url_to_proxy)
url_to_proxy
)
assert mock_request.call_count == 1 assert mock_request.call_count == 1
assert response.status_code == 524, response.status_code assert response.status_code == 524, response.status_code
def test_request_tooManyRedirect(self): def test_request_tooManyRedirect(self):
url_to_proxy = 'http://example.org/' url_to_proxy = "http://example.org/"
httpretty.register_uri( httpretty.register_uri(httpretty.GET, url_to_proxy)
httpretty.GET,
url_to_proxy
)
with mock.patch("urlchecker_http.requests.request") as mock_request: with mock.patch("urlchecker_http.requests.request") as mock_request:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise urlchecker_http.requests.exceptions.TooManyRedirects() raise urlchecker_http.requests.exceptions.TooManyRedirects()
mock_request.side_effect = sideEffect mock_request.side_effect = sideEffect
response = request( response = request(url_to_proxy)
url_to_proxy
)
assert mock_request.call_count == 1 assert mock_request.call_count == 1
assert response.status_code == 520, response.status_code assert response.status_code == 520, response.status_code
def test_request_sslError(self): def test_request_sslError(self):
url_to_proxy = 'http://example.org/' url_to_proxy = "http://example.org/"
httpretty.register_uri( httpretty.register_uri(httpretty.GET, url_to_proxy)
httpretty.GET,
url_to_proxy
)
with mock.patch("urlchecker_http.requests.request") as mock_request: with mock.patch("urlchecker_http.requests.request") as mock_request:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise urlchecker_http.requests.exceptions.SSLError() raise urlchecker_http.requests.exceptions.SSLError()
mock_request.side_effect = sideEffect mock_request.side_effect = sideEffect
response = request( response = request(url_to_proxy)
url_to_proxy
)
assert mock_request.call_count == 1 assert mock_request.call_count == 1
assert response.status_code == 526, response.status_code assert response.status_code == 526, response.status_code
...@@ -187,7 +191,7 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -187,7 +191,7 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
assert self.db.HttpCodeChange.select().count() == 1 assert self.db.HttpCodeChange.select().count() == 1
assert self.db.HttpCodeChange.get().status == result assert self.db.HttpCodeChange.get().status == result
else: else:
raise NotImplementedError('Expected IntegrityError') raise NotImplementedError("Expected IntegrityError")
def test_logHttpStatus_skipIdenticalPreviousValues(self): def test_logHttpStatus_skipIdenticalPreviousValues(self):
ip = "127.0.0.1" ip = "127.0.0.1"
...@@ -215,12 +219,42 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -215,12 +219,42 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
result_2 = logHttpStatus(self.db, ip, url, status_code_2, status_id_2) result_2 = logHttpStatus(self.db, ip, url, status_code_2, status_id_2)
assert result_2 != result assert result_2 != result
assert self.db.HttpCodeChange.select().count() == 2 assert self.db.HttpCodeChange.select().count() == 2
assert self.db.HttpCodeChange.get(self.db.HttpCodeChange.status == status_id).ip == ip assert (
assert self.db.HttpCodeChange.get(self.db.HttpCodeChange.status == status_id).url == url self.db.HttpCodeChange.get(
assert self.db.HttpCodeChange.get(self.db.HttpCodeChange.status == status_id).status_code == status_code self.db.HttpCodeChange.status == status_id
assert self.db.HttpCodeChange.get(self.db.HttpCodeChange.status == status_id_2).ip == ip ).ip
assert self.db.HttpCodeChange.get(self.db.HttpCodeChange.status == status_id_2).url == url == ip
assert self.db.HttpCodeChange.get(self.db.HttpCodeChange.status == status_id_2).status_code == status_code_2 )
assert (
self.db.HttpCodeChange.get(
self.db.HttpCodeChange.status == status_id
).url
== url
)
assert (
self.db.HttpCodeChange.get(
self.db.HttpCodeChange.status == status_id
).status_code
== status_code
)
assert (
self.db.HttpCodeChange.get(
self.db.HttpCodeChange.status == status_id_2
).ip
== ip
)
assert (
self.db.HttpCodeChange.get(
self.db.HttpCodeChange.status == status_id_2
).url
== url
)
assert (
self.db.HttpCodeChange.get(
self.db.HttpCodeChange.status == status_id_2
).status_code
== status_code_2
)
def test_logHttpStatus_insertDifferentUrl(self): def test_logHttpStatus_insertDifferentUrl(self):
ip = "127.0.0.1" ip = "127.0.0.1"
...@@ -244,9 +278,7 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -244,9 +278,7 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
url = "http://example.org/foo?bar=1" url = "http://example.org/foo?bar=1"
bot_version = 1 bot_version = 1
httpretty.register_uri( httpretty.register_uri(
httpretty.GET, httpretty.GET, "http://127.0.0.1/foo?bar=1", status=418
"http://127.0.0.1/foo?bar=1",
status=418
) )
status_id = logStatus(self.db, "foo") status_id = logStatus(self.db, "foo")
...@@ -258,7 +290,10 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -258,7 +290,10 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
assert last_request.headers["Accept-Encoding"] == "gzip, deflate" assert last_request.headers["Accept-Encoding"] == "gzip, deflate"
assert last_request.headers["Connection"] == "keep-alive" assert last_request.headers["Connection"] == "keep-alive"
assert last_request.headers["Host"] == "example.org" assert last_request.headers["Host"] == "example.org"
assert last_request.headers["User-Agent"] == "URLCHECKER/1 (+https://lab.nexedi.com/romain/url-checker)" assert (
last_request.headers["User-Agent"]
== "URLCHECKER/1 (+https://lab.nexedi.com/romain/url-checker)"
)
assert len(last_request.body) == 0 assert len(last_request.body) == 0
assert self.db.HttpCodeChange.select().count() == 1 assert self.db.HttpCodeChange.select().count() == 1
...@@ -277,11 +312,17 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -277,11 +312,17 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
checkHttpStatus(self.db, status_id, url, ip, bot_version) checkHttpStatus(self.db, status_id, url, ip, bot_version)
assert mock_request.call_count == 1 assert mock_request.call_count == 1
assert mock_request.call_args.args == ('https://example.org/foo?bar=1',) assert mock_request.call_args.args == (
assert len(mock_request.call_args.kwargs) == 3, mock_request.call_args.kwargs "https://example.org/foo?bar=1",
assert mock_request.call_args.kwargs['headers'] == {'Host': 'example.org'} )
assert mock_request.call_args.kwargs['session'] is not None assert (
assert mock_request.call_args.kwargs['version'] == 2 len(mock_request.call_args.kwargs) == 3
), mock_request.call_args.kwargs
assert mock_request.call_args.kwargs["headers"] == {
"Host": "example.org"
}
assert mock_request.call_args.kwargs["session"] is not None
assert mock_request.call_args.kwargs["version"] == 2
assert self.db.HttpCodeChange.select().count() == 1 assert self.db.HttpCodeChange.select().count() == 1
assert self.db.HttpCodeChange.get().ip == ip assert self.db.HttpCodeChange.get().ip == ip
...@@ -299,9 +340,9 @@ class UrlCheckerHttpTestCase(unittest.TestCase): ...@@ -299,9 +340,9 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
try: try:
checkHttpStatus(self.db, status_id, url, ip, bot_version) checkHttpStatus(self.db, status_id, url, ip, bot_version)
except NotImplementedError as err: except NotImplementedError as err:
assert str(err) == 'Unhandled url: foo?bar=1' assert str(err) == "Unhandled url: foo?bar=1"
else: else:
raise NotImplementedError('Expected NotImplementedError') raise NotImplementedError("Expected NotImplementedError")
def suite(): def suite():
......
...@@ -39,12 +39,12 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -39,12 +39,12 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
result = logNetwork(self.db, ip, transport, port, state, status_id) result = logNetwork(self.db, ip, transport, port, state, status_id)
try: try:
logNetwork(self.db, ip, transport, port, state + '.', status_id) logNetwork(self.db, ip, transport, port, state + ".", status_id)
except peewee.IntegrityError: except peewee.IntegrityError:
assert self.db.NetworkChange.select().count() == 1 assert self.db.NetworkChange.select().count() == 1
assert self.db.NetworkChange.get().status_id == status_id assert self.db.NetworkChange.get().status_id == status_id
else: else:
raise NotImplementedError('Expected IntegrityError') raise NotImplementedError("Expected IntegrityError")
def test_logNetwork_skipIdenticalPreviousValues(self): def test_logNetwork_skipIdenticalPreviousValues(self):
ip = "127.0.0.1" ip = "127.0.0.1"
...@@ -74,19 +74,71 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -74,19 +74,71 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
state_2 = state + "." state_2 = state + "."
status_id_2 = logStatus(self.db, "foo") status_id_2 = logStatus(self.db, "foo")
result_2 = logNetwork(self.db, ip, transport, port, state_2, status_id_2) result_2 = logNetwork(
self.db, ip, transport, port, state_2, status_id_2
)
assert result_2 != result assert result_2 != result
assert self.db.NetworkChange.select().count() == 2 assert self.db.NetworkChange.select().count() == 2
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id).ip == ip assert (
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id).port == port self.db.NetworkChange.get(
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id).transport == transport self.db.NetworkChange.status == status_id
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id).state == state ).ip
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id).status_id == status_id == ip
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id_2).ip == ip )
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id_2).port == port assert (
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id_2).transport == transport self.db.NetworkChange.get(
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id_2).state == state_2 self.db.NetworkChange.status == status_id
assert self.db.NetworkChange.get(self.db.NetworkChange.status == status_id_2).status_id == status_id_2 ).port
== port
)
assert (
self.db.NetworkChange.get(
self.db.NetworkChange.status == status_id
).transport
== transport
)
assert (
self.db.NetworkChange.get(
self.db.NetworkChange.status == status_id
).state
== state
)
assert (
self.db.NetworkChange.get(
self.db.NetworkChange.status == status_id
).status_id
== status_id
)
assert (
self.db.NetworkChange.get(
self.db.NetworkChange.status == status_id_2
).ip
== ip
)
assert (
self.db.NetworkChange.get(
self.db.NetworkChange.status == status_id_2
).port
== port
)
assert (
self.db.NetworkChange.get(
self.db.NetworkChange.status == status_id_2
).transport
== transport
)
assert (
self.db.NetworkChange.get(
self.db.NetworkChange.status == status_id_2
).state
== state_2
)
assert (
self.db.NetworkChange.get(
self.db.NetworkChange.status == status_id_2
).status_id
== status_id_2
)
def test_logNetwork_insertDifferentKeys(self): def test_logNetwork_insertDifferentKeys(self):
ip = "127.0.0.1" ip = "127.0.0.1"
...@@ -94,7 +146,7 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -94,7 +146,7 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
port = 1234 port = 1234
port_2 = port + 1 port_2 = port + 1
transport = "foobar" transport = "foobar"
transport_2 = transport + '.' transport_2 = transport + "."
state = "bar" state = "bar"
status_id = logStatus(self.db, "foo") status_id = logStatus(self.db, "foo")
...@@ -132,8 +184,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -132,8 +184,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
assert self.db.NetworkChange.select().count() == 1 assert self.db.NetworkChange.select().count() == 1
assert self.db.NetworkChange.get().ip == ip assert self.db.NetworkChange.get().ip == ip
assert self.db.NetworkChange.get().port == port assert self.db.NetworkChange.get().port == port
assert self.db.NetworkChange.get().transport == 'TCP' assert self.db.NetworkChange.get().transport == "TCP"
assert self.db.NetworkChange.get().state == 'open' assert self.db.NetworkChange.get().state == "open"
assert self.db.NetworkChange.get().status_id == status_id assert self.db.NetworkChange.get().status_id == status_id
assert result == True assert result == True
...@@ -143,8 +195,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -143,8 +195,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
status_id = logStatus(self.db, "foo") status_id = logStatus(self.db, "foo")
with mock.patch("urlchecker_network.socket.socket") as mock_socket: with mock.patch("urlchecker_network.socket.socket") as mock_socket:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise ConnectionRefusedError() raise ConnectionRefusedError()
mock_socket.return_value.connect.side_effect = sideEffect mock_socket.return_value.connect.side_effect = sideEffect
result = isTcpPortOpen(self.db, ip, port, status_id) result = isTcpPortOpen(self.db, ip, port, status_id)
...@@ -161,8 +215,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -161,8 +215,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
assert self.db.NetworkChange.select().count() == 1 assert self.db.NetworkChange.select().count() == 1
assert self.db.NetworkChange.get().ip == ip assert self.db.NetworkChange.get().ip == ip
assert self.db.NetworkChange.get().port == port assert self.db.NetworkChange.get().port == port
assert self.db.NetworkChange.get().transport == 'TCP' assert self.db.NetworkChange.get().transport == "TCP"
assert self.db.NetworkChange.get().state == 'closed' assert self.db.NetworkChange.get().state == "closed"
assert self.db.NetworkChange.get().status_id == status_id assert self.db.NetworkChange.get().status_id == status_id
assert result == False assert result == False
...@@ -172,8 +226,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -172,8 +226,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
status_id = logStatus(self.db, "foo") status_id = logStatus(self.db, "foo")
with mock.patch("urlchecker_network.socket.socket") as mock_socket: with mock.patch("urlchecker_network.socket.socket") as mock_socket:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise urlchecker_network.socket.timeout() raise urlchecker_network.socket.timeout()
mock_socket.return_value.connect.side_effect = sideEffect mock_socket.return_value.connect.side_effect = sideEffect
result = isTcpPortOpen(self.db, ip, port, status_id) result = isTcpPortOpen(self.db, ip, port, status_id)
...@@ -190,8 +246,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -190,8 +246,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
assert self.db.NetworkChange.select().count() == 1 assert self.db.NetworkChange.select().count() == 1
assert self.db.NetworkChange.get().ip == ip assert self.db.NetworkChange.get().ip == ip
assert self.db.NetworkChange.get().port == port assert self.db.NetworkChange.get().port == port
assert self.db.NetworkChange.get().transport == 'TCP' assert self.db.NetworkChange.get().transport == "TCP"
assert self.db.NetworkChange.get().state == 'filtered' assert self.db.NetworkChange.get().state == "filtered"
assert self.db.NetworkChange.get().status_id == status_id assert self.db.NetworkChange.get().status_id == status_id
assert result == False assert result == False
...@@ -201,8 +257,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -201,8 +257,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
status_id = logStatus(self.db, "foo") status_id = logStatus(self.db, "foo")
with mock.patch("urlchecker_network.socket.socket") as mock_socket: with mock.patch("urlchecker_network.socket.socket") as mock_socket:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise OSError(urlchecker_network.errno.EHOSTUNREACH, 'foo') raise OSError(urlchecker_network.errno.EHOSTUNREACH, "foo")
mock_socket.return_value.connect.side_effect = sideEffect mock_socket.return_value.connect.side_effect = sideEffect
result = isTcpPortOpen(self.db, ip, port, status_id) result = isTcpPortOpen(self.db, ip, port, status_id)
...@@ -219,8 +277,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -219,8 +277,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
assert self.db.NetworkChange.select().count() == 1 assert self.db.NetworkChange.select().count() == 1
assert self.db.NetworkChange.get().ip == ip assert self.db.NetworkChange.get().ip == ip
assert self.db.NetworkChange.get().port == port assert self.db.NetworkChange.get().port == port
assert self.db.NetworkChange.get().transport == 'TCP' assert self.db.NetworkChange.get().transport == "TCP"
assert self.db.NetworkChange.get().state == 'filtered' assert self.db.NetworkChange.get().state == "filtered"
assert self.db.NetworkChange.get().status_id == status_id assert self.db.NetworkChange.get().status_id == status_id
assert result == False assert result == False
...@@ -230,8 +288,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -230,8 +288,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
status_id = logStatus(self.db, "foo") status_id = logStatus(self.db, "foo")
with mock.patch("urlchecker_network.socket.socket") as mock_socket: with mock.patch("urlchecker_network.socket.socket") as mock_socket:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise OSError(urlchecker_network.errno.ENETUNREACH, 'foo') raise OSError(urlchecker_network.errno.ENETUNREACH, "foo")
mock_socket.return_value.connect.side_effect = sideEffect mock_socket.return_value.connect.side_effect = sideEffect
result = isTcpPortOpen(self.db, ip, port, status_id) result = isTcpPortOpen(self.db, ip, port, status_id)
...@@ -248,8 +308,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -248,8 +308,8 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
assert self.db.NetworkChange.select().count() == 1 assert self.db.NetworkChange.select().count() == 1
assert self.db.NetworkChange.get().ip == ip assert self.db.NetworkChange.get().ip == ip
assert self.db.NetworkChange.get().port == port assert self.db.NetworkChange.get().port == port
assert self.db.NetworkChange.get().transport == 'TCP' assert self.db.NetworkChange.get().transport == "TCP"
assert self.db.NetworkChange.get().state == 'unreachable' assert self.db.NetworkChange.get().state == "unreachable"
assert self.db.NetworkChange.get().status_id == status_id assert self.db.NetworkChange.get().status_id == status_id
assert result == False assert result == False
...@@ -259,8 +319,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -259,8 +319,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
status_id = logStatus(self.db, "foo") status_id = logStatus(self.db, "foo")
with mock.patch("urlchecker_network.socket.socket") as mock_socket: with mock.patch("urlchecker_network.socket.socket") as mock_socket:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise OSError() raise OSError()
mock_socket.return_value.connect.side_effect = sideEffect mock_socket.return_value.connect.side_effect = sideEffect
try: try:
...@@ -268,7 +330,7 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -268,7 +330,7 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
except OSError: except OSError:
assert self.db.NetworkChange.select().count() == 0 assert self.db.NetworkChange.select().count() == 0
else: else:
raise NotImplementedError('Expected OSError') raise NotImplementedError("Expected OSError")
assert mock_socket.call_count == 1 assert mock_socket.call_count == 1
...@@ -286,8 +348,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -286,8 +348,10 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
status_id = logStatus(self.db, "foo") status_id = logStatus(self.db, "foo")
with mock.patch("urlchecker_network.socket.socket") as mock_socket: with mock.patch("urlchecker_network.socket.socket") as mock_socket:
def sideEffect(*args, **kw): def sideEffect(*args, **kw):
raise Exception() raise Exception()
mock_socket.return_value.connect.side_effect = sideEffect mock_socket.return_value.connect.side_effect = sideEffect
try: try:
...@@ -295,7 +359,7 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -295,7 +359,7 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
except Exception: except Exception:
assert self.db.NetworkChange.select().count() == 0 assert self.db.NetworkChange.select().count() == 0
else: else:
raise NotImplementedError('Expected OSError') raise NotImplementedError("Expected OSError")
assert mock_socket.call_count == 1 assert mock_socket.call_count == 1
...@@ -307,6 +371,7 @@ class UrlCheckerNetworkTestCase(unittest.TestCase): ...@@ -307,6 +371,7 @@ class UrlCheckerNetworkTestCase(unittest.TestCase):
assert mock_socket.return_value.close.call_count == 1 assert mock_socket.return_value.close.call_count == 1
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(UrlCheckerNetworkTestCase)) suite.addTest(unittest.makeSuite(UrlCheckerNetworkTestCase))
......
...@@ -47,8 +47,11 @@ class LogDB: ...@@ -47,8 +47,11 @@ class LogDB:
transport = peewee.TextField() transport = peewee.TextField()
port = peewee.IntegerField() port = peewee.IntegerField()
state = peewee.TextField() state = peewee.TextField()
class Meta: class Meta:
primary_key = peewee.CompositeKey("status", "ip", "transport", "port") primary_key = peewee.CompositeKey(
"status", "ip", "transport", "port"
)
class DnsChange(BaseModel): class DnsChange(BaseModel):
status = peewee.ForeignKeyField(Status) status = peewee.ForeignKeyField(Status)
...@@ -62,6 +65,7 @@ class LogDB: ...@@ -62,6 +65,7 @@ class LogDB:
ip = peewee.TextField() ip = peewee.TextField()
url = peewee.TextField() url = peewee.TextField()
status_code = peewee.IntegerField() status_code = peewee.IntegerField()
class Meta: class Meta:
primary_key = peewee.CompositeKey("status", "ip", "url") primary_key = peewee.CompositeKey("status", "ip", "url")
......
...@@ -21,12 +21,7 @@ def getUserAgent(version): ...@@ -21,12 +21,7 @@ def getUserAgent(version):
) )
def request( def request(url, headers=None, session=requests, version=0):
url,
headers=None,
session=requests,
version=0
):
if headers is None: if headers is None:
headers = {} headers = {}
...@@ -99,19 +94,16 @@ def checkHttpStatus(db, status_id, url, ip, bot_version): ...@@ -99,19 +94,16 @@ def checkHttpStatus(db, status_id, url, ip, bot_version):
) )
session = requests.Session() session = requests.Session()
session.mount(base_url, ForcedIPHTTPSAdapter(dest_ip=ip)) session.mount(base_url, ForcedIPHTTPSAdapter(dest_ip=ip))
request_kw['session'] = session request_kw["session"] = session
ip_url = url ip_url = url
elif parsed_url.scheme == "http": elif parsed_url.scheme == "http":
# Force IP location # Force IP location
parsed_url = parsed_url._replace(netloc=ip) parsed_url = parsed_url._replace(netloc=ip)
ip_url = parsed_url.geturl() ip_url = parsed_url.geturl()
else: else:
raise NotImplementedError('Unhandled url: %s' % url) raise NotImplementedError("Unhandled url: %s" % url)
response = request( response = request(
ip_url, ip_url, headers={"Host": hostname}, version=bot_version, **request_kw
headers={"Host": hostname},
version=bot_version,
**request_kw
) )
logHttpStatus(db, ip, url, response.status_code, status_id) logHttpStatus(db, ip, url, response.status_code, status_id)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment