From 8271992aac00d196dd242909bcc97d1425c5b07e Mon Sep 17 00:00:00 2001 From: Ro'e Katz Date: Thu, 5 Sep 2024 13:58:19 +0300 Subject: [PATCH] fixup! Refactor PubSub, Server & Publishers --- .../opal-common/opal_common/async_utils.py | 3 ++- .../opal_server/policy/watcher/task.py | 22 +++++-------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/packages/opal-common/opal_common/async_utils.py b/packages/opal-common/opal_common/async_utils.py index b9714d70..4a8383ea 100644 --- a/packages/opal-common/opal_common/async_utils.py +++ b/packages/opal-common/opal_common/async_utils.py @@ -110,7 +110,8 @@ def add_task(self, f): async def join(self, cancel=False): if cancel: for t in self._tasks: - t.cancel() + if not t.done(): + t.cancel() await asyncio.gather(*self._tasks, return_exceptions=True) self._tasks.clear() diff --git a/packages/opal-server/opal_server/policy/watcher/task.py b/packages/opal-server/opal_server/policy/watcher/task.py index 2cb90b97..0f7d8b39 100644 --- a/packages/opal-server/opal_server/policy/watcher/task.py +++ b/packages/opal-server/opal_server/policy/watcher/task.py @@ -6,6 +6,7 @@ from fastapi_websocket_pubsub import Topic from opal_common.logger import logger from opal_common.sources.base_policy_source import BasePolicySource +from opal_common.async_utils import TasksPool from opal_server.config import opal_server_config from opal_server.pubsub import PubSub @@ -14,20 +15,13 @@ class BasePolicyWatcherTask: """Manages the asyncio tasks of the policy watcher.""" def __init__(self, pubsub: PubSub): - self._tasks: List[asyncio.Task] = [] + self._tasks = TasksPool() self._should_stop: Optional[asyncio.Event] = None self._pubsub = pubsub - self._webhook_tasks: List[asyncio.Task] = [] async def _on_webhook(self, topic: Topic, data: Any): - logger.info(f"Webhook listener triggered ({len(self._webhook_tasks)})") - # TODO: Use TasksPool - for task in self._webhook_tasks: - if task.done(): - # Clean references to finished tasks - self._webhook_tasks.remove(task) - - self._webhook_tasks.append(asyncio.create_task(self.trigger(topic, data))) + logger.info("Webhook listener triggered") + self._tasks.add_task(self.trigger(topic, data)) async def _listen_to_webhook_notifications(self): # Webhook api route can be hit randomly in all workers, so it publishes a message to the webhook topic. @@ -50,10 +44,7 @@ async def start(self): async def stop(self): """stops all policy watcher tasks.""" logger.info("Stopping policy watcher") - for task in self._tasks + self._webhook_tasks: - if not task.done(): - task.cancel() - await asyncio.gather(*self._tasks, return_exceptions=True) + await self._tasks.join() async def trigger(self, topic: Topic, data: Any): """triggers the policy watcher from outside to check for changes (git @@ -64,7 +55,6 @@ async def _fail(self, exc: Exception): """called when the watcher fails, and stops all tasks gracefully.""" logger.error("policy watcher failed with exception: {err}", err=repr(exc)) # trigger uvicorn graceful shutdown - # TODO: Seriously? os.kill(os.getpid(), signal.SIGTERM) @@ -76,7 +66,7 @@ def __init__(self, policy_source: BasePolicySource, *args, **kwargs): async def start(self): await super().start() self._watcher.add_on_failure_callback(self._fail) - self._tasks.append(asyncio.create_task(self._watcher.run())) + self._tasks.add_task(self._watcher.run()) async def stop(self): await self._watcher.stop()