diff --git a/src/tribler/core/components/bandwidth_accounting/db/database.py b/src/tribler/core/components/bandwidth_accounting/db/database.py index 6d7512ba6a1..6594b95c9b7 100644 --- a/src/tribler/core/components/bandwidth_accounting/db/database.py +++ b/src/tribler/core/components/bandwidth_accounting/db/database.py @@ -5,7 +5,8 @@ from tribler.core.components.bandwidth_accounting.db import history, misc, transaction as db_transaction from tribler.core.components.bandwidth_accounting.db.transaction import BandwidthTransactionData -from tribler.core.utilities.pony_utils import TriblerDatabase, handle_db_if_corrupted +from tribler.core.utilities.db_corruption_handling.base import handle_db_if_corrupted +from tribler.core.utilities.pony_utils import TriblerDatabase from tribler.core.utilities.utilities import MEMORY_DB @@ -50,6 +51,7 @@ def sqlite_sync_pragmas(_, connection): create_db = True db_path_string = ":memory:" else: + # We need to handle the database corruption case before determining the state of the create_db flag. handle_db_if_corrupted(db_path) create_db = not db_path.is_file() db_path_string = str(db_path) diff --git a/src/tribler/core/components/component.py b/src/tribler/core/components/component.py index 3acc8094d1e..513fa737526 100644 --- a/src/tribler/core/components/component.py +++ b/src/tribler/core/components/component.py @@ -9,8 +9,8 @@ from tribler.core.components.exceptions import ComponentStartupException, MissedDependency, NoneComponent from tribler.core.components.reporter.exception_handler import default_core_exception_handler from tribler.core.sentry_reporter.sentry_reporter import SentryReporter +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted from tribler.core.utilities.exit_codes import EXITCODE_DATABASE_IS_CORRUPTED -from tribler.core.utilities.pony_utils import DatabaseIsCorrupted from tribler.core.utilities.process_manager import get_global_process_manager if TYPE_CHECKING: @@ -52,8 +52,20 @@ async def start(self): self.started_event.set() if isinstance(e, DatabaseIsCorrupted): + # When the database corruption is detected, we should stop the process immediately. + # Tribler GUI will restart the process and the database will be recreated. + + # Usually we wrap an exception into ComponentStartupException, and allow + # CoreExceptionHandler.unhandled_error_observer to handle it after all components are started, + # but in this case we don't do it. The reason is that handling ComponentStartupException + # starts the shutting down of Tribler, and due to some obscure reasons it is not possible to + # raise any exception, even SystemExit, from CoreExceptionHandler.unhandled_error_observer when + # Tribler is shutting down. It looks like in this case unhandled_error_observer is called from + # Task.__del__ method and all exceptions that are raised from __del__ are ignored. + # See https://bugs.python.org/issue25489 for similar case. process_manager = get_global_process_manager() process_manager.sys_exit(EXITCODE_DATABASE_IS_CORRUPTED, e) + return # Added for clarity; actually, the code raised SystemExit on the previous line if self.session.failfast: raise e diff --git a/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py b/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py index a329e8a3783..529b6048e62 100644 --- a/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py +++ b/src/tribler/core/components/gigachannel_manager/gigachannel_manager.py @@ -12,8 +12,9 @@ from tribler.core.components.metadata_store.db.orm_bindings.channel_node import COMMITTED from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT from tribler.core.components.metadata_store.db.store import MetadataStore +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted from tribler.core.utilities.notifier import Notifier -from tribler.core.utilities.pony_utils import DatabaseIsCorrupted, run_threaded +from tribler.core.utilities.pony_utils import run_threaded from tribler.core.utilities.simpledefs import DownloadStatus from tribler.core.utilities.unicode import hexlify diff --git a/src/tribler/core/components/metadata_store/db/store.py b/src/tribler/core/components/metadata_store/db/store.py index e89306c5726..b8addf25816 100644 --- a/src/tribler/core/components/metadata_store/db/store.py +++ b/src/tribler/core/components/metadata_store/db/store.py @@ -46,11 +46,10 @@ from tribler.core.components.metadata_store.remote_query_community.payload_checker import process_payload from tribler.core.components.torrent_checker.torrent_checker.dataclasses import HealthInfo from tribler.core.exceptions import InvalidSignatureException +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted, handle_db_if_corrupted from tribler.core.utilities.notifier import Notifier from tribler.core.utilities.path_util import Path -from tribler.core.utilities.pony_utils import DatabaseIsCorrupted, TriblerDatabase, get_max, get_or_create, \ - handle_db_if_corrupted, \ - run_threaded +from tribler.core.utilities.pony_utils import TriblerDatabase, get_max, get_or_create, run_threaded from tribler.core.utilities.search_utils import torrent_rank from tribler.core.utilities.unicode import hexlify from tribler.core.utilities.utilities import MEMORY_DB @@ -220,6 +219,7 @@ def on_connect(_, connection): create_db = True db_path_string = ":memory:" else: + # We need to handle the database corruption case before determining the state of the create_db flag. handle_db_if_corrupted(db_filename) create_db = not db_filename.is_file() db_path_string = str(db_filename) diff --git a/src/tribler/core/components/reporter/exception_handler.py b/src/tribler/core/components/reporter/exception_handler.py index 8dc7700cd32..018dcc67f3d 100644 --- a/src/tribler/core/components/reporter/exception_handler.py +++ b/src/tribler/core/components/reporter/exception_handler.py @@ -10,8 +10,8 @@ from tribler.core.components.exceptions import ComponentStartupException from tribler.core.components.reporter.reported_error import ReportedError from tribler.core.sentry_reporter.sentry_reporter import SentryReporter +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted from tribler.core.utilities.exit_codes import EXITCODE_DATABASE_IS_CORRUPTED -from tribler.core.utilities.pony_utils import DatabaseIsCorrupted from tribler.core.utilities.process_manager import get_global_process_manager # There are some errors that we are ignoring. diff --git a/src/tribler/core/components/reporter/tests/test_exception_handler.py b/src/tribler/core/components/reporter/tests/test_exception_handler.py index 41bf79779b7..3b2ecb6cf38 100644 --- a/src/tribler/core/components/reporter/tests/test_exception_handler.py +++ b/src/tribler/core/components/reporter/tests/test_exception_handler.py @@ -7,6 +7,7 @@ from tribler.core.components.reporter.exception_handler import CoreExceptionHandler from tribler.core.sentry_reporter import sentry_reporter from tribler.core.sentry_reporter.sentry_reporter import SentryReporter +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted # pylint: disable=protected-access, redefined-outer-name @@ -85,6 +86,17 @@ def test_unhandled_error_observer_exception(exception_handler): assert reported_error.should_stop +@patch('tribler.core.components.reporter.exception_handler.get_global_process_manager') +def test_unhandled_error_observer_database_corrupted(get_global_process_manager, exception_handler): + # test that database corruption exception reported to the GUI + exception = DatabaseIsCorrupted('db_path_string') + exception_handler.report_callback = MagicMock() + exception_handler.unhandled_error_observer(None, {'exception': exception}) + + get_global_process_manager().sys_exit.assert_called_once_with(99, exception) + exception_handler.report_callback.assert_not_called() + + def test_unhandled_error_observer_only_message(exception_handler): # test that unhandled exception, represented by message, reported to the GUI context = {'message': 'Any'} diff --git a/src/tribler/core/components/tests/test_base_component.py b/src/tribler/core/components/tests/test_base_component.py index 49cfd7a0740..4f3bf0683b2 100644 --- a/src/tribler/core/components/tests/test_base_component.py +++ b/src/tribler/core/components/tests/test_base_component.py @@ -1,9 +1,12 @@ +from unittest.mock import patch + import pytest from tribler.core.components.component import Component from tribler.core.components.exceptions import MissedDependency, MultipleComponentsFound, NoneComponent from tribler.core.components.session import Session from tribler.core.config.tribler_config import TriblerConfig +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted class ComponentTestException(Exception): @@ -46,6 +49,20 @@ class TestComponentB(TestComponent): assert component.stopped +@patch('tribler.core.components.component.get_global_process_manager') +async def test_session_start_database_corruption_detected(get_global_process_manager): + exception = DatabaseIsCorrupted('db_path_string') + + class TestComponent(Component): + async def run(self): + raise exception + + component = TestComponent() + + await component.start() + get_global_process_manager().sys_exit.assert_called_once_with(99, exception) + + class ComponentA(Component): pass diff --git a/src/tribler/core/upgrade/db8_to_db10.py b/src/tribler/core/upgrade/db8_to_db10.py index 9151a4075cd..68534b56456 100644 --- a/src/tribler/core/upgrade/db8_to_db10.py +++ b/src/tribler/core/upgrade/db8_to_db10.py @@ -1,14 +1,13 @@ import contextlib import datetime import logging -import sqlite3 from collections import deque from time import time as now from pony.orm import db_session from tribler.core.components.metadata_store.db.store import MetadataStore -from tribler.core.utilities.pony_utils import marking_corrupted_db +from tribler.core.utilities.db_corruption_handling import sqlite_replacement TABLE_NAMES = ( "ChannelNode", "TorrentState", "TorrentState_TrackerState", "ChannelPeer", "ChannelVote", "TrackerState", "Vsids") @@ -127,12 +126,12 @@ def convert_command(offset, batch_size): def do_migration(self): result = None # estimated duration in seconds of ChannelNode table copying time try: - with marking_corrupted_db(self.old_db_path): - old_table_columns = {} - for table_name in TABLE_NAMES: - old_table_columns[table_name] = get_table_columns(self.old_db_path, table_name) + old_table_columns = {} + for table_name in TABLE_NAMES: + old_table_columns[table_name] = get_table_columns(self.old_db_path, table_name) - with contextlib.closing(sqlite3.connect(self.new_db_path)) as connection, connection: + with contextlib.closing(sqlite_replacement.connect(self.new_db_path)) as connection: + with connection: cursor = connection.cursor() cursor.execute("PRAGMA journal_mode = OFF;") cursor.execute("PRAGMA synchronous = OFF;") @@ -235,7 +234,7 @@ def calc_progress(duration_now, duration_half=60.0): def get_table_columns(db_path, table_name): - with contextlib.closing(sqlite3.connect(db_path)) as connection, connection: + with contextlib.closing(sqlite_replacement.connect(db_path)) as connection, connection: cursor = connection.cursor() cursor.execute(f'SELECT * FROM {table_name} LIMIT 1') names = [description[0] for description in cursor.description] diff --git a/src/tribler/core/upgrade/tests/test_upgrader.py b/src/tribler/core/upgrade/tests/test_upgrader.py index 66dc7081a42..91d7bb6ab66 100644 --- a/src/tribler/core/upgrade/tests/test_upgrader.py +++ b/src/tribler/core/upgrade/tests/test_upgrader.py @@ -18,7 +18,7 @@ from tribler.core.upgrade.upgrade import TriblerUpgrader, catch_db_is_corrupted_exception, \ cleanup_noncompliant_channel_torrents from tribler.core.utilities.configparser import CallbackConfigParser -from tribler.core.utilities.pony_utils import DatabaseIsCorrupted +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted from tribler.core.utilities.utilities import random_infohash diff --git a/src/tribler/core/upgrade/upgrade.py b/src/tribler/core/upgrade/upgrade.py index 6912b42adc5..0a8d4e9f478 100644 --- a/src/tribler/core/upgrade/upgrade.py +++ b/src/tribler/core/upgrade/upgrade.py @@ -23,8 +23,9 @@ from tribler.core.upgrade.tags_to_knowledge.migration import MigrationTagsToKnowledge from tribler.core.upgrade.tags_to_knowledge.tags_db import TagDatabase from tribler.core.utilities.configparser import CallbackConfigParser +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted from tribler.core.utilities.path_util import Path -from tribler.core.utilities.pony_utils import DatabaseIsCorrupted, get_db_version +from tribler.core.utilities.pony_utils import get_db_version from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR @@ -73,17 +74,17 @@ def cleanup_noncompliant_channel_torrents(state_dir): def catch_db_is_corrupted_exception(upgrader_method): # This decorator applied for TriblerUpgrader methods. It suppresses and remembers the DatabaseIsCorrupted exception. # As a result, if one upgrade method raises an exception, the following upgrade methods are still executed. - - # The reason for this is the following: it is possible that one upgrade methods upgrades a database A, while - # the next upgrade method upgrades a database B. If a corruption detected in the database A, the database B still - # need to be upgraded. So we want to temporarily suppress DatabaseIsCorrupted exception until all upgrades are - # executed. - - # If an upgrade found the database to be corrupted, the database is marked as corrupted. Then, the next upgrade - # will rename the corrupted database file (this is handled by the get_db_version call) and immediately return - # because there is no database to upgrade. So, if one upgrade function detects the database corruption, all the - # following upgrade functions for this specific database will skip the actual upgrade. As a result, a new - # database with the current DB version will be created on the Tribler Core start. + # + # The reason for this is the following: it is possible that one upgrade method upgrades database A + # while the following upgrade method upgrades database B. If a corruption is detected in the database A, + # the database B still needs to be upgraded. So, we want to temporarily suppress the DatabaseIsCorrupted exception + # until all upgrades are executed. + # + # If an upgrade finds the database to be corrupted, the database is marked as corrupted. Then, the next upgrade + # will rename the corrupted database file (the get_db_version call handles this) and immediately return because + # there is no database to upgrade. So, if one upgrade function detects database corruption, all the following + # upgrade functions for this specific database will skip the actual upgrade. As a result, a new database with + # the current DB version will be created on the Tribler Core start. @wraps(upgrader_method) def new_method(*args, **kwargs): diff --git a/src/tribler/core/utilities/db_corruption_handling/__init__.py b/src/tribler/core/utilities/db_corruption_handling/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/tribler/core/utilities/db_corruption_handling/base.py b/src/tribler/core/utilities/db_corruption_handling/base.py new file mode 100644 index 00000000000..d4016bba6ea --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/base.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import logging +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from typing import Union + +logger = logging.getLogger('db_corruption_handling') + + +class DatabaseIsCorrupted(Exception): + pass + + +@contextmanager +def handling_malformed_db_error(db_filepath: Path): + # Used in all methods of Connection and Cursor classes where the database corruption error can occur + try: + yield + except Exception as e: + if _is_malformed_db_exception(e): + _mark_db_as_corrupted(db_filepath) + raise DatabaseIsCorrupted(str(db_filepath)) from e + raise + + +def handle_db_if_corrupted(db_filename: Union[str, Path]): + # Checks if the database is marked as corrupted and handles it by removing the database file and the marker file + db_path = Path(db_filename) + marker_path = get_corrupted_db_marker_path(db_path) + if marker_path.exists(): + _handle_corrupted_db(db_path) + + +def get_corrupted_db_marker_path(db_filepath: Path) -> Path: + return Path(str(db_filepath) + '.is_corrupted') + + +def _is_malformed_db_exception(exception): + return isinstance(exception, sqlite3.DatabaseError) and 'malformed' in str(exception) + + +def _mark_db_as_corrupted(db_filepath: Path): + # Creates a new `*.is_corrupted` marker file alongside the database file + marker_path = get_corrupted_db_marker_path(db_filepath) + marker_path.touch() + + +def _handle_corrupted_db(db_path: Path): + # Removes the database file and the marker file + if db_path.exists(): + logger.warning(f'Database file was marked as corrupted, removing it: {db_path}') + db_path.unlink() + + marker_path = get_corrupted_db_marker_path(db_path) + if marker_path.exists(): + logger.warning(f'Removing the corrupted database marker: {marker_path}') + marker_path.unlink() diff --git a/src/tribler/core/utilities/db_corruption_handling/sqlite_replacement.py b/src/tribler/core/utilities/db_corruption_handling/sqlite_replacement.py new file mode 100644 index 00000000000..12d530f32da --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/sqlite_replacement.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import sqlite3 +import sys +from pathlib import Path +from sqlite3 import DataError, DatabaseError, Error, IntegrityError, InterfaceError, InternalError, NotSupportedError, \ + OperationalError, ProgrammingError, Warning, sqlite_version_info # pylint: disable=unused-import, redefined-builtin + +from tribler.core.utilities.db_corruption_handling.base import handling_malformed_db_error + + +# This module serves as a replacement to the sqlite3 module and handles the case when the database is corrupted. +# It provides the `connect` function that should be used instead of `sqlite3.connect` and the `Cursor` and `Connection` +# classes that replaces `sqlite3.Cursor` and `sqlite3.Connection` classes respectively. If the `connect` function or +# any Connectoin or Cursor method is called and the database is corrupted, the database is marked as corrupted and +# the DatabaseIsCorrupted exception is raised. It should be handled by terminating the Tribler Core with the exit code +# EXITCODE_DATABASE_IS_CORRUPTED (99). After the Core restarts, the `handle_db_if_corrupted` function checks the +# presense of the database corruption marker and handles it by removing the database file and the corruption marker. +# After that, the database is recreated upon the next attempt to connect to it. + + +def connect(db_filename: str, **kwargs) -> sqlite3.Connection: + # Replaces the sqlite3.connect function + kwargs['factory'] = Connection + with handling_malformed_db_error(Path(db_filename)): + return sqlite3.connect(db_filename, **kwargs) + + +def _add_method_wrapper_that_handles_malformed_db_exception(cls, method_name: str): + # Creates a wrapper for the given method that handles the case when the database is corrupted + + def wrapper(self, *args, **kwargs): + with handling_malformed_db_error(self._db_filepath): # pylint: disable=protected-access + return getattr(super(cls, self), method_name)(*args, **kwargs) + + wrapper.__name__ = method_name + wrapper.is_wrapped = True # for testing purposes + setattr(cls, method_name, wrapper) + + +class Cursor(sqlite3.Cursor): + # Handles the case when the database is corrupted in all relevant methods. + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._db_filepath = self.connection._db_filepath + + +for method_name_ in ['execute', 'executemany', 'executescript', 'fetchall', 'fetchmany', 'fetchone', '__next__']: + _add_method_wrapper_that_handles_malformed_db_exception(Cursor, method_name_) + + + +class ConnectionBase(sqlite3.Connection): + # This class simplifies testing of the Connection class by allowing mocking of base class methods. + # Direct mocking of sqlite3.Connection methods is not possible because they are C functions. + + if sys.version_info < (3, 11): + def blobopen(self, *args, **kwargs) -> Blob: + raise NotImplementedError + + +class Connection(ConnectionBase): + # Handles the case when the database is corrupted in all relevant methods. + def __init__(self, db_filepath: str, *args, **kwargs): + super().__init__(db_filepath, *args, **kwargs) + self._db_filepath = Path(db_filepath) + + def cursor(self, factory=None) -> Cursor: + return super().cursor(factory or Cursor) + + def iterdump(self): + # Not implemented because it is not used in Tribler. + # Can be added later with an iterator class that handles the malformed db error during the iteration + raise NotImplementedError + + def blobopen(self, *args, **kwargs) -> Blob: # Works for Python >= 3.11 + with handling_malformed_db_error(self._db_filepath): + blob = super().blobopen(*args, **kwargs) + return Blob(blob, self._db_filepath) + + +for method_name_ in ['commit', 'execute', 'executemany', 'executescript', 'backup', '__enter__', '__exit__', + 'serialize', 'deserialize']: + _add_method_wrapper_that_handles_malformed_db_exception(Connection, method_name_) + + +class Blob: # For Python >= 3.11. Added now, so we do not forgot to add it later when upgrading to 3.11. + def __init__(self, blob, db_filepath: Path): + self._blob = blob + self._db_filepath = db_filepath + + +for method_name_ in ['close', 'read', 'write', 'seek', '__len__', '__enter__', '__exit__', '__getitem__', + '__setitem__']: + _add_method_wrapper_that_handles_malformed_db_exception(Blob, method_name_) diff --git a/src/tribler/core/utilities/db_corruption_handling/tests/__init__.py b/src/tribler/core/utilities/db_corruption_handling/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/tribler/core/utilities/db_corruption_handling/tests/conftest.py b/src/tribler/core/utilities/db_corruption_handling/tests/conftest.py new file mode 100644 index 00000000000..ea2f0bfa8b1 --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest + +from tribler.core.utilities.db_corruption_handling.sqlite_replacement import connect + + +@pytest.fixture(name='db_filepath') +def db_filepath_fixture(tmp_path): + return tmp_path / 'test.db' + + +@pytest.fixture(name='connection') +def connection_fixture(db_filepath): + connection = connect(str(db_filepath)) + yield connection + connection.close() diff --git a/src/tribler/core/utilities/db_corruption_handling/tests/test_base.py b/src/tribler/core/utilities/db_corruption_handling/tests/test_base.py new file mode 100644 index 00000000000..79448f2da0d --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/tests/test_base.py @@ -0,0 +1,57 @@ +import sqlite3 +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + + +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted, handle_db_if_corrupted, \ + handling_malformed_db_error + +malformed_error = sqlite3.DatabaseError('database disk image is malformed') + + +def test_handling_malformed_db_error__no_error(db_filepath): + # If no error is raised, the database should not be marked as corrupted + with handling_malformed_db_error(db_filepath): + pass + + assert not Path(str(db_filepath) + '.is_corrupted').exists() + + +def test_handling_malformed_db_error__malformed_error(db_filepath): + # Malformed database errors should be handled by marking the database as corrupted + with pytest.raises(DatabaseIsCorrupted): + with handling_malformed_db_error(db_filepath): + raise malformed_error + + assert Path(str(db_filepath) + '.is_corrupted').exists() + + +def test_handling_malformed_db_error__other_error(db_filepath): + # Other errors should not be handled like malformed database errors + class TestError(Exception): + pass + + with pytest.raises(TestError): + with handling_malformed_db_error(db_filepath): + raise TestError() + + assert not Path(str(db_filepath) + '.is_corrupted').exists() + + +def test_handle_db_if_corrupted__corrupted(db_filepath: Path): + # If the corruption marker is found, the corrupted database file is removed + marker_path = Path(str(db_filepath) + '.is_corrupted') + marker_path.touch() + + handle_db_if_corrupted(db_filepath) + assert not db_filepath.exists() + assert not marker_path.exists() + + +@patch('tribler.core.utilities.db_corruption_handling.base._handle_corrupted_db') +def test_handle_db_if_corrupted__not_corrupted(handle_corrupted_db: Mock, db_filepath: Path): + # If the corruption marker is not found, the handling of the database is not performed + handle_db_if_corrupted(db_filepath) + handle_corrupted_db.assert_not_called() diff --git a/src/tribler/core/utilities/db_corruption_handling/tests/test_sqlite_replacement.py b/src/tribler/core/utilities/db_corruption_handling/tests/test_sqlite_replacement.py new file mode 100644 index 00000000000..3b36b067a09 --- /dev/null +++ b/src/tribler/core/utilities/db_corruption_handling/tests/test_sqlite_replacement.py @@ -0,0 +1,77 @@ +import sqlite3 +from unittest.mock import Mock, patch + +import pytest + +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted +from tribler.core.utilities.db_corruption_handling.sqlite_replacement import Blob, Connection, \ + Cursor, _add_method_wrapper_that_handles_malformed_db_exception, connect + + +# pylint: disable=protected-access + + +malformed_error = sqlite3.DatabaseError('database disk image is malformed') + + +def test_connect(db_filepath): + connection = connect(str(db_filepath)) + assert isinstance(connection, Connection) + connection.close() + + +def test_make_method_that_handles_malformed_db_exception(db_filepath): + # Tests that the _make_method_that_handles_malformed_db_exception function creates a method that handles + # the malformed database exception + + class BaseClass: + method1 = Mock(return_value=Mock()) + + class TestClass(BaseClass): + _db_filepath = db_filepath + + _add_method_wrapper_that_handles_malformed_db_exception(TestClass, 'method1') + + # The method should be successfully wrapped + assert TestClass.method1.is_wrapped + assert TestClass.method1.__name__ == 'method1' + + test_instance = TestClass() + result = test_instance.method1(1, 2, x=3, y=4) + + # *args and **kwargs should be passed to the original method, and the result should be returned + BaseClass.method1.assert_called_once_with(1, 2, x=3, y=4) + assert result is BaseClass.method1.return_value + + # When the base method raises a malformed database exception, the DatabaseIsCorrupted exception should be raised + BaseClass.method1.side_effect = malformed_error + with pytest.raises(DatabaseIsCorrupted): + test_instance.method1(1, 2, x=3, y=4) + + +def test_connection_cursor(connection): + cursor = connection.cursor() + assert isinstance(cursor, Cursor) + + +def test_connection_iterdump(connection): + with pytest.raises(NotImplementedError): + connection.iterdump() + + +@patch('tribler.core.utilities.db_corruption_handling.sqlite_replacement.ConnectionBase.blobopen', + Mock(side_effect=malformed_error)) +def test_connection_blobopen__exception(connection): + with pytest.raises(DatabaseIsCorrupted): + connection.blobopen() + + +@patch('tribler.core.utilities.db_corruption_handling.sqlite_replacement.ConnectionBase.blobopen') +def test_connection_blobopen__no_exception(blobopen, connection): + blobopen.return_value = Mock() + result = connection.blobopen() + + blobopen.assert_called_once() + assert isinstance(result, Blob) + assert result._blob is blobopen.return_value + assert result._db_filepath == connection._db_filepath diff --git a/src/tribler/core/utilities/pony_utils.py b/src/tribler/core/utilities/pony_utils.py index 1683313efe5..3d4c7d30fd8 100644 --- a/src/tribler/core/utilities/pony_utils.py +++ b/src/tribler/core/utilities/pony_utils.py @@ -2,7 +2,6 @@ import contextlib import logging -import sqlite3 import sys import threading import time @@ -13,17 +12,25 @@ from operator import attrgetter from pathlib import Path from types import FrameType -from typing import Callable, Dict, Iterable, Optional, Type, Union +from typing import Callable, Dict, Iterable, Optional, Type from weakref import WeakSet -from contextlib import contextmanager - from pony import orm from pony.orm import core from pony.orm.core import Database, select from pony.orm.dbproviders import sqlite -from pony.orm.dbproviders.sqlite import SQLitePool -from pony.utils import absolutize_path, cut_traceback, cut_traceback_depth, localbase +from pony.utils import cut_traceback, localbase +from tribler.core.utilities.db_corruption_handling import sqlite_replacement +from tribler.core.utilities.db_corruption_handling.base import handle_db_if_corrupted + +# Inject sqlite replacement to PonyORM sqlite database provider to use augmented version of Connection and Cursor +# classes that handle database corruption errors. All connection and cursor methods, such as execute and fetchone, +# raise DatabaseIsCorrupted exception if the database is corrupted. Also, the marker file with ".is_corrupted" +# extension is created alongside the corrupted database file. As a result of exception, the Tribler Core immediately +# stops with the error code 99. Tribler GUI handles this error code by showing the message to the user and automatically +# restarting the Core. After the Core is restarted, the database is re-created from scratch. +sqlite.sqlite = sqlite_replacement + SLOW_DB_SESSION_DURATION_THRESHOLD = 1.0 @@ -34,11 +41,7 @@ StatDict = Dict[Optional[str], core.QueryStat] -class DatabaseIsCorrupted(Exception): - pass - - -def table_exists(cursor: sqlite3.Cursor, table_name: str) -> bool: +def table_exists(cursor: sqlite_replacement.Cursor, table_name: str) -> bool: cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) return cursor.fetchone() is not None @@ -48,14 +51,13 @@ def get_db_version(db_path, default: int = None) -> int: version = None if db_path.exists(): - with marking_corrupted_db(db_path): - with contextlib.closing(sqlite3.connect(db_path)) as connection: - with connection: - cursor = connection.cursor() - if table_exists(cursor, 'MiscData'): - cursor.execute("SELECT value FROM MiscData WHERE name == 'db_version'") - row = cursor.fetchone() - version = int(row[0]) if row else None + with contextlib.closing(sqlite_replacement.connect(db_path)) as connection: + with connection: + cursor = connection.cursor() + if table_exists(cursor, 'MiscData'): + cursor.execute("SELECT value FROM MiscData WHERE name == 'db_version'") + row = cursor.fetchone() + version = int(row[0]) if row else None if version is not None: return version @@ -66,53 +68,6 @@ def get_db_version(db_path, default: int = None) -> int: raise RuntimeError(f'The version value is not found in database {db_path}') -def handle_db_if_corrupted(db_filename: Union[str, Path]): - db_path = Path(db_filename) - marker_path = _get_corrupted_db_marker_path(db_path) - if marker_path.exists(): - _handle_corrupted_db(db_path) - - -def _handle_corrupted_db(db_path: Path): - if db_path.exists(): - logger.warning(f'Database file was marked as corrupted, removing it: {db_path}') - db_path.unlink() - - marker_path = _get_corrupted_db_marker_path(db_path) - if marker_path.exists(): - logger.warning(f'Removing the corrupted database marker: {marker_path}') - marker_path.unlink() - - -def _get_corrupted_db_marker_path(db_filename: Path) -> Path: - return Path(str(db_filename) + '.is_corrupted') - - -@contextmanager -def marking_corrupted_db(db_filename: Union[str, Path]): - try: - yield - except Exception as e: - if _is_malformed_db_exception(e): - db_path = Path(db_filename) - _mark_db_as_corrupted(db_path) - raise DatabaseIsCorrupted(str(db_path)) from e - raise - - -def _is_malformed_db_exception(exception): - return isinstance(exception, (core.DatabaseError, sqlite3.DatabaseError)) and 'malformed' in str(exception) - - -def _mark_db_as_corrupted(db_filename: Path): - if not db_filename.exists(): - raise RuntimeError(f'Corrupted database file not found: {db_filename}') - - marker_path = _get_corrupted_db_marker_path(db_filename) - marker_path.touch() - - - # pylint: disable=bad-staticmethod-argument def get_or_create(cls: Type[core.Entity], create_kwargs=None, **kwargs) -> core.Entity: """Get or create db entity. @@ -356,7 +311,6 @@ def _merge_stats(stats_iter: Iterable[StatDict]) -> StatDict: class TriblerSQLiteProvider(sqlite.SQLiteProvider): - pool: TriblerPool # It is impossible to override the __init__ method without breaking the `SQLiteProvider.get_pool` method's logic. # Therefore, we don't initialize a new attribute `_acquire_time` inside a class constructor method. @@ -384,36 +338,6 @@ def release_lock(self): lock_hold_duration = time.time() - acquire_time info.lock_hold_total_duration += lock_hold_duration - def set_transaction_mode(self, connection, cache): - with marking_corrupted_db(self.pool.filename): - return super().set_transaction_mode(connection, cache) - - def execute(self, cursor, sql, arguments=None, returning_id=False): - with marking_corrupted_db(self.pool.filename): - return super().execute(cursor, sql, arguments, returning_id) - - def mark_db_as_malformed(self): - filename = self.pool.filename - if not Path(filename).exists(): - raise RuntimeError(f'Corrupted database file not found: {filename!r}') - - marker_filename = filename + '.is_corrupted' - Path(marker_filename).touch() - - def get_pool(self, is_shared_memory_db, filename, create_db=False, **kwargs): - if not (is_shared_memory_db or filename == ':memory:'): - filename = absolutize_path(filename, frame_depth=cut_traceback_depth+5) # see the base method for details - handle_db_if_corrupted(filename) - return TriblerPool(is_shared_memory_db, filename, create_db, **kwargs) - - -class TriblerPool(SQLitePool): - # TriblerSQLiteProvider instantiates this class instead of a standard SQLitePool class. It allows to catch - # the "database is malformed" error when new connection is establishing to the database from the ORM - def _connect(self): - with marking_corrupted_db(self.filename): - return super()._connect() - db_session = TriblerDbSession() orm.db_session = orm.core.db_session = db_session @@ -421,7 +345,7 @@ def _connect(self): class TriblerDatabase(Database): # TriblerDatabase extends the functionality of the Database class in the following ways: - # * It adds handling of DatabaseError when the database file is corrupted + # * It adds handling the case when the database file is corrupted # * It accumulates and shows statistics on slow database queries def __init__(self): @@ -444,10 +368,6 @@ def bind(self, **kwargs): self._bind(TriblerSQLiteProvider, **kwargs) - def call_on_connect(self, con): - if self.provider is not None: - with marking_corrupted_db(self.provider.pool.filename): - super().call_on_connect(con) def track_slow_db_sessions(): TriblerDbSession.track_slow_db_sessions = True diff --git a/src/tribler/core/utilities/tests/test_pony_utils.py b/src/tribler/core/utilities/tests/test_pony_utils.py index f6b70393fcd..99b3e30fdcb 100644 --- a/src/tribler/core/utilities/tests/test_pony_utils.py +++ b/src/tribler/core/utilities/tests/test_pony_utils.py @@ -1,14 +1,12 @@ import sqlite3 from pathlib import Path -from unittest.mock import patch, Mock +from unittest.mock import patch import pytest from pony.orm.core import QueryStat, Required from tribler.core.utilities import pony_utils -from tribler.core.utilities.pony_utils import DatabaseIsCorrupted, _mark_db_as_corrupted, get_db_version, \ - handle_db_if_corrupted, \ - marking_corrupted_db, table_exists +from tribler.core.utilities.pony_utils import get_db_version, table_exists EMPTY_DICT = {} @@ -133,41 +131,6 @@ def db_path_fixture(tmp_path: Path): return db_path -@patch('tribler.core.utilities.pony_utils._handle_corrupted_db') -def test_handle_db_if_corrupted__not_corrupted(handle_corrupted_db: Mock, db_path: Path): - # If the corruption marker is not found, the handling of the database is not performed - handle_db_if_corrupted(db_path) - handle_corrupted_db.assert_not_called() - - -def test_handle_db_if_corrupted__corrupted(db_path: Path): - # If the corruption marker is found, the corrupted database file is removed - marker_path = Path(str(db_path) + '.is_corrupted') - marker_path.touch() - - handle_db_if_corrupted(db_path) - assert not db_path.exists() - assert not marker_path.exists() - - -def test_marking_corrupted_db__not_malformed(db_path: Path): - # When the context manger encounters an exception not related to database corruption, it does nothing - with pytest.raises(ZeroDivisionError): - with marking_corrupted_db(db_path): - raise ZeroDivisionError() - - assert not Path(str(db_path) + '.is_corrupted').exists() - - -def test_marking_corrupted_db__malformed(db_path: Path): - # When the context manger encounters an exception not related to database corruption, it adds a corruption marker - with pytest.raises(DatabaseIsCorrupted): - with marking_corrupted_db(db_path): - raise sqlite3.DatabaseError('database disk image is malformed') - - assert Path(str(db_path) + '.is_corrupted').exists() - - def test_get_db_version__db_does_not_exist(tmp_path: Path): # When the database does not exist, the call to get_db_version generates RuntimeError db_path = tmp_path / 'doesnotexist.db' @@ -237,13 +200,3 @@ def test_get_db_version__corrupted_db(tmp_path: Path): with sqlite3.connect(db_path) as connection: assert not table_exists(connection.cursor(), 'MiscData') - - -def test_mark_db_as_corrupted_file_does_not_exist(tmp_path: Path): - # The database file apparently was corrupted, and `_mark_db_as_corrupted(db_path)` is called to mark it as such. - # But the function was not able to find the database file at the specified path. In this unnormal situation, - # it raises an exception. - - db_path = tmp_path / 'doesnotexist.db' - with pytest.raises(RuntimeError, match='^Corrupted database file not found: .*doesnotexist.db$'): - _mark_db_as_corrupted(db_path) diff --git a/src/tribler/gui/upgrade_manager.py b/src/tribler/gui/upgrade_manager.py index 769d7878748..fea476870ac 100644 --- a/src/tribler/gui/upgrade_manager.py +++ b/src/tribler/gui/upgrade_manager.py @@ -11,7 +11,7 @@ from tribler.core.config.tribler_config import TriblerConfig from tribler.core.upgrade.upgrade import TriblerUpgrader from tribler.core.upgrade.version_manager import TriblerVersion, VersionHistory, NoDiskSpaceAvailableError -from tribler.core.utilities.pony_utils import DatabaseIsCorrupted +from tribler.core.utilities.db_corruption_handling.base import DatabaseIsCorrupted from tribler.gui.defs import BUTTON_TYPE_NORMAL, CORRUPTED_DB_WAS_FIXED_MESSAGE, NO_DISK_SPACE_ERROR_MESSAGE, \ UPGRADE_CANCELLED_ERROR_TITLE from tribler.gui.dialogs.confirmationdialog import ConfirmationDialog