Commit da87eced authored by Romain Courteaud's avatar Romain Courteaud

MORE TESTS

parent b2c01784
...@@ -7,6 +7,7 @@ from urlchecker_dns import ( ...@@ -7,6 +7,7 @@ from urlchecker_dns import (
logDnsQuery, logDnsQuery,
buildResolver, buildResolver,
queryDNS, queryDNS,
getReachableResolverList,
) )
from urlchecker_status import logStatus from urlchecker_status import logStatus
import mock import mock
...@@ -374,6 +375,130 @@ class UrlCheckerDNSTestCase(unittest.TestCase): ...@@ -374,6 +375,130 @@ class UrlCheckerDNSTestCase(unittest.TestCase):
assert self.db.DnsChange.get().status_id == status_id assert self.db.DnsChange.get().status_id == status_id
assert result == [] assert result == []
################################################
# getReachableResolverList
################################################
def test_getReachableResolverList_open(self):
resolver_ip = "127.0.0.1"
domain = "example.org"
rdtype = "A"
status_id = logStatus(self.db, "foo")
with mock.patch(
"urlchecker_dns.dns.resolver.Resolver.query"
) as mock_query:
mock_query.return_value = [
MockAnswer("4.3.2.1"),
MockAnswer("1.2.3.4"),
]
result = getReachableResolverList(
self.db, status_id, [resolver_ip]
)
assert mock_query.call_count == 1
mock_query.assert_called_with(
domain, rdtype, raise_on_no_answer=False
)
assert self.db.DnsChange.select().count() == 1
assert self.db.DnsChange.get().resolver_ip == resolver_ip
assert self.db.DnsChange.get().domain == domain
assert self.db.DnsChange.get().rdtype == rdtype
assert self.db.DnsChange.get().response == "1.2.3.4, 4.3.2.1"
assert self.db.DnsChange.get().status_id == status_id
assert result == ["127.0.0.1"]
assert self.db.NetworkChange.select().count() == 1
assert self.db.NetworkChange.get().ip == resolver_ip
assert self.db.NetworkChange.get().port == 53
assert self.db.NetworkChange.get().transport == "UDP"
assert self.db.NetworkChange.get().state == "open"
assert self.db.NetworkChange.get().status_id == status_id
def test_getReachableResolverList_closed(self):
resolver_ip = "127.0.0.1"
domain = "example.org"
rdtype = "A"
status_id = logStatus(self.db, "foo")
with mock.patch(
"urlchecker_dns.dns.resolver.Resolver.query"
) as mock_query:
def sideEffect(*args, **kw):
raise urlchecker_dns.dns.exception.Timeout()
mock_query.side_effect = sideEffect
result = getReachableResolverList(
self.db, status_id, [resolver_ip]
)
assert mock_query.call_count == 1
mock_query.assert_called_with(
domain, rdtype, raise_on_no_answer=False
)
assert self.db.DnsChange.select().count() == 1
assert self.db.DnsChange.get().resolver_ip == resolver_ip
assert self.db.DnsChange.get().domain == domain
assert self.db.DnsChange.get().rdtype == rdtype
assert self.db.DnsChange.get().response == ""
assert self.db.DnsChange.get().status_id == status_id
assert result == []
assert self.db.NetworkChange.select().count() == 1
assert self.db.NetworkChange.get().ip == resolver_ip
assert self.db.NetworkChange.get().port == 53
assert self.db.NetworkChange.get().transport == "UDP"
assert self.db.NetworkChange.get().state == "closed"
assert self.db.NetworkChange.get().status_id == status_id
def test_getReachableResolverList_noiplist(self):
resolver_ip = "127.0.0.1"
domain = "example.org"
rdtype = "A"
status_id = logStatus(self.db, "foo")
resolver = urlchecker_dns.dns.resolver.Resolver(configure=False)
resolver.nameservers.append(resolver_ip)
with mock.patch(
"urlchecker_dns.get_default_resolver"
) as mock_get_default_resolver:
mock_get_default_resolver.return_value = resolver
with mock.patch(
"urlchecker_dns.dns.resolver.Resolver.query"
) as mock_query:
mock_query.return_value = [
MockAnswer("4.3.2.1"),
MockAnswer("1.2.3.4"),
]
result = getReachableResolverList(self.db, status_id, [])
assert mock_get_default_resolver.call_count == 1
assert mock_query.call_count == 1
mock_query.assert_called_with(
domain, rdtype, raise_on_no_answer=False
)
assert self.db.DnsChange.select().count() == 1
assert self.db.DnsChange.get().resolver_ip == resolver_ip
assert self.db.DnsChange.get().domain == domain
assert self.db.DnsChange.get().rdtype == rdtype
assert self.db.DnsChange.get().response == "1.2.3.4, 4.3.2.1"
assert self.db.DnsChange.get().status_id == status_id
assert result == ["127.0.0.1"]
assert self.db.NetworkChange.select().count() == 1
assert self.db.NetworkChange.get().ip == resolver_ip
assert self.db.NetworkChange.get().port == 53
assert self.db.NetworkChange.get().transport == "UDP"
assert self.db.NetworkChange.get().state == "open"
assert self.db.NetworkChange.get().status_id == status_id
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
......
...@@ -3,7 +3,11 @@ from urlchecker_db import LogDB ...@@ -3,7 +3,11 @@ from urlchecker_db import LogDB
from urlchecker_configuration import createConfiguration, logConfiguration from urlchecker_configuration import createConfiguration, logConfiguration
from urlchecker_platform import logPlatform from urlchecker_platform import logPlatform
from urlchecker_status import logStatus from urlchecker_status import logStatus
from urlchecker_dns import getResolverDict, expandDomainList, getServerIpDict from urlchecker_dns import (
getReachableResolverList,
expandDomainList,
getServerIpDict,
)
from urlchecker_http import getUrlHostname, checkHttpStatus from urlchecker_http import getUrlHostname, checkHttpStatus
from urlchecker_network import isTcpPortOpen from urlchecker_network import isTcpPortOpen
...@@ -28,10 +32,10 @@ class WebBot: ...@@ -28,10 +32,10 @@ class WebBot:
logPlatform(self._db, __version__, status_id) logPlatform(self._db, __version__, status_id)
# Calculate the resolver list # Calculate the resolver list
resolver_dict = getResolverDict( resolver_ip_list = getReachableResolverList(
self._db, status_id, self.config["DNS"].split() self._db, status_id, self.config["DNS"].split()
) )
if not resolver_dict: if not resolver_ip_list:
return return
# Calculate the full list of domain to check # Calculate the full list of domain to check
...@@ -49,7 +53,7 @@ class WebBot: ...@@ -49,7 +53,7 @@ class WebBot:
# Get the list of server to check # Get the list of server to check
# XXX Check DNS expiration # XXX Check DNS expiration
server_ip_dict = getServerIpDict( server_ip_dict = getServerIpDict(
self._db, status_id, resolver_dict, domain_list, "A" self._db, status_id, resolver_ip_list, domain_list, "A"
) )
# Check TCP port for the list of IP found # Check TCP port for the list of IP found
......
...@@ -71,30 +71,28 @@ def queryDNS(db, status_id, resolver_ip, domain_text, rdtype): ...@@ -71,30 +71,28 @@ def queryDNS(db, status_id, resolver_ip, domain_text, rdtype):
return answer_list return answer_list
def getResolverDict(db, status_id, resolver_ip_list): def getReachableResolverList(db, status_id, resolver_ip_list):
# Create a list of resolver object # Create a list of resolver object
if len(resolver_ip_list) == 0: if len(resolver_ip_list) == 0:
resolver_ip_list = get_default_resolver().nameservers resolver_ip_list = get_default_resolver().nameservers
resolver_dict = {}
for resolver_ip in resolver_ip_list:
resolver_dict[resolver_ip] = buildResolver(resolver_ip)
result_ip_list = []
# Check the DNS server availability once # Check the DNS server availability once
# to prevent using it later if it is down # to prevent using it later if it is down
resolver_tuple_list = [x for x in resolver_dict.items()] for resolver_ip in resolver_ip_list:
for ip, resolver in resolver_tuple_list:
resolver_state = "open" resolver_state = "open"
answer_list = queryDNS(db, status_id, ip, URL_TO_CHECK, "A") answer_list = queryDNS(db, status_id, resolver_ip, URL_TO_CHECK, "A")
if len(answer_list) == 0: if len(answer_list) == 0:
# We expect a valid response # We expect a valid response
# Drop the DNS server... # Drop the DNS server...
resolver_dict.pop(ip)
resolver_state = "closed" resolver_state = "closed"
logNetwork(db, ip, "UDP", 53, resolver_state, status_id) else:
resolver_state = "open"
result_ip_list.append(resolver_ip)
logNetwork(db, resolver_ip, "UDP", 53, resolver_state, status_id)
return resolver_dict return result_ip_list
def expandDomainList(domain_list): def expandDomainList(domain_list):
...@@ -109,10 +107,10 @@ def expandDomainList(domain_list): ...@@ -109,10 +107,10 @@ def expandDomainList(domain_list):
return domain_list return domain_list
def getServerIpDict(db, status_id, resolver_dict, domain_list, rdtype): def getServerIpDict(db, status_id, resolver_ip_list, domain_list, rdtype):
server_ip_dict = {} server_ip_dict = {}
for domain_text in domain_list: for domain_text in domain_list:
for resolver_ip, resolver in resolver_dict.items(): for resolver_ip in resolver_ip_list:
answer_list = queryDNS( answer_list = queryDNS(
db, status_id, resolver_ip, domain_text, rdtype db, status_id, resolver_ip, domain_text, rdtype
) )
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment