Commit b22b441b authored by Romain Courteaud's avatar Romain Courteaud

Add timeout parameter

parent 7b9db1df
......@@ -210,10 +210,10 @@ class UrlCheckerDNSTestCase(unittest.TestCase):
# buildResolver
################################################
def test_buildResolver_default(self):
resolver = buildResolver("127.0.0.1")
resolver = buildResolver("127.0.0.1", 4)
assert resolver.nameservers == ["127.0.0.1"]
assert resolver.timeout == 2
assert resolver.lifetime == 2
assert resolver.timeout == 4
assert resolver.lifetime == 4
assert resolver.edns == -1
################################################
......
......@@ -311,13 +311,14 @@ class UrlCheckerHttpTestCase(unittest.TestCase):
"https://example.org/foo?bar=1",
)
assert (
len(mock_request.call_args.kwargs) == 3
len(mock_request.call_args.kwargs) == 4
), mock_request.call_args.kwargs
assert mock_request.call_args.kwargs["headers"] == {
"Host": "example.org"
}
assert mock_request.call_args.kwargs["session"] is not None
assert mock_request.call_args.kwargs["version"] == 2
assert mock_request.call_args.kwargs["timeout"] == 2
assert self.db.HttpCodeChange.select().count() == 1
assert self.db.HttpCodeChange.get().ip == ip
......
......@@ -61,11 +61,12 @@ class WebBot:
def iterateLoop(self):
status_id = logStatus(self._db, "loop")
timeout = int(self.config["TIMEOUT"])
# logPlatform(self._db, __version__, status_id)
# Calculate the resolver list
resolver_ip_list = getReachableResolverList(
self._db, status_id, self.config["NAMESERVER"].split()
self._db, status_id, self.config["NAMESERVER"].split(), timeout
)
if not resolver_ip_list:
return
......@@ -76,7 +77,7 @@ class WebBot:
# Get the list of server to check
# XXX Check DNS expiration
server_ip_dict = getDomainIpDict(
self._db, status_id, resolver_ip_list, domain_list, "A"
self._db, status_id, resolver_ip_list, domain_list, "A", timeout
)
# Check TCP port for the list of IP found
......@@ -86,7 +87,9 @@ class WebBot:
for server_ip in server_ip_list:
# XXX Check SSL certificate expiration
for port, protocol in [(80, "http"), (443, "https")]:
if isTcpPortOpen(self._db, server_ip, port, status_id):
if isTcpPortOpen(
self._db, server_ip, port, status_id, timeout
):
for hostname in server_ip_dict[server_ip]:
url = "%s://%s" % (protocol, hostname)
if url not in url_dict:
......@@ -103,7 +106,9 @@ class WebBot:
# Check HTTP Status
for url in url_dict:
for ip in url_dict[url]:
checkHttpStatus(self._db, status_id, url, ip, __version__)
checkHttpStatus(
self._db, status_id, url, ip, __version__, timeout
)
# XXX Check location header and check new url recursively
# XXX Parse HTML, fetch found link, css, js, image
# XXX Check HTTP Cache
......
......@@ -18,6 +18,7 @@ from urlchecker_bot import create_bot
@click.option("--nameserver", "-n", help="The IP of the DNS server.")
@click.option("--url", "-u", help="The url to check.")
@click.option("--domain", "-d", help="The domain to check.")
@click.option("--timeout", "-t", help="The timeout value.")
@click.option(
"--configuration", "-f", help="The path of the configuration file."
)
......@@ -29,7 +30,9 @@ from urlchecker_bot import create_bot
default="plain",
show_default=True,
)
def runUrlChecker(run, sqlite, nameserver, url, domain, configuration, output):
def runUrlChecker(
run, sqlite, nameserver, url, domain, timeout, configuration, output
):
# click.echo("Running url checker bot")
mapping = {}
......
......@@ -30,6 +30,8 @@ def createConfiguration(
)
if "FORMAT" not in config[CONFIG_SECTION]:
config[CONFIG_SECTION]["FORMAT"] = "json"
if "TIMEOUT" not in config[CONFIG_SECTION]:
config[CONFIG_SECTION]["TIMEOUT"] = "1"
if config[CONFIG_SECTION]["SQLITE"] == ":memory:":
# Do not loop when using temporary DB
......
......@@ -90,20 +90,20 @@ def logDnsQuery(db, status_id, resolver_ip, domain_text, rdtype, answer_list):
return previous_entry.status_id
def buildResolver(resolver_ip):
def buildResolver(resolver_ip, timeout):
resolver = dns.resolver.Resolver(configure=False)
resolver.nameservers.append(resolver_ip)
resolver.timeout = TIMEOUT
resolver.lifetime = TIMEOUT
resolver.timeout = timeout
resolver.lifetime = timeout
resolver.edns = -1
return resolver
def queryDNS(db, status_id, resolver_ip, domain_text, rdtype):
def queryDNS(db, status_id, resolver_ip, domain_text, rdtype, timeout=TIMEOUT):
# only A (and AAAA) has address property
assert rdtype == "A"
resolver = buildResolver(resolver_ip)
resolver = buildResolver(resolver_ip, timeout)
try:
answer_list = [
x.address
......@@ -123,14 +123,16 @@ def queryDNS(db, status_id, resolver_ip, domain_text, rdtype):
return answer_list
def getReachableResolverList(db, status_id, resolver_ip_list):
def getReachableResolverList(db, status_id, resolver_ip_list, timeout=TIMEOUT):
# Create a list of resolver object
result_ip_list = []
# Check the DNS server availability once
# to prevent using it later if it is down
for resolver_ip in resolver_ip_list:
resolver_state = "open"
answer_list = queryDNS(db, status_id, resolver_ip, URL_TO_CHECK, "A")
answer_list = queryDNS(
db, status_id, resolver_ip, URL_TO_CHECK, "A", timeout
)
if len(answer_list) == 0:
# We expect a valid response
......@@ -156,12 +158,14 @@ def expandDomainList(domain_list):
return domain_list
def getDomainIpDict(db, status_id, resolver_ip_list, domain_list, rdtype):
def getDomainIpDict(
db, status_id, resolver_ip_list, domain_list, rdtype, timeout=TIMEOUT
):
server_ip_dict = {}
for domain_text in domain_list:
for resolver_ip in resolver_ip_list:
answer_list = queryDNS(
db, status_id, resolver_ip, domain_text, rdtype
db, status_id, resolver_ip, domain_text, rdtype, timeout
)
for address in answer_list:
if address not in server_ip_dict:
......
......@@ -25,7 +25,7 @@ def getUserAgent(version):
)
def request(url, headers=None, session=requests, version=0):
def request(url, timeout=TIMEOUT, headers=None, session=requests, version=0):
if headers is None:
headers = {}
......@@ -37,7 +37,7 @@ def request(url, headers=None, session=requests, version=0):
kwargs = {}
kwargs["stream"] = False
kwargs["timeout"] = TIMEOUT
kwargs["timeout"] = timeout
kwargs["allow_redirects"] = False
kwargs["verify"] = True
args = ["GET", url]
......@@ -125,10 +125,10 @@ def logHttpStatus(db, ip, url, code, status_id):
return previous_entry.status_id
def checkHttpStatus(db, status_id, url, ip, bot_version):
def checkHttpStatus(db, status_id, url, ip, bot_version, timeout=TIMEOUT):
parsed_url = urlparse(url)
hostname = parsed_url.hostname
request_kw = {}
request_kw = {"timeout": timeout}
# SNI Support
if parsed_url.scheme == "https":
# Provide SNI support
......
......@@ -2,6 +2,7 @@ import socket
import errno
from peewee import fn
TIMEOUT = 2
......@@ -86,10 +87,10 @@ def logNetwork(db, ip, transport, port, state, status_id):
return previous_entry.status_id
def isTcpPortOpen(db, ip, port, status_id):
def isTcpPortOpen(db, ip, port, status_id, timeout=TIMEOUT):
is_open = False
sock = socket.socket()
sock.settimeout(TIMEOUT)
sock.settimeout(timeout)
try:
sock.connect((ip, port))
state = "open"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment