From 8f0b52a3af730a4143a9ac5d0e1a8d64167a0270 Mon Sep 17 00:00:00 2001 From: Guillaume Gauvrit Date: Sat, 8 Feb 2025 16:46:41 +0100 Subject: [PATCH] Refactor to isolate dependencies per transaction --- src/messagebus/domain/model/message.py | 5 +- src/messagebus/service/_async/dependency.py | 39 ++++++++++++ src/messagebus/service/_async/registry.py | 61 +++++++++++-------- src/messagebus/service/_async/unit_of_work.py | 23 ++++++- src/messagebus/service/_sync/dependency.py | 39 ++++++++++++ src/messagebus/service/_sync/registry.py | 59 ++++++++++-------- src/messagebus/service/_sync/unit_of_work.py | 23 ++++++- tests/_async/conftest.py | 7 ++- tests/_async/test_registry.py | 21 ++++--- tests/_sync/conftest.py | 7 ++- tests/_sync/test_registry.py | 21 ++++--- 11 files changed, 226 insertions(+), 79 deletions(-) create mode 100644 src/messagebus/service/_async/dependency.py create mode 100644 src/messagebus/service/_sync/dependency.py diff --git a/src/messagebus/domain/model/message.py b/src/messagebus/domain/model/message.py index 66bcb4e..0818e3e 100644 --- a/src/messagebus/domain/model/message.py +++ b/src/messagebus/domain/model/message.py @@ -6,7 +6,7 @@ """ from datetime import datetime -from typing import Any, Generic +from typing import Any, Generic, TypeVar from uuid import UUID from lastuuid import uuid7 @@ -69,3 +69,6 @@ class GenericEvent(Message[TMetadata]): """Command that use the default metadata.""" Event = GenericEvent[Metadata] """Event that use the default metadata.""" + + +TMessage = TypeVar("TMessage", bound=Message[Any]) diff --git a/src/messagebus/service/_async/dependency.py b/src/messagebus/service/_async/dependency.py new file mode 100644 index 0000000..ee68d4c --- /dev/null +++ b/src/messagebus/service/_async/dependency.py @@ -0,0 +1,39 @@ +from collections.abc import Mapping, Sequence +from typing import Any, Generic + +from messagebus.domain.model.message import TMessage +from messagebus.service._async.unit_of_work import TAsyncUow +from messagebus.typing import AsyncMessageHandler, P + + +class AsyncDependency: + """Describe an async dependency""" + + async def on_after_commit(self) -> None: + """Method called when the unit of work transaction is has been commited.""" + + async def on_after_rollback(self) -> None: + """Method called when the unit of work transaction is has been rolled back.""" + + +class AsyncMessageHook(Generic[TMessage, TAsyncUow, P]): + callback: AsyncMessageHandler[TMessage, "TAsyncUow", P] + dependencies: Sequence[str] + + def __init__( + self, + callback: AsyncMessageHandler[TMessage, "TAsyncUow", P], + dependencies: Sequence[str], + ) -> None: + self.callback = callback + self.dependencies = dependencies + + async def __call__( + self, + msg: TMessage, + uow: "TAsyncUow", + dependencies: Mapping[str, AsyncDependency], + ) -> Any: + deps = {k: dependencies[k] for k in self.dependencies} + resp = await self.callback(msg, uow, **deps) # type: ignore + return resp diff --git a/src/messagebus/service/_async/registry.py b/src/messagebus/service/_async/registry.py index 8cf2980..8c33592 100644 --- a/src/messagebus/service/_async/registry.py +++ b/src/messagebus/service/_async/registry.py @@ -7,15 +7,24 @@ import inspect import logging from collections import defaultdict -from functools import partial +from collections.abc import Mapping from typing import Any, Generic, cast import venusian # type: ignore from messagebus.domain.model import GenericCommand, GenericEvent, Message -from messagebus.typing import AsyncMessageHandler, P, TAsyncUow, TMessage - -from .unit_of_work import AsyncUnitOfWorkTransaction, TRepositories +from messagebus.domain.model.message import TMessage +from messagebus.service._async.dependency import ( + AsyncDependency, + AsyncMessageHandler, + AsyncMessageHook, + P, +) +from messagebus.service._async.unit_of_work import ( + AsyncUnitOfWorkTransaction, + TAsyncUow, + TRepositories, +) log = logging.getLogger(__name__) VENUSIAN_CATEGORY = "messagebus" @@ -55,38 +64,38 @@ class AsyncMessageBus(Generic[TRepositories]): def __init__(self, **dependencies: Any) -> None: self.commands_registry: dict[ - type[GenericCommand[Any]], - AsyncMessageHandler[GenericCommand[Any], Any, ...], + type[GenericCommand[Any]], AsyncMessageHook[Any, Any, Any] ] = {} self.events_registry: dict[ - type[GenericEvent[Any]], - list[AsyncMessageHandler[GenericEvent[Any], Any, ...]], + type[GenericEvent[Any]], list[AsyncMessageHook[Any, Any, Any]] ] = defaultdict(list) - self.depencencies = dependencies or {} + self.dependencies = cast( + Mapping[str, type[AsyncDependency]], dependencies or {} + ) def add_listener( self, msg_type: type[Message[Any]], callback: AsyncMessageHandler[Any, Any, P] ) -> None: signature = inspect.signature(callback) - kwargs = {} + dependencies: list[str] = [] for idx, key in enumerate(signature.parameters): if idx >= 2: - if key not in self.depencencies: + if key not in self.dependencies: raise ConfigurationError( f"Missing dependency in message bus: {key} for command " f"type {msg_type.__name__}, listener: {callback.__name__}" ) - kwargs[key] = self.depencencies[key] - if kwargs: - callback = partial(callback, **kwargs) # type: ignore + dependencies.append(key) + + msghook = AsyncMessageHook(callback, dependencies) if issubclass(msg_type, GenericCommand): if msg_type in self.commands_registry: raise ConfigurationError( f"{msg_type} command has been registered twice" ) - self.commands_registry[msg_type] = callback + self.commands_registry[msg_type] = msghook elif issubclass(msg_type, GenericEvent): - self.events_registry[msg_type].append(callback) + self.events_registry[msg_type].append(msghook) else: raise ConfigurationError( f"Invalid usage of the listen decorator: " @@ -101,12 +110,13 @@ def remove_listener( raise ConfigurationError(f"{msg_type} command has not been registered") del self.commands_registry[msg_type] elif issubclass(msg_type, GenericEvent): - try: - self.events_registry[msg_type].remove(callback) - except ValueError as exc: - raise ConfigurationError( - f"{msg_type} event has not been registered" - ) from exc + msg_hooks = [ + v for v in self.events_registry[msg_type] if v.callback == callback + ] + if msg_hooks: + self.events_registry[msg_type].remove(msg_hooks[0]) + else: + raise ConfigurationError(f"{msg_type} event has not been registered") else: raise ConfigurationError( f"Invalid usage of the listen decorator: " @@ -120,6 +130,7 @@ async def handle( Notify listener of that event registered with `messagebus.add_listener`. Return the first event from the command. """ + dependencies = {k: uow.add_listener(v()) for k, v in self.dependencies.items()} queue = [message] idx = 0 ret = None @@ -130,14 +141,14 @@ async def handle( msg_type = type(message) if msg_type in self.commands_registry: cmdret = await self.commands_registry[msg_type]( # type: ignore - cast(GenericCommand[Any], message), uow + cast(GenericCommand[Any], message), uow, dependencies ) if idx == 0: ret = cmdret queue.extend(uow.uow.collect_new_events()) elif msg_type in self.events_registry: - for callback in self.events_registry[msg_type]: # type: ignore - await callback(cast(GenericEvent[Any], message), uow) + for msghook in self.events_registry[msg_type]: # type: ignore + await msghook(cast(GenericEvent[Any], message), uow, dependencies) queue.extend(uow.uow.collect_new_events()) await uow.eventstore.add(message) idx += 1 diff --git a/src/messagebus/service/_async/unit_of_work.py b/src/messagebus/service/_async/unit_of_work.py index b82a80b..3f6c4bb 100644 --- a/src/messagebus/service/_async/unit_of_work.py +++ b/src/messagebus/service/_async/unit_of_work.py @@ -6,9 +6,12 @@ import enum from collections.abc import Iterator from types import TracebackType -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from messagebus.domain.model import Message + +if TYPE_CHECKING: + from messagebus.service._async.dependency import AsyncDependency from messagebus.service._async.repository import ( AsyncAbstractRepository, AsyncEventstoreAbstractRepository, @@ -39,6 +42,7 @@ class AsyncUnitOfWorkTransaction(Generic[TRepositories]): def __init__(self, uow: AsyncAbstractUnitOfWork[TRepositories]) -> None: self.status = TransactionStatus.running self.uow = uow + self._hooks: list[Any] = [] def __getattr__(self, name: str) -> TRepositories: return getattr(self.uow, name) @@ -47,15 +51,29 @@ def __getattr__(self, name: str) -> TRepositories: def eventstore(self) -> AsyncEventstoreAbstractRepository: return self.uow.eventstore + def add_listener(self, listener: AsyncDependency) -> AsyncDependency: + self._hooks.append(listener) + return listener + + async def _on_after_commit(self) -> None: + for val in self._hooks: + await val.on_after_commit() + + async def _on_after_rollback(self) -> None: + for val in self._hooks: + await val.on_after_rollback() + async def commit(self) -> None: if self.status != TransactionStatus.running: raise TransactionError(f"Transaction already closed ({self.status.value}).") await self.uow.commit() self.status = TransactionStatus.committed + await self._on_after_commit() async def rollback(self) -> None: await self.uow.rollback() self.status = TransactionStatus.rolledback + await self._on_after_rollback() async def __aenter__(self) -> AsyncUnitOfWorkTransaction[TRepositories]: if self.status != TransactionStatus.running: @@ -130,3 +148,6 @@ async def commit(self) -> None: @abc.abstractmethod async def rollback(self) -> None: """Rollback the transation.""" + + +TAsyncUow = TypeVar("TAsyncUow", bound=AsyncAbstractUnitOfWork[Any]) diff --git a/src/messagebus/service/_sync/dependency.py b/src/messagebus/service/_sync/dependency.py new file mode 100644 index 0000000..86b8e70 --- /dev/null +++ b/src/messagebus/service/_sync/dependency.py @@ -0,0 +1,39 @@ +from collections.abc import Mapping, Sequence +from typing import Any, Generic + +from messagebus.domain.model.message import TMessage +from messagebus.service._sync.unit_of_work import TSyncUow +from messagebus.typing import P, SyncMessageHandler + + +class SyncDependency: + """Describe an async dependency""" + + def on_after_commit(self) -> None: + """Method called when the unit of work transaction is has been commited.""" + + def on_after_rollback(self) -> None: + """Method called when the unit of work transaction is has been rolled back.""" + + +class SyncMessageHook(Generic[TMessage, TSyncUow, P]): + callback: SyncMessageHandler[TMessage, "TSyncUow", P] + dependencies: Sequence[str] + + def __init__( + self, + callback: SyncMessageHandler[TMessage, "TSyncUow", P], + dependencies: Sequence[str], + ) -> None: + self.callback = callback + self.dependencies = dependencies + + def __call__( + self, + msg: TMessage, + uow: "TSyncUow", + dependencies: Mapping[str, SyncDependency], + ) -> Any: + deps = {k: dependencies[k] for k in self.dependencies} + resp = self.callback(msg, uow, **deps) # type: ignore + return resp diff --git a/src/messagebus/service/_sync/registry.py b/src/messagebus/service/_sync/registry.py index cdf69ff..6310fe9 100644 --- a/src/messagebus/service/_sync/registry.py +++ b/src/messagebus/service/_sync/registry.py @@ -7,15 +7,24 @@ import inspect import logging from collections import defaultdict -from functools import partial +from collections.abc import Mapping from typing import Any, Generic, cast import venusian # type: ignore from messagebus.domain.model import GenericCommand, GenericEvent, Message -from messagebus.typing import P, SyncMessageHandler, TMessage, TSyncUow - -from .unit_of_work import SyncUnitOfWorkTransaction, TRepositories +from messagebus.domain.model.message import TMessage +from messagebus.service._sync.dependency import ( + P, + SyncDependency, + SyncMessageHandler, + SyncMessageHook, +) +from messagebus.service._sync.unit_of_work import ( + SyncUnitOfWorkTransaction, + TRepositories, + TSyncUow, +) log = logging.getLogger(__name__) VENUSIAN_CATEGORY = "messagebus" @@ -55,38 +64,36 @@ class SyncMessageBus(Generic[TRepositories]): def __init__(self, **dependencies: Any) -> None: self.commands_registry: dict[ - type[GenericCommand[Any]], - SyncMessageHandler[GenericCommand[Any], Any, ...], + type[GenericCommand[Any]], SyncMessageHook[Any, Any, Any] ] = {} self.events_registry: dict[ - type[GenericEvent[Any]], - list[SyncMessageHandler[GenericEvent[Any], Any, ...]], + type[GenericEvent[Any]], list[SyncMessageHook[Any, Any, Any]] ] = defaultdict(list) - self.depencencies = dependencies or {} + self.dependencies = cast(Mapping[str, type[SyncDependency]], dependencies or {}) def add_listener( self, msg_type: type[Message[Any]], callback: SyncMessageHandler[Any, Any, P] ) -> None: signature = inspect.signature(callback) - kwargs = {} + dependencies: list[str] = [] for idx, key in enumerate(signature.parameters): if idx >= 2: - if key not in self.depencencies: + if key not in self.dependencies: raise ConfigurationError( f"Missing dependency in message bus: {key} for command " f"type {msg_type.__name__}, listener: {callback.__name__}" ) - kwargs[key] = self.depencencies[key] - if kwargs: - callback = partial(callback, **kwargs) # type: ignore + dependencies.append(key) + + msghook = SyncMessageHook(callback, dependencies) if issubclass(msg_type, GenericCommand): if msg_type in self.commands_registry: raise ConfigurationError( f"{msg_type} command has been registered twice" ) - self.commands_registry[msg_type] = callback + self.commands_registry[msg_type] = msghook elif issubclass(msg_type, GenericEvent): - self.events_registry[msg_type].append(callback) + self.events_registry[msg_type].append(msghook) else: raise ConfigurationError( f"Invalid usage of the listen decorator: " @@ -101,12 +108,13 @@ def remove_listener( raise ConfigurationError(f"{msg_type} command has not been registered") del self.commands_registry[msg_type] elif issubclass(msg_type, GenericEvent): - try: - self.events_registry[msg_type].remove(callback) - except ValueError as exc: - raise ConfigurationError( - f"{msg_type} event has not been registered" - ) from exc + msg_hooks = [ + v for v in self.events_registry[msg_type] if v.callback == callback + ] + if msg_hooks: + self.events_registry[msg_type].remove(msg_hooks[0]) + else: + raise ConfigurationError(f"{msg_type} event has not been registered") else: raise ConfigurationError( f"Invalid usage of the listen decorator: " @@ -120,6 +128,7 @@ def handle( Notify listener of that event registered with `messagebus.add_listener`. Return the first event from the command. """ + dependencies = {k: uow.add_listener(v()) for k, v in self.dependencies.items()} queue = [message] idx = 0 ret = None @@ -130,14 +139,14 @@ def handle( msg_type = type(message) if msg_type in self.commands_registry: cmdret = self.commands_registry[msg_type]( # type: ignore - cast(GenericCommand[Any], message), uow + cast(GenericCommand[Any], message), uow, dependencies ) if idx == 0: ret = cmdret queue.extend(uow.uow.collect_new_events()) elif msg_type in self.events_registry: - for callback in self.events_registry[msg_type]: # type: ignore - callback(cast(GenericEvent[Any], message), uow) + for msghook in self.events_registry[msg_type]: # type: ignore + msghook(cast(GenericEvent[Any], message), uow, dependencies) queue.extend(uow.uow.collect_new_events()) uow.eventstore.add(message) idx += 1 diff --git a/src/messagebus/service/_sync/unit_of_work.py b/src/messagebus/service/_sync/unit_of_work.py index 203038c..b4b17c2 100644 --- a/src/messagebus/service/_sync/unit_of_work.py +++ b/src/messagebus/service/_sync/unit_of_work.py @@ -6,9 +6,12 @@ import enum from collections.abc import Iterator from types import TracebackType -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from messagebus.domain.model import Message + +if TYPE_CHECKING: + from messagebus.service._sync.dependency import SyncDependency from messagebus.service._sync.repository import ( SyncAbstractRepository, SyncEventstoreAbstractRepository, @@ -39,6 +42,7 @@ class SyncUnitOfWorkTransaction(Generic[TRepositories]): def __init__(self, uow: SyncAbstractUnitOfWork[TRepositories]) -> None: self.status = TransactionStatus.running self.uow = uow + self._hooks: list[Any] = [] def __getattr__(self, name: str) -> TRepositories: return getattr(self.uow, name) @@ -47,15 +51,29 @@ def __getattr__(self, name: str) -> TRepositories: def eventstore(self) -> SyncEventstoreAbstractRepository: return self.uow.eventstore + def add_listener(self, listener: SyncDependency) -> SyncDependency: + self._hooks.append(listener) + return listener + + def _on_after_commit(self) -> None: + for val in self._hooks: + val.on_after_commit() + + def _on_after_rollback(self) -> None: + for val in self._hooks: + val.on_after_rollback() + def commit(self) -> None: if self.status != TransactionStatus.running: raise TransactionError(f"Transaction already closed ({self.status.value}).") self.uow.commit() self.status = TransactionStatus.committed + self._on_after_commit() def rollback(self) -> None: self.uow.rollback() self.status = TransactionStatus.rolledback + self._on_after_rollback() def __enter__(self) -> SyncUnitOfWorkTransaction[TRepositories]: if self.status != TransactionStatus.running: @@ -130,3 +148,6 @@ def commit(self) -> None: @abc.abstractmethod def rollback(self) -> None: """Rollback the transation.""" + + +TSyncUow = TypeVar("TSyncUow", bound=SyncAbstractUnitOfWork[Any]) diff --git a/tests/_async/conftest.py b/tests/_async/conftest.py index f5c6b45..26460a4 100644 --- a/tests/_async/conftest.py +++ b/tests/_async/conftest.py @@ -17,6 +17,7 @@ Message, Metadata, ) +from messagebus.service._async.dependency import AsyncDependency from messagebus.service._async.eventstream import ( AsyncAbstractEventstreamTransport, AsyncEventstreamPublisher, @@ -46,7 +47,7 @@ class DummyModel(GenericModel[MyMetadata]): counter: int = Field(0) -class Notifier: +class Notifier(AsyncDependency): inbox: ClassVar[list[str]] = [] def send_message(self, message: str): @@ -208,11 +209,11 @@ async def uow_with_eventstore( @pytest.fixture def notifier(): - return Notifier() + return Notifier @pytest.fixture -def bus(notifier: Notifier) -> AsyncMessageBus[Repositories]: +def bus(notifier: type[Notifier]) -> AsyncMessageBus[Repositories]: return AsyncMessageBus(notifier=notifier) diff --git a/tests/_async/test_registry.py b/tests/_async/test_registry.py index 143f4d5..43b3444 100644 --- a/tests/_async/test_registry.py +++ b/tests/_async/test_registry.py @@ -1,4 +1,3 @@ -import functools from typing import Any import pytest @@ -9,6 +8,7 @@ DummyCommand, DummyEvent, DummyModel, + Notifier, Repositories, ) from tests._async.handlers import dummy @@ -153,10 +153,12 @@ def test_scan(bus: AsyncMessageBus[Any]): bus.scan("tests._async.handlers") assert DummyCommand in bus.commands_registry - assert bus.commands_registry[DummyCommand] == dummy.handler + assert bus.commands_registry[DummyCommand].callback == dummy.handler assert DummyEvent in bus.events_registry - assert bus.events_registry[DummyEvent] == [dummy.handler_evt1, dummy.handler_evt2] + assert len(bus.events_registry[DummyEvent]) == 2 + assert bus.events_registry[DummyEvent][0].callback == dummy.handler_evt1 + assert bus.events_registry[DummyEvent][1].callback == dummy.handler_evt2 def test_scan_relative(bus: AsyncMessageBus[Any]): @@ -171,11 +173,11 @@ def test_scan_relative(bus: AsyncMessageBus[Any]): async def listen_command_with_dependency( cmd: DummyCommand, uow: AsyncUnitOfWorkTransaction[Repositories], - dummy_dict: dict[str, str], + dummy_dep: Notifier, ) -> DummyModel: """This command raise an event played by the message bus.""" foo = DummyModel(id=cmd.id, counter=0) - dummy_dict["foo"] = "bar" + dummy_dep.send_message("foobar") return foo @@ -183,13 +185,12 @@ async def test_messagebus_dependency( uow: AsyncUnitOfWorkTransaction[Repositories], ): d: dict[str, str] = {} - bus = AsyncMessageBus[Repositories](dummy_dict=d) + bus = AsyncMessageBus[Repositories](dummy_dep=d) bus.add_listener(DummyCommand, listen_command_with_dependency) - assert isinstance(bus.commands_registry[DummyCommand], functools.partial) assert ( - bus.commands_registry[DummyCommand].keywords # type: ignore - == {"dummy_dict": d} + bus.commands_registry[DummyCommand].callback == listen_command_with_dependency ) + assert bus.commands_registry[DummyCommand].dependencies == ["dummy_dep"] async def test_messagebus_dependency_error_missing_deps( @@ -199,6 +200,6 @@ async def test_messagebus_dependency_error_missing_deps( with pytest.raises(ConfigurationError) as ctx: bus.add_listener(DummyCommand, listen_command_with_dependency) assert ( - str(ctx.value) == "Missing dependency in message bus: dummy_dict for " + str(ctx.value) == "Missing dependency in message bus: dummy_dep for " "command type DummyCommand, listener: listen_command_with_dependency" ) diff --git a/tests/_sync/conftest.py b/tests/_sync/conftest.py index fd5dfaa..f915394 100644 --- a/tests/_sync/conftest.py +++ b/tests/_sync/conftest.py @@ -17,6 +17,7 @@ Message, Metadata, ) +from messagebus.service._sync.dependency import SyncDependency from messagebus.service._sync.eventstream import ( SyncAbstractEventstreamTransport, SyncEventstreamPublisher, @@ -46,7 +47,7 @@ class DummyModel(GenericModel[MyMetadata]): counter: int = Field(0) -class Notifier: +class Notifier(SyncDependency): inbox: ClassVar[list[str]] = [] def send_message(self, message: str): @@ -208,11 +209,11 @@ def uow_with_eventstore( @pytest.fixture def notifier(): - return Notifier() + return Notifier @pytest.fixture -def bus(notifier: Notifier) -> SyncMessageBus[Repositories]: +def bus(notifier: type[Notifier]) -> SyncMessageBus[Repositories]: return SyncMessageBus(notifier=notifier) diff --git a/tests/_sync/test_registry.py b/tests/_sync/test_registry.py index 39d9b9a..fdbb2e1 100644 --- a/tests/_sync/test_registry.py +++ b/tests/_sync/test_registry.py @@ -1,4 +1,3 @@ -import functools from typing import Any import pytest @@ -8,6 +7,7 @@ DummyCommand, DummyEvent, DummyModel, + Notifier, Repositories, SyncUnitOfWorkTransaction, ) @@ -151,10 +151,12 @@ def test_scan(bus: SyncMessageBus[Any]): bus.scan("tests._sync.handlers") assert DummyCommand in bus.commands_registry - assert bus.commands_registry[DummyCommand] == dummy.handler + assert bus.commands_registry[DummyCommand].callback == dummy.handler assert DummyEvent in bus.events_registry - assert bus.events_registry[DummyEvent] == [dummy.handler_evt1, dummy.handler_evt2] + assert len(bus.events_registry[DummyEvent]) == 2 + assert bus.events_registry[DummyEvent][0].callback == dummy.handler_evt1 + assert bus.events_registry[DummyEvent][1].callback == dummy.handler_evt2 def test_scan_relative(bus: SyncMessageBus[Any]): @@ -169,11 +171,11 @@ def test_scan_relative(bus: SyncMessageBus[Any]): def listen_command_with_dependency( cmd: DummyCommand, uow: SyncUnitOfWorkTransaction[Repositories], - dummy_dict: dict[str, str], + dummy_dep: Notifier, ) -> DummyModel: """This command raise an event played by the message bus.""" foo = DummyModel(id=cmd.id, counter=0) - dummy_dict["foo"] = "bar" + dummy_dep.send_message("foobar") return foo @@ -181,13 +183,12 @@ def test_messagebus_dependency( uow: SyncUnitOfWorkTransaction[Repositories], ): d: dict[str, str] = {} - bus = SyncMessageBus[Repositories](dummy_dict=d) + bus = SyncMessageBus[Repositories](dummy_dep=d) bus.add_listener(DummyCommand, listen_command_with_dependency) - assert isinstance(bus.commands_registry[DummyCommand], functools.partial) assert ( - bus.commands_registry[DummyCommand].keywords # type: ignore - == {"dummy_dict": d} + bus.commands_registry[DummyCommand].callback == listen_command_with_dependency ) + assert bus.commands_registry[DummyCommand].dependencies == ["dummy_dep"] def test_messagebus_dependency_error_missing_deps( @@ -197,6 +198,6 @@ def test_messagebus_dependency_error_missing_deps( with pytest.raises(ConfigurationError) as ctx: bus.add_listener(DummyCommand, listen_command_with_dependency) assert ( - str(ctx.value) == "Missing dependency in message bus: dummy_dict for " + str(ctx.value) == "Missing dependency in message bus: dummy_dep for " "command type DummyCommand, listener: listen_command_with_dependency" )