Commit 2eaeb1f3 authored by Romain Courteaud's avatar Romain Courteaud

db: add indexes to speed up insertion and report

parent 0d576fb6
...@@ -81,6 +81,7 @@ class LogDB: ...@@ -81,6 +81,7 @@ class LogDB:
primary_key = peewee.CompositeKey( primary_key = peewee.CompositeKey(
"status", "ip", "transport", "port" "status", "ip", "transport", "port"
) )
indexes = ((("ip", "transport", "port", "status"), True),)
class DnsChange(BaseModel): class DnsChange(BaseModel):
status = peewee.ForeignKeyField(Status) status = peewee.ForeignKeyField(Status)
...@@ -93,6 +94,9 @@ class LogDB: ...@@ -93,6 +94,9 @@ class LogDB:
primary_key = peewee.CompositeKey( primary_key = peewee.CompositeKey(
"status", "resolver_ip", "domain", "rdtype" "status", "resolver_ip", "domain", "rdtype"
) )
indexes = (
(("resolver_ip", "domain", "rdtype", "status"), True),
)
class SslChange(BaseModel): class SslChange(BaseModel):
status = peewee.ForeignKeyField(Status) status = peewee.ForeignKeyField(Status)
...@@ -109,6 +113,7 @@ class LogDB: ...@@ -109,6 +113,7 @@ class LogDB:
primary_key = peewee.CompositeKey( primary_key = peewee.CompositeKey(
"status", "ip", "port", "hostname" "status", "ip", "port", "hostname"
) )
indexes = ((("ip", "port", "hostname", "status"), True),)
class HttpCodeChange(BaseModel): class HttpCodeChange(BaseModel):
status = peewee.ForeignKeyField(Status) status = peewee.ForeignKeyField(Status)
...@@ -120,6 +125,7 @@ class LogDB: ...@@ -120,6 +125,7 @@ class LogDB:
class Meta: class Meta:
primary_key = peewee.CompositeKey("status", "ip", "url") primary_key = peewee.CompositeKey("status", "ip", "url")
indexes = ((("ip", "url", "status"), True),)
self.Status = Status self.Status = Status
self.ConfigurationChange = ConfigurationChange self.ConfigurationChange = ConfigurationChange
...@@ -132,7 +138,7 @@ class LogDB: ...@@ -132,7 +138,7 @@ class LogDB:
def createTables(self): def createTables(self):
# http://www.sqlite.org/pragma.html#pragma_user_version # http://www.sqlite.org/pragma.html#pragma_user_version
db_version = self._db.pragma("user_version") db_version = self._db.pragma("user_version")
expected_version = 4 expected_version = 5
if db_version != expected_version: if db_version != expected_version:
with self._db.transaction(): with self._db.transaction():
...@@ -176,6 +182,42 @@ class LogDB: ...@@ -176,6 +182,42 @@ class LogDB:
) )
) )
if (0 < db_version) and (db_version <= 4):
# version 4 without the index to speed up reporting
migration_list.extend(
[
migrator.add_index(
"NetworkChange",
("ip", "transport", "port", "status_id"),
True,
),
migrator.add_index(
"DnsChange",
(
"resolver_ip",
"domain",
"rdtype",
"status_id",
),
True,
),
migrator.add_index(
"HttpCodeChange",
("ip", "url", "status_id"),
True,
),
]
)
if (1 < db_version) and (db_version <= 4):
# version 4 without the index to speed up reporting
migration_list.append(
migrator.add_index(
"SslChange",
("ip", "port", "hostname", "status_id"),
True,
)
)
if migration_list: if migration_list:
migrate(*migration_list) migrate(*migration_list)
......
...@@ -25,7 +25,15 @@ from playhouse.reflection import Introspector ...@@ -25,7 +25,15 @@ from playhouse.reflection import Introspector
ValidationResult = namedtuple( ValidationResult = namedtuple(
"ValidationResult", "ValidationResult",
("valid", "table_exists", "add_fields", "remove_fields", "change_fields"), (
"valid",
"table_exists",
"add_fields",
"remove_fields",
"change_fields",
"add_indexes",
"remove_indexes",
),
) )
...@@ -55,8 +63,29 @@ def validate_schema(model): ...@@ -55,8 +63,29 @@ def validate_schema(model):
): ):
to_change.append((field, db_field)) to_change.append((field, db_field))
is_valid = not any((to_remove, to_add, to_change)) indexes = set([(tuple(x), y) for x, y in model._meta.indexes])
return ValidationResult(is_valid, True, to_add, to_remove, to_change) db_indexes = set(
[
(tuple(x), y)
for x, y in db_model._meta.indexes
if tuple(x) != model._meta.primary_key.field_names
]
)
index_to_remove = [c for c in indexes - db_indexes]
index_to_add = [c for c in db_indexes - indexes]
is_valid = not any(
(to_remove, to_add, to_change, index_to_remove, index_to_add)
)
return ValidationResult(
is_valid,
True,
to_add,
to_remove,
to_change,
index_to_remove,
index_to_add,
)
class SurykatkaDBTestCase(unittest.TestCase): class SurykatkaDBTestCase(unittest.TestCase):
...@@ -66,7 +95,19 @@ class SurykatkaDBTestCase(unittest.TestCase): ...@@ -66,7 +95,19 @@ class SurykatkaDBTestCase(unittest.TestCase):
def test_createTable(self): def test_createTable(self):
assert self.db._db.pragma("user_version") == 0 assert self.db._db.pragma("user_version") == 0
self.db.createTables() self.db.createTables()
assert self.db._db.pragma("user_version") == 4 assert self.db._db.pragma("user_version") == 5
assert validate_schema(self.db.HttpCodeChange).valid, validate_schema(
self.db.HttpCodeChange
)
assert validate_schema(self.db.SslChange).valid, validate_schema(
self.db.SslChange
)
assert validate_schema(self.db.DnsChange).valid, validate_schema(
self.db.DnsChange
)
assert validate_schema(self.db.NetworkChange).valid, validate_schema(
self.db.NetworkChange
)
def test_downgrade(self): def test_downgrade(self):
assert self.db._db.pragma("user_version") == 0 assert self.db._db.pragma("user_version") == 0
...@@ -104,13 +145,39 @@ class SurykatkaDBTestCase(unittest.TestCase): ...@@ -104,13 +145,39 @@ class SurykatkaDBTestCase(unittest.TestCase):
"HttpCodeChange", "http_header_dict" "HttpCodeChange", "http_header_dict"
), ),
) )
migrate(
SqliteMigrator(self.db._db).drop_index(
"HttpCodeChange", "HttpCodeChange_ip_url_status_id"
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"DnsChange",
"DnsChange_resolver_ip_domain_rdtype_status_id",
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"NetworkChange",
"NetworkChange_ip_transport_port_status_id",
),
)
self.db._db.pragma("user_version", 1) self.db._db.pragma("user_version", 1)
self.db.createTables() self.db.createTables()
assert self.db._db.pragma("user_version") == 4 assert self.db._db.pragma("user_version") == 5
assert validate_schema(self.db.HttpCodeChange).valid, validate_schema(
self.db.HttpCodeChange
)
assert validate_schema(self.db.SslChange).valid, validate_schema( assert validate_schema(self.db.SslChange).valid, validate_schema(
self.db.SslChange self.db.SslChange
) )
assert validate_schema(self.db.DnsChange).valid, validate_schema(
self.db.DnsChange
)
assert validate_schema(self.db.NetworkChange).valid, validate_schema(
self.db.NetworkChange
)
def test_migrationFromVersion2(self): def test_migrationFromVersion2(self):
assert self.db._db.pragma("user_version") == 0 assert self.db._db.pragma("user_version") == 0
...@@ -138,13 +205,44 @@ class SurykatkaDBTestCase(unittest.TestCase): ...@@ -138,13 +205,44 @@ class SurykatkaDBTestCase(unittest.TestCase):
"HttpCodeChange", "http_header_dict" "HttpCodeChange", "http_header_dict"
), ),
) )
migrate(
SqliteMigrator(self.db._db).drop_index(
"HttpCodeChange", "HttpCodeChange_ip_url_status_id"
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"SslChange", "SslChange_ip_port_hostname_status_id"
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"DnsChange",
"DnsChange_resolver_ip_domain_rdtype_status_id",
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"NetworkChange",
"NetworkChange_ip_transport_port_status_id",
),
)
self.db._db.pragma("user_version", 2) self.db._db.pragma("user_version", 2)
self.db.createTables() self.db.createTables()
assert self.db._db.pragma("user_version") == 4 assert self.db._db.pragma("user_version") == 5
assert validate_schema(self.db.HttpCodeChange).valid, validate_schema( assert validate_schema(self.db.HttpCodeChange).valid, validate_schema(
self.db.HttpCodeChange self.db.HttpCodeChange
) )
assert validate_schema(self.db.SslChange).valid, validate_schema(
self.db.SslChange
)
assert validate_schema(self.db.DnsChange).valid, validate_schema(
self.db.DnsChange
)
assert validate_schema(self.db.NetworkChange).valid, validate_schema(
self.db.NetworkChange
)
def test_migrationFromVersion3(self): def test_migrationFromVersion3(self):
assert self.db._db.pragma("user_version") == 0 assert self.db._db.pragma("user_version") == 0
...@@ -167,13 +265,99 @@ class SurykatkaDBTestCase(unittest.TestCase): ...@@ -167,13 +265,99 @@ class SurykatkaDBTestCase(unittest.TestCase):
"HttpCodeChange", "http_header_dict" "HttpCodeChange", "http_header_dict"
), ),
) )
migrate(
SqliteMigrator(self.db._db).drop_index(
"HttpCodeChange", "HttpCodeChange_ip_url_status_id"
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"SslChange", "SslChange_ip_port_hostname_status_id"
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"DnsChange",
"DnsChange_resolver_ip_domain_rdtype_status_id",
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"NetworkChange",
"NetworkChange_ip_transport_port_status_id",
),
)
self.db._db.pragma("user_version", 3) self.db._db.pragma("user_version", 3)
self.db.createTables() self.db.createTables()
assert self.db._db.pragma("user_version") == 4 assert self.db._db.pragma("user_version") == 5
assert validate_schema(self.db.HttpCodeChange).valid, validate_schema(
self.db.HttpCodeChange
)
assert validate_schema(self.db.SslChange).valid, validate_schema(
self.db.SslChange
)
assert validate_schema(self.db.DnsChange).valid, validate_schema(
self.db.DnsChange
)
assert validate_schema(self.db.NetworkChange).valid, validate_schema(
self.db.NetworkChange
)
def test_migrationFromVersion4(self):
assert self.db._db.pragma("user_version") == 0
# Recreate version 3
with self.db._db.transaction():
self.db._db.create_tables(
[
self.db.Status,
self.db.ConfigurationChange,
self.db.HttpCodeChange,
self.db.NetworkChange,
self.db.PlatformChange,
self.db.DnsChange,
self.db.SslChange,
]
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"HttpCodeChange", "HttpCodeChange_ip_url_status_id"
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"SslChange", "SslChange_ip_port_hostname_status_id"
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"DnsChange",
"DnsChange_resolver_ip_domain_rdtype_status_id",
),
)
migrate(
SqliteMigrator(self.db._db).drop_index(
"NetworkChange",
"NetworkChange_ip_transport_port_status_id",
),
)
self.db._db.pragma("user_version", 4)
self.db.createTables()
assert self.db._db.pragma("user_version") == 5
assert validate_schema(self.db.HttpCodeChange).valid, validate_schema( assert validate_schema(self.db.HttpCodeChange).valid, validate_schema(
self.db.HttpCodeChange self.db.HttpCodeChange
) )
assert validate_schema(self.db.SslChange).valid, validate_schema(
self.db.SslChange
)
assert validate_schema(self.db.DnsChange).valid, validate_schema(
self.db.DnsChange
)
assert validate_schema(self.db.NetworkChange).valid, validate_schema(
self.db.NetworkChange
)
def suite(): def suite():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment