Skip to content

Commit

Permalink
Remove await statements from with db_session blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Nov 15, 2022
1 parent 2bcac99 commit f33d03a
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from asyncio import Future
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from ipv8.util import succeed

from pony.orm import db_session

import pytest

from tribler.core.components.gigachannel_manager.gigachannel_manager import GigaChannelManager
from tribler.core.components.libtorrent.torrentdef import TorrentDef
from tribler.core.components.metadata_store.db.orm_bindings.channel_node import NEW
Expand Down Expand Up @@ -113,32 +111,40 @@ def test_updated_my_channel(personal_channel, gigachannel_manager, tmpdir):
gigachannel_manager.download_manager.start_download.assert_called_once()


async def test_check_and_regen_personal_channel_torrent(personal_channel, gigachannel_manager):
async def test_check_and_regen_personal_channel_torrent_wait(personal_channel, gigachannel_manager):
with db_session:
chan_pk, chan_id = personal_channel.public_key, personal_channel.id_
chan_download = MagicMock()

async def mock_wait(*_):
pass
# Test wait for status OK
await gigachannel_manager.check_and_regen_personal_channel_torrent(
channel_pk=chan_pk,
channel_id=chan_id,
channel_download=MagicMock(wait_for_status=AsyncMock()),
timeout=0.5
)

chan_download.wait_for_status = mock_wait
# Test wait for status OK
await gigachannel_manager.check_and_regen_personal_channel_torrent(chan_pk, chan_id, chan_download, timeout=0.5)

async def mock_wait_2(*_):
await asyncio.sleep(3)
async def test_check_and_regen_personal_channel_torrent_sleep(personal_channel, gigachannel_manager):
with db_session:
chan_pk, chan_id = personal_channel.public_key, personal_channel.id_

chan_download.wait_for_status = mock_wait_2
# Test timeout waiting for seeding state and then regen
async def mock_wait(*_):
await asyncio.sleep(3)

f = MagicMock()
f = MagicMock()

async def mock_regen(*_):
f()
async def mock_regen(*_):
f()

gigachannel_manager.regenerate_channel_torrent = mock_regen
await gigachannel_manager.check_and_regen_personal_channel_torrent(chan_pk, chan_id, chan_download, timeout=0.5)
f.assert_called_once()
with patch.object(GigaChannelManager, 'regenerate_channel_torrent', mock_regen):
# Test timeout waiting for seeding state and then regen
await gigachannel_manager.check_and_regen_personal_channel_torrent(
channel_pk=chan_pk,
channel_id=chan_id,
channel_download=MagicMock(wait_for_status=mock_wait),
timeout=0.5
)
f.assert_called_once()


async def test_check_channels_updates(personal_channel, gigachannel_manager, metadata_store):
Expand Down Expand Up @@ -220,10 +226,10 @@ def mock_process_channel_dir(c, _):

# Manually fire the channel updates checking routine
gigachannel_manager.check_channels_updates()
await gigachannel_manager.process_queued_channels()
await gigachannel_manager.process_queued_channels()

# The queue should be empty afterwards
assert not gigachannel_manager.channels_processing_queue
# The queue should be empty afterwards
assert not gigachannel_manager.channels_processing_queue


async def test_remove_cruft_channels(torrent_template, personal_channel, gigachannel_manager, metadata_store):
Expand Down Expand Up @@ -319,7 +325,7 @@ def mock_remove(infohash, remove_content=False):


async def test_reject_malformed_channel(
gigachannel_manager, metadata_store
gigachannel_manager, metadata_store
): # pylint: disable=unused-argument, redefined-outer-name
global initiated_download
with db_session:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ async def test_get_contents_count(add_fake_torrents_channels, mock_dlmgr, rest_a
mock_dlmgr.get_download = lambda _: None
with db_session:
chan = metadata_store.ChannelMetadata.select().first()
json_dict = await do_request(rest_api, f'channels/{hexlify(chan.public_key)}/123?include_total=1')

json_dict = await do_request(rest_api, f'channels/{hexlify(chan.public_key)}/123?include_total=1')
assert json_dict['total'] == 5


Expand Down Expand Up @@ -311,12 +312,12 @@ async def test_get_channel_contents_by_type(metadata_store, my_channel, mock_dlm
with db_session:
metadata_store.CollectionNode(title='some_folder', origin_id=my_channel.id_)

json_dict = await do_request(
rest_api,
'channels/%s/%d?metadata_type=%d&metadata_type=%d'
% (hexlify(my_channel.public_key), my_channel.id_, COLLECTION_NODE, REGULAR_TORRENT),
expected_code=200,
)
json_dict = await do_request(
rest_api,
'channels/%s/%d?metadata_type=%d&metadata_type=%d'
% (hexlify(my_channel.public_key), my_channel.id_, COLLECTION_NODE, REGULAR_TORRENT),
expected_code=200,
)

assert len(json_dict['results']) == 10
assert 'status' in json_dict['results'][0]
Expand Down Expand Up @@ -431,12 +432,12 @@ async def test_add_torrents_no_channel(metadata_store, my_channel, rest_api):
with db_session:
my_chan = metadata_store.ChannelMetadata.get_my_channels().first()
my_chan.delete()
await do_request(
rest_api,
f'channels/{hexlify(my_channel.public_key)}/{my_channel.id_}/torrents',
request_type='PUT',
expected_code=404,
)
await do_request(
rest_api,
f'channels/{hexlify(my_channel.public_key)}/{my_channel.id_}/torrents',
request_type='PUT',
expected_code=404,
)


async def test_add_torrents_no_dir(my_channel, rest_api):
Expand Down Expand Up @@ -516,17 +517,17 @@ async def test_add_torrent_duplicate(my_channel, rest_api):
tdef = TorrentDef.load(TORRENT_UBUNTU_FILE)
my_channel.add_torrent_to_channel(tdef, {'description': 'blabla'})

with open(TORRENT_UBUNTU_FILE, "rb") as torrent_file:
base64_content = base64.b64encode(torrent_file.read()).decode('utf-8')
with open(TORRENT_UBUNTU_FILE, "rb") as torrent_file:
base64_content = base64.b64encode(torrent_file.read()).decode('utf-8')

post_params = {'torrent': base64_content}
await do_request(
rest_api,
f'channels/{hexlify(my_channel.public_key)}/{my_channel.id_}/torrents',
request_type='PUT',
post_data=post_params,
expected_code=200,
)
post_params = {'torrent': base64_content}
await do_request(
rest_api,
f'channels/{hexlify(my_channel.public_key)}/{my_channel.id_}/torrents',
request_type='PUT',
post_data=post_params,
expected_code=200,
)


async def test_add_torrent(my_channel, rest_api):
Expand All @@ -537,12 +538,12 @@ async def test_add_torrent(my_channel, rest_api):
base64_content = base64.b64encode(torrent_file.read())

post_params = {'torrent': base64_content.decode('utf-8')}
await do_request(
rest_api,
f'channels/{hexlify(my_channel.public_key)}/{my_channel.id_}/torrents',
request_type='PUT',
post_data=post_params,
)
await do_request(
rest_api,
f'channels/{hexlify(my_channel.public_key)}/{my_channel.id_}/torrents',
request_type='PUT',
post_data=post_params,
)


async def test_add_torrent_invalid_uri(my_channel, rest_api):
Expand Down Expand Up @@ -696,10 +697,11 @@ async def test_get_channel_thumbnail(rest_api, metadata_store):
)
endpoint = f'channels/{hexlify(chan.public_key)}/{chan.id_}/thumbnail'
url = f'/{endpoint}'
async with rest_api.request("GET", url, ssl=False) as response:
assert response.status == 200
assert await response.read() == PNG_DATA
assert response.headers["Content-Type"] == "image/png"

async with rest_api.request("GET", url, ssl=False) as response:
assert response.status == 200
assert await response.read() == PNG_DATA
assert response.headers["Content-Type"] == "image/png"


async def test_get_my_channel_tags(metadata_store, mock_dlmgr_get_download, my_channel,
Expand Down Expand Up @@ -735,12 +737,12 @@ async def test_get_my_channel_tags_xxx(metadata_store, knowledge_db, mock_dlmgr_
tags = ["totally safe", "wrongterm", "wRonGtErM", "a wrongterm b"]
tag_torrent(infohash, knowledge_db, tags=tags)

json_dict = await do_request(
rest_api,
'channels/%s/%d?metadata_type=%d&hide_xxx=1'
% (hexlify(my_channel.public_key), chan.id_, REGULAR_TORRENT),
expected_code=200,
)
json_dict = await do_request(
rest_api,
'channels/%s/%d?metadata_type=%d&hide_xxx=1'
% (hexlify(my_channel.public_key), chan.id_, REGULAR_TORRENT),
expected_code=200,
)

assert len(json_dict['results']) == 1
print(json_dict)
Expand Down

0 comments on commit f33d03a

Please sign in to comment.