Skip to content

Commit

Permalink
fixup! Refactor PubSub, Server & Publishers
Browse files Browse the repository at this point in the history
  • Loading branch information
roekatz committed Sep 5, 2024
1 parent 859d8b9 commit 8271992
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 17 deletions.
3 changes: 2 additions & 1 deletion packages/opal-common/opal_common/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def add_task(self, f):
async def join(self, cancel=False):
if cancel:
for t in self._tasks:
t.cancel()
if not t.done():
t.cancel()
await asyncio.gather(*self._tasks, return_exceptions=True)
self._tasks.clear()

Expand Down
22 changes: 6 additions & 16 deletions packages/opal-server/opal_server/policy/watcher/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi_websocket_pubsub import Topic
from opal_common.logger import logger
from opal_common.sources.base_policy_source import BasePolicySource
from opal_common.async_utils import TasksPool
from opal_server.config import opal_server_config
from opal_server.pubsub import PubSub

Expand All @@ -14,20 +15,13 @@ class BasePolicyWatcherTask:
"""Manages the asyncio tasks of the policy watcher."""

def __init__(self, pubsub: PubSub):
self._tasks: List[asyncio.Task] = []
self._tasks = TasksPool()
self._should_stop: Optional[asyncio.Event] = None
self._pubsub = pubsub
self._webhook_tasks: List[asyncio.Task] = []

async def _on_webhook(self, topic: Topic, data: Any):
logger.info(f"Webhook listener triggered ({len(self._webhook_tasks)})")
# TODO: Use TasksPool
for task in self._webhook_tasks:
if task.done():
# Clean references to finished tasks
self._webhook_tasks.remove(task)

self._webhook_tasks.append(asyncio.create_task(self.trigger(topic, data)))
logger.info("Webhook listener triggered")
self._tasks.add_task(self.trigger(topic, data))

async def _listen_to_webhook_notifications(self):
# Webhook api route can be hit randomly in all workers, so it publishes a message to the webhook topic.
Expand All @@ -50,10 +44,7 @@ async def start(self):
async def stop(self):
"""stops all policy watcher tasks."""
logger.info("Stopping policy watcher")
for task in self._tasks + self._webhook_tasks:
if not task.done():
task.cancel()
await asyncio.gather(*self._tasks, return_exceptions=True)
await self._tasks.join()

async def trigger(self, topic: Topic, data: Any):
"""triggers the policy watcher from outside to check for changes (git
Expand All @@ -64,7 +55,6 @@ async def _fail(self, exc: Exception):
"""called when the watcher fails, and stops all tasks gracefully."""
logger.error("policy watcher failed with exception: {err}", err=repr(exc))
# trigger uvicorn graceful shutdown
# TODO: Seriously?
os.kill(os.getpid(), signal.SIGTERM)


Expand All @@ -76,7 +66,7 @@ def __init__(self, policy_source: BasePolicySource, *args, **kwargs):
async def start(self):
await super().start()
self._watcher.add_on_failure_callback(self._fail)
self._tasks.append(asyncio.create_task(self._watcher.run()))
self._tasks.add_task(self._watcher.run())

async def stop(self):
await self._watcher.stop()
Expand Down

0 comments on commit 8271992

Please sign in to comment.