Skip to content

Commit

Permalink
Fix asyncio.create_task() calls
Browse files Browse the repository at this point in the history
  • Loading branch information
drew2a committed Feb 27, 2023
1 parent 77d44a1 commit 0b5c04d
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 34 deletions.
5 changes: 3 additions & 2 deletions scripts/experiments/tunnel_community/hidden_peer_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def __init__(self, *args, **kwargs):
self.register_task('_graceful_shutdown', self._graceful_shutdown, delay=EXPERIMENT_RUN_TIME)

def _graceful_shutdown(self):
task = asyncio.create_task(self.on_tribler_shutdown())
task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self))
tasks = self.async_group.add(self.on_tribler_shutdown())
shutdown_task = tasks[0]
shutdown_task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self))

async def on_tribler_shutdown(self):
await self.shutdown_task_manager()
Expand Down
5 changes: 3 additions & 2 deletions scripts/experiments/tunnel_community/speed_test_exit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def __init__(self, *args, **kwargs):
self.output_file = 'speed_test_exit.txt'

def _graceful_shutdown(self):
task = asyncio.create_task(self.on_tribler_shutdown())
task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self))
tasks = self.async_group.add(self.on_tribler_shutdown())
shutdown_task = tasks[0]
shutdown_task.add_done_callback(lambda result: TinyTriblerService._graceful_shutdown(self))

async def on_tribler_shutdown(self):
await self.shutdown_task_manager()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from asyncio import create_task
from binascii import unhexlify

from aiohttp import ContentTypeError, web
Expand Down Expand Up @@ -225,5 +224,6 @@ async def get_torrent_health(self, request):
return RESTResponse({"error": f"Error processing timeout parameter: {e}"}, status=HTTP_BAD_REQUEST)

infohash = unhexlify(request.match_info['infohash'])
create_task(self.torrent_checker.check_torrent_health(infohash, timeout=timeout, scrape_now=True))
check_coro = self.torrent_checker.check_torrent_health(infohash, timeout=timeout, scrape_now=True)
self.async_group.add(check_coro)
return RESTResponse({'checking': '1'})
20 changes: 8 additions & 12 deletions src/tribler/core/components/restapi/rest/events_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import time
from asyncio import CancelledError
Expand All @@ -8,7 +9,6 @@
from aiohttp_apispec import docs
from ipv8.REST.schema import schema
from ipv8.messaging.anonymization.tunnel import Circuit
from ipv8.taskmanager import TaskManager, task
from marshmallow.fields import Dict, String

from tribler.core import notifications
Expand Down Expand Up @@ -39,16 +39,15 @@ def passthrough(x):


@froze_it
class EventsEndpoint(RESTEndpoint, TaskManager):
class EventsEndpoint(RESTEndpoint):
"""
Important events in Tribler are returned over the events endpoint. This connection is held open. Each event is
pushed over this endpoint in the form of a JSON dictionary. Each JSON dictionary contains a type field that
indicates the type of the event. Individual events are separated by a newline character.
"""

def __init__(self, notifier: Notifier, public_key: str = None):
RESTEndpoint.__init__(self)
TaskManager.__init__(self)
super().__init__()
self.events_responses: List[RESTStreamResponse] = []
self.app.on_shutdown.append(self.on_shutdown)
self.undelivered_error: Optional[dict] = None
Expand All @@ -59,7 +58,8 @@ def __init__(self, notifier: Notifier, public_key: str = None):

def on_notification(self, topic, *args, **kwargs):
if topic in topics_to_send_to_gui:
self.write_data({"topic": topic.__name__, "args": args, "kwargs": kwargs})
data = {"topic": topic.__name__, "args": args, "kwargs": kwargs}
self.async_group.add(self.write_data(data))

def on_circuit_removed(self, circuit: Circuit, additional_info: str):
# The original notification contains non-JSON-serializable argument, so we send another one to GUI
Expand All @@ -69,10 +69,7 @@ def on_circuit_removed(self, circuit: Circuit, additional_info: str):
additional_info=additional_info)

async def on_shutdown(self, _):
await self.shutdown_task_manager()

async def shutdown(self):
await self.shutdown_task_manager()
await self.shutdown()

def setup_routes(self):
self.app.add_routes([web.get('', self.get_events)])
Expand Down Expand Up @@ -101,7 +98,6 @@ def encode_message(self, message: dict) -> bytes:
def has_connection_to_gui(self):
return bool(self.events_responses)

@task
async def write_data(self, message):
"""
Write data over the event socket if it's open.
Expand All @@ -124,7 +120,7 @@ async def write_data(self, message):
def on_tribler_exception(self, reported_error: ReportedError):
message = self.error_message(reported_error)
if self.has_connection_to_gui():
self.write_data(message)
self.async_group.add(self.write_data(message))
elif not self.undelivered_error:
# If there are several undelivered errors, we store the first error as more important and skip other
self.undelivered_error = message
Expand Down Expand Up @@ -170,7 +166,7 @@ async def get_events(self, request):

try:
while True:
await self.register_anonymous_task('event_sleep', lambda: None, delay=3600)
await asyncio.sleep(3600)
except CancelledError:
self.events_responses.remove(response)
return response
29 changes: 27 additions & 2 deletions src/tribler/core/components/restapi/rest/rest_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from __future__ import annotations

import json
import logging
from typing import Dict, TYPE_CHECKING

from aiohttp import web

from tribler.core.utilities.async_group import AsyncGroup

if TYPE_CHECKING:
from tribler.core.components.restapi.rest.events_endpoint import EventsEndpoint
from ipv8.REST.root_endpoint import RootEndpoint as IPV8RootEndpoint

HTTP_BAD_REQUEST = 400
HTTP_UNAUTHORIZED = 401
HTTP_NOT_FOUND = 404
Expand All @@ -14,16 +23,32 @@ class RESTEndpoint:
def __init__(self, middlewares=()):
self._logger = logging.getLogger(self.__class__.__name__)
self.app = web.Application(middlewares=middlewares, client_max_size=2 * 1024 ** 2)
self.endpoints = {}
self.endpoints: Dict[str, RESTEndpoint] = {}
self.async_group = AsyncGroup()
self.setup_routes()

self._shutdown = False

def setup_routes(self):
pass

def add_endpoint(self, prefix, endpoint):
def add_endpoint(self, prefix: str, endpoint: RESTEndpoint | EventsEndpoint | IPV8RootEndpoint):
self.endpoints[prefix] = endpoint
self.app.add_subapp(prefix, endpoint.app)

async def shutdown(self):
if self._shutdown:
return
self._shutdown = True

shutdown_group = AsyncGroup()
for endpoint in self.endpoints.values():
if isinstance(endpoint, RESTEndpoint):
shutdown_group.add(endpoint.shutdown()) # IPV8RootEndpoint doesn't have a shutdown method

await shutdown_group.wait()
await self.async_group.cancel()


class RESTResponse(web.Response):

Expand Down
4 changes: 2 additions & 2 deletions src/tribler/core/components/restapi/rest/shutdown_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, shutdown_callback):
self.shutdown_callback = shutdown_callback

def setup_routes(self):
self.app.add_routes([web.put('', self.shutdown)])
self.app.add_routes([web.put('', self.shutdown_request)])

@docs(
tags=["General"],
Expand All @@ -31,7 +31,7 @@ def setup_routes(self):
}
}
)
async def shutdown(self, request):
async def shutdown_request(self, _):
self._logger.info('Received a shutdown request from GUI')
self.shutdown_callback()
return RESTResponse({"shutdown": True})
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def test_on_tribler_exception_stores_only_first_error(endpoint, reported_e
assert endpoint.undelivered_error == endpoint.error_message(first_reported_error)


@patch.object(EventsEndpoint, 'register_anonymous_task', new=AsyncMock(side_effect=CancelledError))
@patch('asyncio.sleep', new=AsyncMock(side_effect=CancelledError))
@patch.object(RESTStreamResponse, 'prepare', new=AsyncMock())
@patch.object(RESTStreamResponse, 'write', new_callable=AsyncMock)
@patch.object(EventsEndpoint, 'encode_message')
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from unittest.mock import AsyncMock, patch

from tribler.core.components.restapi.rest.rest_endpoint import RESTEndpoint
from tribler.core.utilities.async_group import AsyncGroup


# pylint: disable=protected-access

async def test_shutdown():
# In this test we check that all coros related to the Root Endpoint are cancelled
# during shutdown

async def coro():
...

root_endpoint = RESTEndpoint()
root_endpoint.async_group.add(coro())

# add 2 child endpoints with a single coro in each:
child_endpoints = [RESTEndpoint(), RESTEndpoint()]
for index, child_endpoint in enumerate(child_endpoints):
root_endpoint.add_endpoint(prefix=f'/{index}', endpoint=child_endpoint)
child_endpoint.async_group.add(coro())

def total_coro_count():
count = 0
for endpoint in child_endpoints + [root_endpoint]:
count += len(endpoint.async_group._futures)
return count

assert total_coro_count() == 3

await root_endpoint.shutdown()

assert total_coro_count() == 0


@patch.object(AsyncGroup, 'cancel', new_callable=AsyncMock)
async def test_multiple_shutdown_calls(async_group_cancel: AsyncMock):
# Test that if shutdown calls twice, only one call is processed
endpoint = RESTEndpoint()

await endpoint.shutdown()
await endpoint.shutdown()

async_group_cancel.assert_called_once()
4 changes: 2 additions & 2 deletions src/tribler/core/components/restapi/restapi_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def report_callback(reported_error: ReportedError):
async def shutdown(self):
await super().shutdown()

if self._events_endpoint:
await self._events_endpoint.shutdown()
if self.root_endpoint:
await self.root_endpoint.shutdown()

if self._core_exception_handler:
self._core_exception_handler.report_callback = None
Expand Down
5 changes: 4 additions & 1 deletion src/tribler/core/components/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tribler.core.components.component import Component, ComponentError, ComponentStartupException, \
MultipleComponentsFound
from tribler.core.config.tribler_config import TriblerConfig
from tribler.core.utilities.async_group import AsyncGroup
from tribler.core.utilities.crypto_patcher import patch_crypto_be_discovery
from tribler.core.utilities.install_dir import get_lib_path
from tribler.core.utilities.network_utils import default_network_utils
Expand All @@ -33,6 +34,7 @@ def __init__(self, config: TriblerConfig = None, components: List[Component] = (
self.config: TriblerConfig = config or TriblerConfig()
self.shutdown_event: Event = shutdown_event or Event()
self.notifier: Notifier = notifier or Notifier(loop=get_event_loop())
self.async_group = AsyncGroup()
self.components: Dict[Type[Component], Component] = {}
for component in components:
self.register(component.__class__, component)
Expand Down Expand Up @@ -104,7 +106,7 @@ async def exception_reraiser():
self.logger.info(f'Reraise startup exception: {self._startup_exception}')
raise self._startup_exception

get_event_loop().create_task(exception_reraiser())
self.async_group.add(exception_reraiser())

def set_startup_exception(self, exc: Exception):
if not self._startup_exception:
Expand All @@ -113,6 +115,7 @@ def set_startup_exception(self, exc: Exception):
async def shutdown(self):
self.logger.info("Stopping components")
await gather(*[create_task(component.stop()) for component in self.components.values()])
await self.async_group.cancel()
self.logger.info("All components are stopped")


Expand Down
8 changes: 6 additions & 2 deletions src/tribler/core/utilities/async_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from asyncio import CancelledError, Future
from asyncio import CancelledError, Future, Task
from contextlib import suppress
from typing import Iterable, List, Set

Expand All @@ -24,13 +24,17 @@ class AsyncGroup:
def __init__(self):
self._futures: Set[Future] = set()

def add(self, *coroutines):
def add(self, *coroutines) -> List[Task]:
"""Add a coroutine to the group.
"""
result = []
for coroutine in coroutines:
task = asyncio.create_task(coroutine)
self._futures.add(task)
task.add_done_callback(self._done_callback)
result.append(task)

return result

async def wait(self):
""" Wait for completion of all futures
Expand Down
6 changes: 4 additions & 2 deletions src/tribler/core/utilities/tests/test_async_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,23 @@ async def raise_exception():


async def test_add_single_coro(group: AsyncGroup):
group.add(
tasks = group.add(
void()
)

assert len(group._futures) == 1
assert len(tasks) == 1


async def test_add_iterable(group: AsyncGroup):
group.add(
tasks = group.add(
void(),
void(),
void()
)

assert len(group._futures) == 3
assert len(tasks) == 3


async def test_cancel(group: AsyncGroup):
Expand Down
14 changes: 10 additions & 4 deletions src/tribler/core/utilities/tiny_tribler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tribler.core.components.component import Component
from tribler.core.components.session import Session
from tribler.core.config.tribler_config import TriblerConfig
from tribler.core.utilities.async_group import AsyncGroup
from tribler.core.utilities.osutils import get_root_state_directory
from tribler.core.utilities.process_manager import ProcessKind, ProcessManager, TriblerProcess, \
set_global_process_manager
Expand All @@ -27,6 +28,8 @@ def __init__(self, components: List[Component], timeout_in_sec=None, state_dir=P
self.config = TriblerConfig(state_dir=state_dir.absolute())
self.timeout_in_sec = timeout_in_sec
self.components = components
self.async_group = AsyncGroup()
self._main_task = None

async def on_tribler_started(self):
"""Function will calls after the Tribler session is started
Expand All @@ -42,7 +45,7 @@ async def start_tribler():
await self._start_session()

if self.timeout_in_sec:
asyncio.create_task(self._terminate_by_timeout())
self.async_group.add(self._terminate_by_timeout())

self._enable_graceful_shutdown()
await self.on_tribler_started()
Expand All @@ -51,7 +54,9 @@ async def start_tribler():
if fragile:
make_async_loop_fragile(loop)

loop.create_task(start_tribler())
# the variable `self._main_task` is used here to prevent a naked `loop.create_task()` call
# more details: https://github.com/Tribler/tribler/issues/7299
self._main_task = loop.create_task(start_tribler())
try:
loop.run_forever()
finally:
Expand Down Expand Up @@ -97,8 +102,9 @@ async def _terminate_by_timeout(self):

def _graceful_shutdown(self):
self.logger.info("Shutdown gracefully")
task = asyncio.create_task(self.session.shutdown())
task.add_done_callback(lambda result: self._stop_event_loop())
tasks = self.async_group.add(self.session.shutdown())
shutdown_task = tasks[0]
shutdown_task.add_done_callback(lambda result: self._stop_event_loop())

def _stop_event_loop(self):
asyncio.get_running_loop().stop()
Expand Down

0 comments on commit 0b5c04d

Please sign in to comment.