Skip to content

Commit

Permalink
fix(db): Create single task that executes DiagnosticMessage queries
Browse files Browse the repository at this point in the history
This commit introduces an executor task that handles queries coming from
a queue. Compared to the old approach that easily generated >100k tasks
per scan, this approach is way more memory friendly.

What adds to this is that previously the tasks were generated in a
different event loop than the one they were awaited in, leading to
dangling objects that persisted until the python interpreter closed.
This commit removes these memory leaks which becomes extremely important
when scanners are used "as lib".
  • Loading branch information
ferdinandjarisch committed Feb 3, 2025
1 parent 02190ee commit 8934009
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 35 deletions.
11 changes: 10 additions & 1 deletion src/gallia/command/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,13 @@ async def _db_insert_run_meta(self) -> None:
path=self.artifacts_dir,
)

await self.db_handler.disconnect()

async def _db_finish_run_meta(self) -> None:
if self.db_handler is not None and self.db_handler.connection is not None:
if self.db_handler is not None:
if self.db_handler.meta is not None:
try:
await self.db_handler.connect()
await self.db_handler.complete_run_meta(
datetime.now(UTC).astimezone(), self.run_meta.exit_code, self.artifacts_dir
)
Expand Down Expand Up @@ -519,6 +522,9 @@ async def main(self) -> None: ...
async def setup(self) -> None:
from gallia.plugins.plugin import load_transport

if self.db_handler is not None:
await self.db_handler.connect()

if self.config.power_supply is not None:
self.power_supply = await PowerSupply.connect(self.config.power_supply)
if self.config.power_cycle is True:
Expand All @@ -542,5 +548,8 @@ async def setup(self) -> None:
async def teardown(self) -> None:
await self.transport.close()

if self.db_handler is not None:
await self.db_handler.disconnect()

if self.dumpcap:
await self.dumpcap.stop()
9 changes: 8 additions & 1 deletion src/gallia/commands/discover/doip.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(self, config: DoIPDiscovererConfig):
self.config: DoIPDiscovererConfig = config

# This is an ugly hack to circumvent AsyncScript's shortcomings regarding return codes

def run(self) -> int:
return asyncio.run(self.main2())

Expand All @@ -90,7 +89,11 @@ async def main2(self) -> int:

if self.db_handler is not None:
try:
# We need to manually connect to the DB because we are an AsyncScript that
# does not do so automatically
await self.db_handler.connect()
await self.db_handler.insert_discovery_run("doip")
await self.db_handler.disconnect()
except Exception as e:
logger.warning(f"Could not write the discovery run to the database: {e!r}")

Expand Down Expand Up @@ -420,7 +423,11 @@ async def task_read_diagnostic_messages(
with self.artifacts_dir.joinpath("4_responsive_targets.txt").open("a") as f:
f.write(f"{current_target}\n")
if self.db_handler is not None:
# We need to manually connect to the DB because we are an AsyncScript that
# does not do so automatically
await self.db_handler.connect()
await self.db_handler.insert_discovery_result(current_target)
await self.db_handler.disconnect()

if (
abs(source_address - conn.target_addr) > 10
Expand Down
89 changes: 56 additions & 33 deletions src/gallia/db/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
UDSResponse,
)
from gallia.services.uds.core.utils import bytes_repr as bytes_repr_
from gallia.services.uds.core.utils import g_repr
from gallia.utils import handle_task_error, set_task_handler_ctx_variable


Expand Down Expand Up @@ -143,9 +142,13 @@ def __init__(self, database: Path):
self.target: str | None = None
self.discovery_run: int | None = None
self.meta: int | None = None
self._executor_task: asyncio.Task | None = None
self._execute_queue: asyncio.Queue | None = None

async def connect(self) -> None:
assert self.connection is None, "Already connected to the database"
if self.connection is not None:
logger.warning("Already connected to the database")
return

self.path.parent.mkdir(exist_ok=True, parents=True)
self.connection = await aiosqlite.connect(self.path)
Expand All @@ -161,20 +164,64 @@ async def connect(self) -> None:
await self.connection.executescript(DB_SCHEMA)
await self.check_version()

# This queue is meant to be used for usage-heavy executes that are not time-sensitive, e.g. UDS messages
self._execute_queue = asyncio.Queue()
self._executor_task = asyncio.create_task(self._executor_func())
self._executor_task.add_done_callback(
handle_task_error,
context=set_task_handler_ctx_variable(__name__, "DbHandler"),
)

async def _executor_func(self) -> None:
assert self.connection is not None, "Not connected to the database"
assert self._execute_queue is not None, "Queue was not started"

try:
while True:
(query, query_parameter) = await self._execute_queue.get()

try:
await self.connection.execute(query, query_parameter)
await self.connection.commit()
except aiosqlite.OperationalError:
logger.warning(
f"Could not log message for {query_parameter[5]} to database. Retrying ..."
)
await self._execute_queue.put((query, query_parameter))
finally:
# Inform the the queue that the query was fully processed to track progress
self._execute_queue.task_done()

except asyncio.CancelledError:
logger.debug("Database worker cancelled")
except asyncio.IncompleteReadError as e:
logger.debug(f"Database worker received EOF: {e}")
except Exception as e:
logger.critical(f"Database worker died: {e!r}")

async def disconnect(self) -> None:
assert self.connection is not None, "Not connected to the database"
assert self._execute_queue is not None, "Queue is already detached"
assert self._executor_task is not None, "Task is already detached"

for task in self.tasks:
try:
await task
except Exception as e:
logger.error(f"Inside task: {g_repr(e)}")
logger.info("Syncing database…")
try:
# Wait for all queries in the queue to be written to the database and cancel task afterwards
await self._execute_queue.join()
self._executor_task.cancel()
await self._executor_task
except Exception as e:
logger.error(f"Could not properly clean up the database task: {e!r}")
finally:
self._execute_queue = None
self._executor_task = None

try:
await self.connection.commit()
finally:
await self.connection.close()
self.connection = None
logger.info("Database closed")

async def check_version(self) -> None:
assert self.connection is not None, "Not connected to the database"
Expand Down Expand Up @@ -298,6 +345,7 @@ async def insert_scan_result( # noqa: PLR0913
commit: bool = True,
) -> None:
assert self.connection is not None, "Not connected to the database"
assert self._execute_queue is not None, "Queue not yet created"
assert self.scan_run is not None, "Scan run not yet created"

request_attributes: dict[str, Any] = {
Expand Down Expand Up @@ -366,32 +414,7 @@ async def insert_scan_result( # noqa: PLR0913
log_mode.name,
)

async def execute() -> None:
assert self.connection is not None, "Not connected to the database"

done = False

while not done:
try:
await self.connection.execute(query, query_parameter)
done = True
except aiosqlite.OperationalError:
logger.warning(
f"Could not log message for {query_parameter[5]} to database. Retrying ..."
)
except asyncio.CancelledError:
logger.warning("Database query was cancelled.")
done = True

if commit:
await self.connection.commit()

task = asyncio.create_task(execute())
task.add_done_callback(
handle_task_error,
context=set_task_handler_ctx_variable(__name__, "DbHandler"),
)
self.tasks.append(task)
await self._execute_queue.put((query, query_parameter))

async def insert_session_transition(self, destination: int, steps: list[int]) -> None:
assert self.connection is not None, "Not connected to the database"
Expand Down

0 comments on commit 8934009

Please sign in to comment.