From ba7a3d40a592e2b8f23c95891640e0cd95b6aa71 Mon Sep 17 00:00:00 2001 From: Michael Weiser Date: Thu, 25 Nov 2021 14:38:50 +0000 Subject: [PATCH] db: Add initial asyncio support In order to alleviate the serialisation caused by the single-threaded REST API frontend server we switch the respective database functions to asyncio. This allows the server to at least accept more requests while database operations are in progress. The works very well with MariaDB. Postgres shows a slow response on initial operations up to about 15 seconds caused by an autovacuum operation triggered by an introspection feature of asyncpg (https://github.com/MagicStack/asyncpg/issues/530). The workaround is to disable the JIT of newer postgresql versions server-side for the time being. sqlite flat-out runs into "database is locked" errors. More research is required here. We rewrite the database URL to specific asyncio dialects for now. The plan is to switch to asyncio completely so we can leave them alone in the end. The SQLAlchemy dependency is raised to 1.4.24 because this is the first version to support asyncmy, the new asyncio mysql driver. Move sleeping for retries out of the critical section protected by locking to allow for even more parallelism. Remove the lock on analysis journal retrieval since reads should never conflict with each other. We need to keep the locking in analysis_add() even with async because multiple async calls of the routine may be in progress at various stages of processing and conflict with and possibly deadlock each other, particularly when using sqlite which will throw 'database is locked' errors after a timeout. Having a threading and asyncio Lock protect adding and updating of analyses from threading and asyncio context is likely to not work as required. The only hope is to switch analysis update to asyncio as well. Start to move to the 2.0 SQLAlchemy API using session.execute() and select(), delete() and update() statements. For the asyncio API this is requied since session.query() is not supported there. We switch some non-asyncio users as well while we're at it. This also allows for reusing of statements across retries. Start using session context handlers to get rid of explicit session closing. The testsuite is updated to match. --- peekaboo/db.py | 490 +++++++++++++++++++++++++-------------------- peekaboo/server.py | 4 +- requirements.txt | 2 +- tests/test.py | 27 ++- 4 files changed, 295 insertions(+), 228 deletions(-) diff --git a/peekaboo/db.py b/peekaboo/db.py index fca14d14..2b655357 100644 --- a/peekaboo/db.py +++ b/peekaboo/db.py @@ -25,13 +25,17 @@ """ A class wrapping database operations needed by Peekaboo based on SQLAlchemy. """ +import asyncio import random +import re import time import threading import logging from datetime import datetime, timedelta from sqlalchemy import Column, Integer, String, Text, DateTime, \ Enum, Index +import sqlalchemy.sql.expression +import sqlalchemy.ext.asyncio from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.engine import create_engine from sqlalchemy.orm import sessionmaker, scoped_session @@ -154,10 +158,23 @@ def __init__(self, db_url, instance_id=0, """ logging.getLogger('sqlalchemy.engine').setLevel(log_level) - self.__engine = create_engine(db_url) + self.__engine = create_engine(db_url, future=True) session_factory = sessionmaker(bind=self.__engine) self.__session = scoped_session(session_factory) self.__lock = threading.RLock() + self.__async_lock = asyncio.Lock() + + async_db_url = re.sub( + r'sqlite(\+[a-z]+)?:///', 'sqlite+aiosqlite:///', re.sub( + r'mysql(\+[a-z]+)?://', 'mysql+asyncmy://', re.sub( + r'postgresql(\+[a-z]+)?://', 'postgresql+asyncpg://', + db_url))) + async_engine = sqlalchemy.ext.asyncio.create_async_engine(async_db_url) + self.__async_session_factory = sessionmaker( + bind=async_engine, + class_=sqlalchemy.ext.asyncio.AsyncSession) + # no scoping necessary as we're not using asyncio across threads + self.instance_id = instance_id self.stale_in_flight_threshold = stale_in_flight_threshold self.retries = 5 @@ -169,23 +186,23 @@ def __init__(self, db_url, instance_id=0, self.deadlock_backoff_base = 10 self.connect_backoff_base = 2000 - with self.__lock: - attempt = 1 - while attempt <= self.retries: + attempt = 1 + delay = 0 + while attempt <= self.retries: + with self.__lock: try: Base.metadata.create_all(self.__engine) + break except (OperationalError, DBAPIError, SQLAlchemyError) as error: - attempt = self.was_transient_error( + attempt, delay = self.was_transient_error( error, attempt, 'create metadata') - if attempt > 0: - continue - raise PeekabooDatabaseError( - 'Failed to create schema in database: %s' % - error) + if attempt < 0: + raise PeekabooDatabaseError( + 'Failed to create schema in database: %s' % error) - break + time.sleep(delay) def was_transient_error(self, error, attempt, action): """ Decide if an exception signals a transient error condition and @@ -198,12 +215,13 @@ def was_transient_error(self, error, attempt, action): """ # will not be retried anyway, so no use checking and sleeping if attempt >= self.retries: - return -1 + return -1, 0 # only DBAPIError has connection_invalidated + if getattr(error, 'connection_invalidated', False): logger.debug('Connection invalidated %s. Retrying.', action) - return attempt + 1 + return attempt + 1, 0 # Access the original DBAPI exception anonymously. # We intentionally do some crude duck-typing here to avoid @@ -212,7 +230,7 @@ def was_transient_error(self, error, attempt, action): # identically numbered error of another RDBMS. if (getattr(error, 'orig', None) is None or getattr(error.orig, 'args', None) is None): - return -1 + return -1, 0 args = error.orig.args @@ -224,8 +242,7 @@ def was_transient_error(self, error, attempt, action): backoff = random.randint(maxmsecs/2, maxmsecs) logger.debug('Connection failed %s, backing off for %d ' 'milliseconds before retrying', action, backoff) - time.sleep(backoff / 1000) - return attempt + 1 + return attempt + 1, backoff / 1000 # (MySQLdb._exceptions.OperationalError) (1213, 'Deadlock # found when trying to get lock; try restarting transaction') @@ -235,12 +252,11 @@ def was_transient_error(self, error, attempt, action): backoff = random.randint(maxmsecs/2, maxmsecs) logger.debug('Database deadlock detected %s, backing off for %d ' 'milliseconds before retrying.', action, backoff) - time.sleep(backoff / 1000) - return attempt + 1 + return attempt + 1, backoff / 1000 - return -1 + return -1, 0 - def analysis_add(self, sample): + async def analysis_add(self, sample): """ Add an analysis task to the analysis journal in the database. @@ -257,33 +273,31 @@ def analysis_add(self, sample): reason=sample.reason) job_id = None - with self.__lock: - attempt = 1 - while attempt <= self.retries: - session = self.__session() - session.add(sample_info) - try: - # flush to retrieve the automatically assigned primary key - # value - session.flush() - job_id = sample_info.id - session.commit() - except (OperationalError, DBAPIError, - SQLAlchemyError) as error: - session.rollback() - - attempt = self.was_transient_error( - error, attempt, 'saving analysis result') - if attempt > 0: - continue - - raise PeekabooDatabaseError( - 'Failed to add analysis task to the database: %s' % - error) - finally: - session.close() - - break + attempt = 1 + delay = 0 + while attempt <= self.retries: + async with self.__async_lock: + async with self.__async_session_factory() as session: + session.add(sample_info) + try: + # flush to retrieve the automatically assigned primary + # key value + await session.flush() + job_id = sample_info.id + await session.commit() + break + except (OperationalError, DBAPIError, + SQLAlchemyError) as error: + await session.rollback() + + attempt, delay = self.was_transient_error( + error, attempt, 'adding analysis') + + if attempt < 0: + raise PeekabooDatabaseError( + 'Failed to add analysis task to the database: %s' % error) + + await asyncio.sleep(delay) sample.update_id(job_id) return job_id @@ -294,34 +308,34 @@ def analysis_update(self, sample): @param sample: The sample object for this analysis task. """ - with self.__lock: - attempt = 1 - while attempt <= self.retries: - session = self.__session() - analysis = session.query(SampleInfo).filter_by( - id=sample.id).first() - analysis.state = sample.state - analysis.result = sample.result - analysis.reason = sample.reason + statement = sqlalchemy.sql.expression.update(SampleInfo).where( + SampleInfo.id == sample.id).values( + state=sample.state, + result=sample.result, + reason=sample.reason) - try: - session.commit() - except (OperationalError, DBAPIError, - SQLAlchemyError) as error: - session.rollback() - - attempt = self.was_transient_error( - error, attempt, 'saving analysis result') - if attempt > 0: - continue - - raise PeekabooDatabaseError( - 'Failed to add analysis task to the database: %s' % - error) - finally: - session.close() - - break + attempt = 1 + delay = 0 + while attempt <= self.retries: + with self.__lock: + with self.__session() as session: + try: + session.execute(statement) + session.commit() + break + except (OperationalError, DBAPIError, + SQLAlchemyError) as error: + session.rollback() + + attempt, delay = self.was_transient_error( + error, attempt, 'updating analysis') + + if attempt < 0: + raise PeekabooDatabaseError( + 'Failed to update analysis task in the database: %s' % + error) + + time.sleep(delay) def analysis_journal_fetch_journal(self, sample): """ @@ -332,19 +346,40 @@ def analysis_journal_fetch_journal(self, sample): @return: A sorted list of (analysis_time, result, reason) of the requested sample. """ - with self.__lock: - session = self.__session() - sample_journal = session.query( - SampleInfo.analysis_time, SampleInfo.result, SampleInfo.reason - ).filter(SampleInfo.id != sample.id).filter_by( + statement = sqlalchemy.sql.expression.select( + SampleInfo.analysis_time, SampleInfo.result, + SampleInfo.reason).where( + SampleInfo.id != sample.id).filter_by( state=JobState.FINISHED, sha256sum=sample.sha256sum, - file_extension=sample.file_extension - ).order_by(SampleInfo.analysis_time).all() - session.close() + file_extension=sample.file_extension).order_by( + SampleInfo.analysis_time) + + sample_journal = None + attempt = 1 + delay = 0 + while attempt <= self.retries: + with self.__session() as session: + try: + sample_journal = session.execute(statement).all() + break + except (OperationalError, DBAPIError, + SQLAlchemyError) as error: + session.rollback() + + attempt, delay = self.was_transient_error( + error, attempt, 'fetching analysis journal') + + if attempt < 0: + raise PeekabooDatabaseError( + 'Failed to fetch analysis journal from the database: %s' % + error) + + time.sleep(delay) + return sample_journal - def analysis_retrieve(self, job_id): + async def analysis_retrieve(self, job_id): """ Fetch information stored in the database about a given sample object. @@ -352,14 +387,34 @@ def analysis_retrieve(self, job_id): @type job_id: int @return: reason and result for the given analysis task """ - with self.__lock: - session = self.__session() - sample_result = session.query( - SampleInfo.reason, SampleInfo.result - ).filter_by(id=job_id, state=JobState.FINISHED).first() - session.close() + statement = sqlalchemy.sql.expression.select( + SampleInfo.reason, SampleInfo.result).filter_by( + id=job_id, state=JobState.FINISHED) - return sample_result + result = None + attempt = 1 + delay = 0 + while attempt <= self.retries: + async with self.__async_session_factory() as session: + try: + proxy = await session.execute(statement) + result = proxy.first() + break + except (OperationalError, DBAPIError, + SQLAlchemyError) as error: + await session.rollback() + + attempt, delay = self.was_transient_error( + error, attempt, 'retrieving analysis result') + + if attempt < 0: + raise PeekabooDatabaseError( + 'Failed to retrieve analysis from the database: %s' % + error) + + await asyncio.sleep(delay) + + return result def mark_sample_in_flight(self, sample, instance_id=None, start_time=None): """ @@ -388,40 +443,37 @@ def mark_sample_in_flight(self, sample, instance_id=None, start_time=None): instance_id=instance_id, start_time=start_time) attempt = 1 - locked = False + delay = 0 while attempt <= self.retries: # a new session needs to be constructed on each attempt - session = self.__session() - - # try to mark this sample as in flight in an atomic insert - # operation (modulo possible deadlocks with various RDBMS) - session.add(in_flight_marker) - - try: - session.commit() - locked = True - logger.debug('Marked sample %s as in flight', sha256sum) - # duplicate primary key == entry already exists - except IntegrityError: - session.rollback() - logger.debug('Sample %s is already in flight on another ' - 'instance', sha256sum) - except (OperationalError, DBAPIError, - SQLAlchemyError) as error: - session.rollback() - - attempt = self.was_transient_error( - error, attempt, 'marking sample %s as in flight' % - sha256sum) - if attempt > 0: - continue - - raise PeekabooDatabaseError( - 'Unable to mark sample as in flight: %s' % error) - finally: - session.close() - - return locked + with self.__session() as session: + # try to mark this sample as in flight in an atomic insert + # operation (modulo possible deadlocks with various RDBMS) + session.add(in_flight_marker) + + try: + session.commit() + logger.debug('Marked sample %s as in flight', sha256sum) + return True + # duplicate primary key == entry already exists + except IntegrityError: + session.rollback() + logger.debug('Sample %s is already in flight on another ' + 'instance', sha256sum) + return False + except (OperationalError, DBAPIError, + SQLAlchemyError) as error: + session.rollback() + + attempt, delay = self.was_transient_error( + error, attempt, 'marking sample %s as in flight' % + sha256sum) + + if attempt < 0: + raise PeekabooDatabaseError( + 'Unable to mark sample as in flight: %s' % error) + + time.sleep(delay) return False @@ -443,36 +495,35 @@ def clear_sample_in_flight(self, sample, instance_id=None): instance_id = self.instance_id sha256sum = sample.sha256sum + statement = sqlalchemy.sql.expression.delete( + InFlightSample).where( + InFlightSample.sha256sum == sha256sum).where( + InFlightSample.instance_id == instance_id) attempt = 1 + cleared = 0 while attempt <= self.retries: - session = self.__session() - - # clear in-flight marker from database - query = session.query(InFlightSample).filter( - InFlightSample.sha256sum == sha256sum).filter( - InFlightSample.instance_id == instance_id) - - try: - # delete() is not queued and goes to the DB before commit() - cleared = query.delete() - session.commit() - except (OperationalError, DBAPIError, - SQLAlchemyError) as error: - session.rollback() + with self.__session() as session: + try: + # clear in-flight marker from database + marker = session.execute(statement) + session.commit() + cleared = marker.rowcount + break + except (OperationalError, DBAPIError, + SQLAlchemyError) as error: + session.rollback() - attempt = self.was_transient_error( - error, attempt, 'clearing in-flight status of sample %s' % - sha256sum) - if attempt > 0: - continue + attempt, delay = self.was_transient_error( + error, attempt, 'clearing in-flight status of ' + 'sample %s' % sha256sum) - raise PeekabooDatabaseError('Unable to clear in-flight status ' - 'of sample: %s' % error) - finally: - session.close() + if attempt < 0: + raise PeekabooDatabaseError( + 'Unable to clear in-flight status of sample: %s' % + error) - break + time.sleep(delay) if cleared == 0: raise PeekabooDatabaseError('Unexpected inconsistency: Sample %s ' @@ -507,39 +558,39 @@ def clear_in_flight_samples(self, instance_id=None): if instance_id is None: instance_id = self.instance_id + if instance_id < 0: + # delete all locks + statement = sqlalchemy.sql.expression.delete(InFlightSample) + logger.debug('Clearing database of all in-flight samples.') + else: + # delete only the locks of a specific instance + statement = sqlalchemy.sql.expression.delete( + InFlightSample).where( + InFlightSample.instance_id == instance_id) + logger.debug('Clearing database of all in-flight samples of ' + 'instance %d.', instance_id) + attempt = 1 while attempt <= self.retries: - session = self.__session() - - if instance_id < 0: - # delete all locks - query = session.query(InFlightSample) - logger.debug('Clearing database of all in-flight samples.') - else: - # delete only the locks of a specific instance - query = session.query(InFlightSample).filter( - InFlightSample.instance_id == instance_id) - logger.debug('Clearing database of all in-flight samples of ' - 'instance %d.', instance_id) - try: - # delete() is not queued and goes to the DB before commit() - query.delete() - session.commit() - except (OperationalError, DBAPIError, - SQLAlchemyError) as error: - session.rollback() - - attempt = self.was_transient_error( - error, attempt, 'clearing database of in-flight samples') - if attempt > 0: - continue - - raise PeekabooDatabaseError('Unable to clear the database of ' - 'in-flight samples: %s' % error) - finally: - session.close() - - break + with self.__session() as session: + try: + session.execute(statement) + session.commit() + break + except (OperationalError, DBAPIError, + SQLAlchemyError) as error: + session.rollback() + + attempt, delay = self.was_transient_error( + error, attempt, + 'clearing database of in-flight samples') + + if attempt < 0: + raise PeekabooDatabaseError( + 'Unable to clear the database of in-flight ' + 'samples: %s' % error) + + time.sleep(delay) def clear_stale_in_flight_samples(self): """ @@ -555,49 +606,54 @@ def clear_stale_in_flight_samples(self): 'Clearing database of all stale in-flight samples ' '(%d seconds)', self.stale_in_flight_threshold) - attempt = 1 - while attempt <= self.retries: - session = self.__session() - + def clear_statement(statement_class): # delete only the locks of a specific instance - query = session.query(InFlightSample).filter( + return statement_class(InFlightSample).where( InFlightSample.start_time <= datetime.utcnow() - timedelta( seconds=self.stale_in_flight_threshold)) - try: - # the loop triggers the query, so only do it if debugging is - # enabled - if logger.isEnabledFor(logging.DEBUG): - # obviously there's a race between logging and actual - # delete here, use with caution, compare with actual number - # of markers cleared below before relying on it for - # debugging - for stale in query: - logger.debug( - 'Stale in-flight marker to clear: %s', stale) - - # delete() is not queued and goes to the DB before commit() - cleared = query.delete() - session.commit() - if cleared > 0: - logger.warning( - '%d stale in-flight samples cleared.', cleared) - except (OperationalError, DBAPIError, - SQLAlchemyError) as error: - session.rollback() - - attempt = self.was_transient_error( - error, attempt, - 'clearing the database of stale in-flight samples') - if attempt > 0: - continue - - raise PeekabooDatabaseError( - 'Unable to clear the database of stale in-flight ' - 'samples: %s' % error) - finally: - session.close() - - break + + delete_statement = clear_statement(sqlalchemy.sql.expression.delete) + select_statement = clear_statement(sqlalchemy.sql.expression.select) + + attempt = 1 + cleared = 0 + while attempt <= self.retries: + with self.__session() as session: + try: + # only do the query if debugging is enabled + if logger.isEnabledFor(logging.DEBUG): + # obviously there's a race between logging and actual + # delete here, use with caution, compare with actual + # number of markers cleared below before relying on it + # for debugging + markers = session.execute(select_statement) + for stale in markers: + logger.debug( + 'Stale in-flight marker to clear: %s', stale) + + markers = session.execute(delete_statement) + session.commit() + + cleared = markers.rowcount + if cleared > 0: + logger.warning( + '%d stale in-flight samples cleared.', cleared) + + break + except (OperationalError, DBAPIError, + SQLAlchemyError) as error: + session.rollback() + + attempt, delay = self.was_transient_error( + error, attempt, + 'clearing the database of stale in-flight samples') + + if attempt < 0: + raise PeekabooDatabaseError( + 'Unable to clear the database of stale in-flight ' + 'samples: %s' % error) + + time.sleep(delay) return cleared > 0 diff --git a/peekaboo/server.py b/peekaboo/server.py index 74e3f38a..3c13cd7c 100644 --- a/peekaboo/server.py +++ b/peekaboo/server.py @@ -110,7 +110,7 @@ async def scan(self, request): sample_file.type, content_disposition) try: - self.db_con.analysis_add(sample) + await self.db_con.analysis_add(sample) except PeekabooDatabaseError as dberr: logger.error('Failed to add analysis to database: %s', dberr) return sanic.response.json( @@ -138,7 +138,7 @@ async def report(self, _, job_id): {'message': 'job ID missing from request'}, 400) try: - job_info = self.db_con.analysis_retrieve(job_id) + job_info = await self.db_con.analysis_retrieve(job_id) except PeekabooDatabaseError as dberr: logger.error('Failed to retrieve analysis result from ' 'database: %s', dberr) diff --git a/requirements.txt b/requirements.txt index 3a7696aa..1c7f61b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -sqlalchemy>=1.1.0 +sqlalchemy[asyncio]>=1.4.24 python-magic>=0.4.17 oletools>=0.54 sdnotify>=0.3.1 diff --git a/tests/test.py b/tests/test.py index e5b12c60..358fea92 100755 --- a/tests/test.py +++ b/tests/test.py @@ -26,6 +26,7 @@ """ The testsuite. """ +import asyncio import gettext import sys import os @@ -391,10 +392,12 @@ def setUpClass(cls): 'This is just a test case.', further_analysis=False) cls.sample.add_rule_result(result) + cls.loop = asyncio.get_event_loop() def test_1_analysis_add(self): """ Test adding a new analysis. """ - self.db_con.analysis_add(self.sample) + self.loop.run_until_complete( + self.db_con.analysis_add(self.sample)) # sample now contains a job ID def test_2_analysis_update(self): @@ -405,7 +408,8 @@ def test_2_analysis_update(self): def test_3_analysis_journal_fetch_journal(self): """ Test retrieval of analysis results. """ - self.db_con.analysis_add(self.sample) + self.loop.run_until_complete( + self.db_con.analysis_add(self.sample)) # sample now contains another, new job ID # mark sample done so journal and result retrieval tests can work self.sample.mark_done() @@ -425,9 +429,11 @@ def test_3_analysis_journal_fetch_journal(self): def test_4_analysis_retrieve(self): """ Test retrieval of analysis results. """ - self.db_con.analysis_add(self.sample) + self.loop.run_until_complete( + self.db_con.analysis_add(self.sample)) # sample now contains a job ID - reason, result = self.db_con.analysis_retrieve(self.sample.id) + reason, result = self.loop.run_until_complete( + self.db_con.analysis_retrieve(self.sample.id)) self.assertEqual(result, Result.failed) self.assertEqual(reason, 'This is just a test case.') @@ -961,6 +967,7 @@ def setUpClass(cls): cls.tests_data_dir = os.path.join(TESTSDIR, "test-data") cls.office_data_dir = os.path.join(cls.tests_data_dir, 'office') + cls.loop = asyncio.get_event_loop() def test_config_known(self): # pylint: disable=no-self-use """ Test the known rule configuration. """ @@ -1286,7 +1293,8 @@ def test_rule_expression_knowntools(self): 'Unittest', Result.failed, 'This is just a test case.', further_analysis=False) sample.add_rule_result(failed_result) - db_con.analysis_add(sample) + self.loop.run_until_complete( + db_con.analysis_add(sample)) sample.mark_done() db_con.analysis_update(sample) @@ -1302,7 +1310,8 @@ def test_rule_expression_knowntools(self): self.assertEqual(result.result, Result.ignored) sample.add_rule_result(result) - db_con.analysis_add(sample) + self.loop.run_until_complete( + db_con.analysis_add(sample)) sample.mark_done() db_con.analysis_update(sample) @@ -1315,7 +1324,8 @@ def test_rule_expression_knowntools(self): self.assertEqual(result.result, Result.ignored) sample.add_rule_result(result) - db_con.analysis_add(sample) + self.loop.run_until_complete( + db_con.analysis_add(sample)) sample.mark_done() db_con.analysis_update(sample) @@ -1324,7 +1334,8 @@ def test_rule_expression_knowntools(self): self.assertEqual(result.result, Result.ignored) sample.add_rule_result(failed_result) - db_con.analysis_add(sample) + self.loop.run_until_complete( + db_con.analysis_add(sample)) sample.mark_done() db_con.analysis_update(sample)