Skip to content

Commit

Permalink
Merge pull request #7150 from drew2a/fix/run_threaded
Browse files Browse the repository at this point in the history
Extract the `run_threaded` function to the `pony_utils.py`
  • Loading branch information
drew2a committed Nov 8, 2022
2 parents 3c77811 + 2d79443 commit 29292a4
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 101 deletions.
2 changes: 1 addition & 1 deletion scripts/seedbox/disseminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def commit(self):
def flush(self):
_logger.debug('Flush')

self.community.mds._db.flush() # pylint: disable=protected-access
self.community.mds.db.flush()


class Service(TinyTriblerService):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tribler.core.components.metadata_store.db.serialization import CHANNEL_TORRENT
from tribler.core.components.metadata_store.db.store import MetadataStore
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.simpledefs import DLSTATUS_SEEDING, NTFY
from tribler.core.utilities.unicode import hexlify

Expand Down Expand Up @@ -283,7 +284,7 @@ def _process_download():
mds.process_channel_dir(channel_dirname, channel.public_key, channel.id_, external_thread=True)

try:
await mds.run_threaded(_process_download)
await run_threaded(mds.db, _process_download)
except Exception as e: # pylint: disable=broad-except # pragma: no cover
self._logger.error("Error when processing channel dir download: %s", e)

Expand Down
85 changes: 33 additions & 52 deletions src/tribler/core/components/metadata_store/db/store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import logging
import re
import threading
from asyncio import get_event_loop
from datetime import datetime, timedelta
from time import sleep, time
from typing import Optional, Union

from lz4.frame import LZ4FrameDecompressor

from pony import orm
from pony.orm import db_session, desc, left_join, raw_sql, select
from pony.orm.dbproviders.sqlite import keep_exception
Expand Down Expand Up @@ -50,12 +47,11 @@
from tribler.core.exceptions import InvalidSignatureException
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.pony_utils import get_max, get_or_create
from tribler.core.utilities.pony_utils import get_max, get_or_create, run_threaded
from tribler.core.utilities.search_utils import torrent_rank
from tribler.core.utilities.unicode import hexlify
from tribler.core.utilities.utilities import MEMORY_DB


BETA_DB_VERSIONS = [0, 1, 2, 3, 4, 5]
CURRENT_DB_VERSION = 14

Expand Down Expand Up @@ -164,12 +160,12 @@ def __init__(
# We have to dynamically define/init ORM-managed entities here to be able to support
# multiple sessions in Tribler. ORM-managed classes are bound to the database instance
# at definition.
self._db = orm.Database()
self.db = orm.Database()

# This attribute is internally called by Pony on startup, though pylint cannot detect it
# with the static analysis.
# pylint: disable=unused-variable
@self._db.on_connect(provider='sqlite')
@self.db.on_connect(provider='sqlite')
def on_connect(_, connection):
cursor = connection.cursor()
cursor.execute("PRAGMA journal_mode = WAL")
Expand All @@ -189,31 +185,31 @@ def on_connect(_, connection):

# pylint: enable=unused-variable

self.MiscData = misc.define_binding(self._db)
self.MiscData = misc.define_binding(self.db)

self.TrackerState = tracker_state.define_binding(self._db)
self.TorrentState = torrent_state.define_binding(self._db)
self.TrackerState = tracker_state.define_binding(self.db)
self.TorrentState = torrent_state.define_binding(self.db)

self.ChannelNode = channel_node.define_binding(self._db, logger=self._logger, key=my_key)
self.ChannelNode = channel_node.define_binding(self.db, logger=self._logger, key=my_key)

self.MetadataNode = metadata_node.define_binding(self._db)
self.CollectionNode = collection_node.define_binding(self._db)
self.MetadataNode = metadata_node.define_binding(self.db)
self.CollectionNode = collection_node.define_binding(self.db)
self.TorrentMetadata = torrent_metadata.define_binding(
self._db,
self.db,
notifier=notifier,
tag_processor_version=tag_processor_version
)
self.ChannelMetadata = channel_metadata.define_binding(self._db)
self.ChannelMetadata = channel_metadata.define_binding(self.db)

self.JsonNode = json_node.define_binding(self._db, db_version)
self.ChannelDescription = channel_description.define_binding(self._db)
self.JsonNode = json_node.define_binding(self.db, db_version)
self.ChannelDescription = channel_description.define_binding(self.db)

self.BinaryNode = binary_node.define_binding(self._db, db_version)
self.ChannelThumbnail = channel_thumbnail.define_binding(self._db)
self.BinaryNode = binary_node.define_binding(self.db, db_version)
self.ChannelThumbnail = channel_thumbnail.define_binding(self.db)

self.ChannelVote = channel_vote.define_binding(self._db)
self.ChannelPeer = channel_peer.define_binding(self._db)
self.Vsids = vsids.define_binding(self._db)
self.ChannelVote = channel_vote.define_binding(self.db)
self.ChannelPeer = channel_peer.define_binding(self.db)
self.Vsids = vsids.define_binding(self.db)

self.ChannelMetadata._channels_dir = channels_dir # pylint: disable=protected-access

Expand All @@ -224,13 +220,13 @@ def on_connect(_, connection):
create_db = not db_filename.is_file()
db_path_string = str(db_filename)

self._db.bind(provider='sqlite', filename=db_path_string, create_db=create_db, timeout=120.0)
self._db.generate_mapping(
self.db.bind(provider='sqlite', filename=db_path_string, create_db=create_db, timeout=120.0)
self.db.generate_mapping(
create_tables=create_db, check_tables=check_tables
) # Must be run out of session scope
if create_db:
with db_session(ddl=True):
self._db.execute(sql_create_fts_table)
self.db.execute(sql_create_fts_table)
self.create_fts_triggers()
self.create_torrentstate_triggers()
self.create_partial_indexes()
Expand All @@ -245,15 +241,6 @@ def on_connect(_, connection):
default_vsids = self.Vsids.create_default_vsids()
self.ChannelMetadata.votes_scaling = default_vsids.max_val

async def run_threaded(self, func, *args, **kwargs):
def wrapper():
try:
return func(*args, **kwargs)
finally:
self.disconnect_thread()

return await get_event_loop().run_in_executor(None, wrapper)

def set_value(self, key: str, value: str):
key_value = get_or_create(self.MiscData, name=key)
key_value.value = value
Expand All @@ -263,14 +250,14 @@ def get_value(self, key: str, default: Optional[str] = None) -> Optional[str]:
return data.value if data else default

def drop_indexes(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute("select name from sqlite_master where type='index' and name like 'idx_%'")
for [index_name] in cursor.fetchall():
cursor.execute(f"drop index {index_name}")

def get_objects_to_create(self):
connection = self._db.get_connection()
schema = self._db.schema
connection = self.db.get_connection()
schema = self.db.schema
provider = schema.provider
created_tables = set()
result = []
Expand All @@ -284,28 +271,28 @@ def get_db_file_size(self):
return 0 if self.db_path is MEMORY_DB else Path(self.db_path).size()

def drop_fts_triggers(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute("select name from sqlite_master where type='trigger' and name like 'fts_%'")
for [trigger_name] in cursor.fetchall():
cursor.execute(f"drop trigger {trigger_name}")

def create_fts_triggers(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute(sql_add_fts_trigger_insert)
cursor.execute(sql_add_fts_trigger_delete)
cursor.execute(sql_add_fts_trigger_update)

def fill_fts_index(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute("insert into FtsIndex(rowid, title) select rowid, title from ChannelNode")

def create_torrentstate_triggers(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute(sql_add_torrentstate_trigger_after_insert)
cursor.execute(sql_add_torrentstate_trigger_after_update)

def create_partial_indexes(self):
cursor = self._db.get_connection().cursor()
cursor = self.db.get_connection().cursor()
cursor.execute(sql_create_partial_index_channelnode_subscribed)
cursor.execute(sql_create_partial_index_channelnode_metadata_type)

Expand All @@ -332,13 +319,7 @@ def vote_bump(self, public_key, id_, voter_pk):

def shutdown(self):
self._shutting_down = True
self._db.disconnect()

def disconnect_thread(self):
# Ugly workaround for closing threadpool connections
# Remark: maybe subclass ThreadPoolExecutor to handle this automatically?
if not isinstance(threading.current_thread(), threading._MainThread): # pylint: disable=W0212
self._db.disconnect()
self.db.disconnect()

@staticmethod
def get_list_of_channel_blobs_to_process(dirname, start_timestamp):
Expand Down Expand Up @@ -467,7 +448,7 @@ def process_mdblob_file(self, filepath, **kwargs):

async def process_compressed_mdblob_threaded(self, compressed_data, **kwargs):
try:
return await self.run_threaded(self.process_compressed_mdblob, compressed_data, **kwargs)
return await run_threaded(self.db, self.process_compressed_mdblob, compressed_data, **kwargs)
except Exception as e: # pylint: disable=broad-except # pragma: no cover
self._logger.warning("DB transaction error when tried to process compressed mdblob: %s", str(e))
return None
Expand Down Expand Up @@ -787,7 +768,7 @@ def get_entries_query(
return pony_query

async def get_entries_threaded(self, **kwargs):
return await self.run_threaded(self.get_entries, **kwargs)
return await run_threaded(self.db, self.get_entries, **kwargs)

@db_session
def get_entries(self, first=1, last=None, **kwargs):
Expand Down Expand Up @@ -838,7 +819,7 @@ def get_auto_complete_terms(self, text, max_terms, limit=10):
suggestion_re = re.compile(suggestion_pattern, re.UNICODE)

with db_session:
titles = self._db.select("""
titles = self.db.select("""
cn.title
FROM ChannelNode cn
INNER JOIN FtsIndex ON cn.rowid = FtsIndex.rowid
Expand Down
25 changes: 12 additions & 13 deletions src/tribler/core/components/metadata_store/db/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from datetime import datetime
from unittest.mock import patch

import pytest
from ipv8.keyvault.crypto import default_eccrypto

from pony.orm import db_session

import pytest

from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import (
CHANNEL_DIR_NAME_LENGTH,
entries_to_chunk,
Expand All @@ -29,8 +27,10 @@
from tribler.core.components.metadata_store.tests.test_channel_download import CHANNEL_METADATA_UPDATED
from tribler.core.tests.tools.common import TESTS_DATA_DIR
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.utilities import random_infohash


# pylint: disable=protected-access,unused-argument


Expand Down Expand Up @@ -269,13 +269,12 @@ def test_process_forbidden_payload(metadata_store):
def test_process_payload(metadata_store):
sender_key = default_eccrypto.generate_key("curve25519")
for md_class in (
metadata_store.ChannelMetadata,
metadata_store.TorrentMetadata,
metadata_store.CollectionNode,
metadata_store.ChannelDescription,
metadata_store.ChannelThumbnail,
metadata_store.ChannelMetadata,
metadata_store.TorrentMetadata,
metadata_store.CollectionNode,
metadata_store.ChannelDescription,
metadata_store.ChannelThumbnail,
):

node, node_payload, node_deleted_payload = get_payloads(md_class, sender_key)
node_dict = node.to_dict()
node.delete()
Expand Down Expand Up @@ -333,8 +332,8 @@ def test_process_payload_with_known_channel_public_key(metadata_store):

# Check accepting a payload with matching public key
assert (
metadata_store.process_payload(payload, channel_public_key=key1.pub().key_to_bin()[10:])[0].obj_state
== ObjState.NEW_OBJECT
metadata_store.process_payload(payload, channel_public_key=key1.pub().key_to_bin()[10:])[0].obj_state
== ObjState.NEW_OBJECT
)
assert metadata_store.TorrentMetadata.get()

Expand Down Expand Up @@ -465,8 +464,8 @@ def f1(a, b, *, c, d):
return threading.get_ident()
raise ThreadedTestException('test exception')

result = await metadata_store.run_threaded(f1, 1, 2, c=3, d=4)
result = await run_threaded(metadata_store.db, f1, 1, 2, c=3, d=4)
assert result != thread_id

with pytest.raises(ThreadedTestException, match='^test exception$'):
await metadata_store.run_threaded(f1, 1, 2, c=5, d=6)
await run_threaded(metadata_store.db, f1, 1, 2, c=5, d=6)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tribler.core.components.metadata_store.utils import RequestTimeoutException
from tribler.core.components.knowledge.community.knowledge_validator import is_valid_resource
from tribler.core.components.knowledge.db.knowledge_db import ResourceType
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.unicode import hexlify

BINARY_FIELDS = ("infohash", "channel_pk")
Expand Down Expand Up @@ -213,12 +214,13 @@ async def process_rpc_query(self, sanitized_parameters: Dict[str, Any]) -> List:
:raises ValueError: if no JSON could be decoded.
:raises pony.orm.dbapiprovider.OperationalError: if an illegal query was performed.
"""
# tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter
tags = sanitized_parameters.pop('tags', None)
if self.knowledge_db:
# tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter
tags = sanitized_parameters.pop('tags', None)

infohash_set = await self.mds.run_threaded(self.search_for_tags, tags)
if infohash_set:
sanitized_parameters['infohash_set'] = {bytes.fromhex(s) for s in infohash_set}
infohash_set = await run_threaded(self.knowledge_db.instance, self.search_for_tags, tags)
if infohash_set:
sanitized_parameters['infohash_set'] = {bytes.fromhex(s) for s in infohash_set}

return await self.mds.get_entries_threaded(**sanitized_parameters)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tribler.core.components.metadata_store.restapi.metadata_schema import MetadataParameters, MetadataSchema
from tribler.core.components.restapi.rest.rest_endpoint import HTTP_BAD_REQUEST, RESTResponse
from tribler.core.components.knowledge.db.knowledge_db import ResourceType
from tribler.core.utilities.pony_utils import run_threaded
from tribler.core.utilities.utilities import froze_it

SNIPPETS_TO_SHOW = 3 # The number of snippets we return from the search results
Expand Down Expand Up @@ -151,7 +152,7 @@ def search_db():
if infohash_set:
sanitized['infohash_set'] = {bytes.fromhex(s) for s in infohash_set}

search_results, total, max_rowid = await mds.run_threaded(search_db)
search_results, total, max_rowid = await run_threaded(mds.db, search_db)
except Exception as e: # pylint: disable=broad-except; # pragma: no cover
self._logger.exception("Error while performing DB search: %s: %s", type(e).__name__, e)
return RESTResponse(status=HTTP_BAD_REQUEST)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
from pathlib import Path
from unittest.mock import Mock, patch

import pytest
from ipv8.keyvault.crypto import default_eccrypto

from lz4.frame import LZ4FrameDecompressor

from pony.orm import ObjectNotFound, db_session

import pytest

from tribler.core.components.libtorrent.torrentdef import TorrentDef
from tribler.core.components.metadata_store.db.orm_bindings.channel_metadata import (
CHANNEL_DIR_NAME_LENGTH,
Expand All @@ -31,6 +28,7 @@
from tribler.core.utilities.simpledefs import CHANNEL_STATE
from tribler.core.utilities.utilities import random_infohash


# pylint: disable=protected-access


Expand Down Expand Up @@ -80,7 +78,7 @@ def mds_with_some_torrents_fixture(metadata_store):
# torrent6 aaa zzz

def save():
metadata_store._db.flush() # pylint: disable=W0212
metadata_store.db.flush()

def new_channel(**kwargs):
params = dict(subscribed=True, share=True, status=NEW, infohash=random_infohash())
Expand Down
Loading

0 comments on commit 29292a4

Please sign in to comment.