Commit b2c01784 authored by Romain Courteaud's avatar Romain Courteaud

MORE TEST

parent 4e315e81
import unittest import unittest
from urlchecker_db import LogDB from urlchecker_db import LogDB
import peewee import peewee
from urlchecker_dns import expandDomainList, logDnsQuery import urlchecker_dns
from urlchecker_dns import (
expandDomainList,
logDnsQuery,
buildResolver,
queryDNS,
)
from urlchecker_status import logStatus from urlchecker_status import logStatus
import mock
class MockAnswer(object):
def __init__(self, address):
self.address = address
class UrlCheckerDNSTestCase(unittest.TestCase): class UrlCheckerDNSTestCase(unittest.TestCase):
...@@ -192,6 +204,176 @@ class UrlCheckerDNSTestCase(unittest.TestCase): ...@@ -192,6 +204,176 @@ class UrlCheckerDNSTestCase(unittest.TestCase):
assert self.db.DnsChange.select().count() == 8 assert self.db.DnsChange.select().count() == 8
################################################
# buildResolver
################################################
def test_buildResolver_default(self):
resolver = buildResolver("127.0.0.1")
assert resolver.nameservers == ["127.0.0.1"]
assert resolver.timeout == 2
assert resolver.lifetime == 2
assert resolver.edns == -1
################################################
# queryDNS
################################################
def test_queryDNS_default(self):
resolver_ip = "127.0.0.1"
domain = "example.org"
rdtype = "A"
status_id = logStatus(self.db, "foo")
with mock.patch(
"urlchecker_dns.dns.resolver.Resolver.query"
) as mock_query:
mock_query.return_value = [
MockAnswer("4.3.2.1"),
MockAnswer("1.2.3.4"),
]
result = queryDNS(self.db, status_id, resolver_ip, domain, rdtype)
assert mock_query.call_count == 1
mock_query.assert_called_with(
domain, rdtype, raise_on_no_answer=False
)
assert self.db.DnsChange.select().count() == 1
assert self.db.DnsChange.get().resolver_ip == resolver_ip
assert self.db.DnsChange.get().domain == domain
assert self.db.DnsChange.get().rdtype == rdtype
assert self.db.DnsChange.get().response == "1.2.3.4, 4.3.2.1"
assert self.db.DnsChange.get().status_id == status_id
assert result == ["1.2.3.4", "4.3.2.1"]
def test_queryDNS_rejectRdtype(self):
resolver_ip = "127.0.0.1"
domain = "example.org"
rdtype = "B"
status_id = logStatus(self.db, "foo")
try:
queryDNS(self.db, status_id, resolver_ip, domain, rdtype)
except AssertionError:
assert self.db.DnsChange.select().count() == 0
else:
raise NotImplementedError("Expected AssertionError")
def test_queryDNS_nxdomain(self):
resolver_ip = "127.0.0.1"
domain = "example.org"
rdtype = "A"
status_id = logStatus(self.db, "foo")
with mock.patch(
"urlchecker_dns.dns.resolver.Resolver.query"
) as mock_query:
def sideEffect(*args, **kw):
raise urlchecker_dns.dns.resolver.NXDOMAIN()
mock_query.side_effect = sideEffect
result = queryDNS(self.db, status_id, resolver_ip, domain, rdtype)
assert mock_query.call_count == 1
mock_query.assert_called_with(
domain, rdtype, raise_on_no_answer=False
)
assert self.db.DnsChange.select().count() == 1
assert self.db.DnsChange.get().resolver_ip == resolver_ip
assert self.db.DnsChange.get().domain == domain
assert self.db.DnsChange.get().rdtype == rdtype
assert self.db.DnsChange.get().response == ""
assert self.db.DnsChange.get().status_id == status_id
assert result == []
def test_queryDNS_NoAnswer(self):
resolver_ip = "127.0.0.1"
domain = "example.org"
rdtype = "A"
status_id = logStatus(self.db, "foo")
with mock.patch(
"urlchecker_dns.dns.resolver.Resolver.query"
) as mock_query:
def sideEffect(*args, **kw):
raise urlchecker_dns.dns.resolver.NoAnswer()
mock_query.side_effect = sideEffect
result = queryDNS(self.db, status_id, resolver_ip, domain, rdtype)
assert mock_query.call_count == 1
mock_query.assert_called_with(
domain, rdtype, raise_on_no_answer=False
)
assert self.db.DnsChange.select().count() == 1
assert self.db.DnsChange.get().resolver_ip == resolver_ip
assert self.db.DnsChange.get().domain == domain
assert self.db.DnsChange.get().rdtype == rdtype
assert self.db.DnsChange.get().response == ""
assert self.db.DnsChange.get().status_id == status_id
assert result == []
def test_queryDNS_timeout(self):
resolver_ip = "127.0.0.1"
domain = "example.org"
rdtype = "A"
status_id = logStatus(self.db, "foo")
with mock.patch(
"urlchecker_dns.dns.resolver.Resolver.query"
) as mock_query:
def sideEffect(*args, **kw):
raise urlchecker_dns.dns.exception.Timeout()
mock_query.side_effect = sideEffect
result = queryDNS(self.db, status_id, resolver_ip, domain, rdtype)
assert mock_query.call_count == 1
mock_query.assert_called_with(
domain, rdtype, raise_on_no_answer=False
)
assert self.db.DnsChange.select().count() == 1
assert self.db.DnsChange.get().resolver_ip == resolver_ip
assert self.db.DnsChange.get().domain == domain
assert self.db.DnsChange.get().rdtype == rdtype
assert self.db.DnsChange.get().response == ""
assert self.db.DnsChange.get().status_id == status_id
assert result == []
def test_queryDNS_nonameservers(self):
resolver_ip = "127.0.0.1"
domain = "example.org"
rdtype = "A"
status_id = logStatus(self.db, "foo")
with mock.patch(
"urlchecker_dns.dns.resolver.Resolver.query"
) as mock_query:
def sideEffect(*args, **kw):
raise urlchecker_dns.dns.resolver.NoNameservers()
mock_query.side_effect = sideEffect
result = queryDNS(self.db, status_id, resolver_ip, domain, rdtype)
assert mock_query.call_count == 1
mock_query.assert_called_with(
domain, rdtype, raise_on_no_answer=False
)
assert self.db.DnsChange.select().count() == 1
assert self.db.DnsChange.get().resolver_ip == resolver_ip
assert self.db.DnsChange.get().domain == domain
assert self.db.DnsChange.get().rdtype == rdtype
assert self.db.DnsChange.get().response == ""
assert self.db.DnsChange.get().status_id == status_id
assert result == []
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
......
from dns.resolver import get_default_resolver from dns.resolver import get_default_resolver
import dns.resolver import dns
import dns.name
from urlchecker_network import logNetwork from urlchecker_network import logNetwork
URL_TO_CHECK = "example.org" URL_TO_CHECK = "example.org"
...@@ -39,10 +38,20 @@ def logDnsQuery(db, status_id, resolver_ip, domain_text, rdtype, answer_list): ...@@ -39,10 +38,20 @@ def logDnsQuery(db, status_id, resolver_ip, domain_text, rdtype, answer_list):
return previous_entry.status_id return previous_entry.status_id
def queryDNS(db, status_id, resolver_ip, resolver, domain_text, rdtype): def buildResolver(resolver_ip):
resolver = dns.resolver.Resolver(configure=False)
resolver.nameservers.append(resolver_ip)
resolver.timeout = TIMEOUT
resolver.lifetime = TIMEOUT
resolver.edns = -1
return resolver
def queryDNS(db, status_id, resolver_ip, domain_text, rdtype):
# only A (and AAAA) has address property # only A (and AAAA) has address property
assert rdtype == "A" assert rdtype == "A"
resolver = buildResolver(resolver_ip)
try: try:
answer_list = [ answer_list = [
x.address x.address
...@@ -59,7 +68,6 @@ def queryDNS(db, status_id, resolver_ip, resolver, domain_text, rdtype): ...@@ -59,7 +68,6 @@ def queryDNS(db, status_id, resolver_ip, resolver, domain_text, rdtype):
answer_list = [] answer_list = []
logDnsQuery(db, status_id, resolver_ip, domain_text, rdtype, answer_list) logDnsQuery(db, status_id, resolver_ip, domain_text, rdtype, answer_list)
return answer_list return answer_list
...@@ -70,19 +78,14 @@ def getResolverDict(db, status_id, resolver_ip_list): ...@@ -70,19 +78,14 @@ def getResolverDict(db, status_id, resolver_ip_list):
resolver_ip_list = get_default_resolver().nameservers resolver_ip_list = get_default_resolver().nameservers
resolver_dict = {} resolver_dict = {}
for resolver_ip in resolver_ip_list: for resolver_ip in resolver_ip_list:
resolver = dns.resolver.Resolver(configure=False) resolver_dict[resolver_ip] = buildResolver(resolver_ip)
resolver.nameservers.append(resolver_ip)
resolver.timeout = TIMEOUT
resolver.lifetime = TIMEOUT
resolver.edns = -1
resolver_dict[resolver_ip] = resolver
# Check the DNS server availability once # Check the DNS server availability once
# to prevent using it later if it is down # to prevent using it later if it is down
resolver_tuple_list = [x for x in resolver_dict.items()] resolver_tuple_list = [x for x in resolver_dict.items()]
for ip, resolver in resolver_tuple_list: for ip, resolver in resolver_tuple_list:
resolver_state = "open" resolver_state = "open"
answer_list = queryDNS(db, status_id, ip, resolver, URL_TO_CHECK, "A") answer_list = queryDNS(db, status_id, ip, URL_TO_CHECK, "A")
if len(answer_list) == 0: if len(answer_list) == 0:
# We expect a valid response # We expect a valid response
...@@ -111,7 +114,7 @@ def getServerIpDict(db, status_id, resolver_dict, domain_list, rdtype): ...@@ -111,7 +114,7 @@ def getServerIpDict(db, status_id, resolver_dict, domain_list, rdtype):
for domain_text in domain_list: for domain_text in domain_list:
for resolver_ip, resolver in resolver_dict.items(): for resolver_ip, resolver in resolver_dict.items():
answer_list = queryDNS( answer_list = queryDNS(
db, status_id, resolver_ip, resolver, domain_text, rdtype db, status_id, resolver_ip, domain_text, rdtype
) )
for address in answer_list: for address in answer_list:
if address not in server_ip_dict: if address not in server_ip_dict:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment