From 31d9bbee4fc21b0f6675fc798e9d4e8ed4bdfe61 Mon Sep 17 00:00:00 2001 From: Jev Date: Sat, 16 Nov 2024 21:39:01 +0000 Subject: [PATCH] add _on_init coro --- examples/node_system.py | 1 - src/roxbot/adapters/mqtt_adapter.py | 3 ++- src/roxbot/node.py | 12 +++++++++--- tests/test_interfaces.py | 6 ------ 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/node_system.py b/examples/node_system.py index 7d4ebac..637ff92 100755 --- a/examples/node_system.py +++ b/examples/node_system.py @@ -24,7 +24,6 @@ def __init__(self) -> None: super().__init__() # add coroutines to run in main() - self._coros.append(self._on_init) self._coros.append(self.talker_coro) async def _on_init(self) -> None: diff --git a/src/roxbot/adapters/mqtt_adapter.py b/src/roxbot/adapters/mqtt_adapter.py index aece33b..d61f1be 100644 --- a/src/roxbot/adapters/mqtt_adapter.py +++ b/src/roxbot/adapters/mqtt_adapter.py @@ -110,12 +110,13 @@ def publish_nowait(self, topic: str, data: JsonSerializableType) -> None: async def subscribe(self, topic: str) -> None: """subscribe to topic""" + self._log.info(f"Subscribing to {topic}") + await asyncio.wait_for(self._client_ready.wait(), timeout=1) if self._client is None: raise RuntimeError("MQTT client not initialized") - self._log.info(f"Subscribing to {topic}") await self._client.subscribe(topic) async def unsubscribe(self, topic: str) -> None: diff --git a/src/roxbot/node.py b/src/roxbot/node.py index b0102a9..f08f0b3 100644 --- a/src/roxbot/node.py +++ b/src/roxbot/node.py @@ -40,13 +40,14 @@ def __init__(self, name: str | None = None) -> None: self.mqtt = MqttAdapter(parent=self) # list of coroutines to run in main(). Append to this list in __init__ of derived class. Provide as a reference to the coro, not a call. - self._coros: List[Callable] = [ - self.mqtt.main, - ] + self._coros: List[Callable] = [] self._main_started = False self._tasks: List[asyncio.Task] = [] + async def _on_init(self) -> None: + """init coroutine to run in main(), override in derived class""" + async def main(self) -> None: """main coroutine""" self._log.debug("starting main") @@ -56,6 +57,11 @@ async def main(self) -> None: self._main_started = True + # start mqtt + self._tasks.append(asyncio.create_task(self.mqtt.main())) + + await self._on_init() + async with asyncio.TaskGroup() as tg: for coro in self._coros: self._log.info(f"starting {coro}") diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 8d31fee..644dced 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -30,19 +30,15 @@ async def main(self) -> None: def test_data_classes(): - now = time.time() - latlon = interfaces.GpsLatlon(1, 2) assert latlon.lat == 1 assert latlon.lon == 2 assert latlon.gps_qual == 0 - assert abs(latlon.ts - now) < 1 head = interfaces.GpsHeading(1, 2) assert head.heading == 1 assert head.heading_stdev == 2 - assert abs(head.ts - now) < 1 pos = interfaces.PositionData(1, 2) assert pos.lat == 1 @@ -50,10 +46,8 @@ def test_data_classes(): assert pos.x == 0 assert pos.y == 0 assert pos.gps_qual == 0 - assert abs(pos.ts - now) < 1 head = interfaces.HeadingData(1, 2, 3) assert head.heading == 1 assert head.heading_stdev == 2 assert head.theta == 3 - assert abs(head.ts - now) < 1