Skip to content

Commit

Permalink
Refactor to isolate dependencies per transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
mardiros committed Feb 8, 2025
1 parent 889e464 commit 8f0b52a
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 79 deletions.
5 changes: 4 additions & 1 deletion src/messagebus/domain/model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from datetime import datetime
from typing import Any, Generic
from typing import Any, Generic, TypeVar
from uuid import UUID

from lastuuid import uuid7
Expand Down Expand Up @@ -69,3 +69,6 @@ class GenericEvent(Message[TMetadata]):
"""Command that use the default metadata."""
Event = GenericEvent[Metadata]
"""Event that use the default metadata."""


TMessage = TypeVar("TMessage", bound=Message[Any])
39 changes: 39 additions & 0 deletions src/messagebus/service/_async/dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from collections.abc import Mapping, Sequence
from typing import Any, Generic

from messagebus.domain.model.message import TMessage
from messagebus.service._async.unit_of_work import TAsyncUow
from messagebus.typing import AsyncMessageHandler, P


class AsyncDependency:
"""Describe an async dependency"""

async def on_after_commit(self) -> None:
"""Method called when the unit of work transaction is has been commited."""

async def on_after_rollback(self) -> None:
"""Method called when the unit of work transaction is has been rolled back."""


class AsyncMessageHook(Generic[TMessage, TAsyncUow, P]):
callback: AsyncMessageHandler[TMessage, "TAsyncUow", P]
dependencies: Sequence[str]

def __init__(
self,
callback: AsyncMessageHandler[TMessage, "TAsyncUow", P],
dependencies: Sequence[str],
) -> None:
self.callback = callback
self.dependencies = dependencies

async def __call__(
self,
msg: TMessage,
uow: "TAsyncUow",
dependencies: Mapping[str, AsyncDependency],
) -> Any:
deps = {k: dependencies[k] for k in self.dependencies}
resp = await self.callback(msg, uow, **deps) # type: ignore
return resp
61 changes: 36 additions & 25 deletions src/messagebus/service/_async/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,24 @@
import inspect
import logging
from collections import defaultdict
from functools import partial
from collections.abc import Mapping
from typing import Any, Generic, cast

import venusian # type: ignore

from messagebus.domain.model import GenericCommand, GenericEvent, Message
from messagebus.typing import AsyncMessageHandler, P, TAsyncUow, TMessage

from .unit_of_work import AsyncUnitOfWorkTransaction, TRepositories
from messagebus.domain.model.message import TMessage
from messagebus.service._async.dependency import (
AsyncDependency,
AsyncMessageHandler,
AsyncMessageHook,
P,
)
from messagebus.service._async.unit_of_work import (
AsyncUnitOfWorkTransaction,
TAsyncUow,
TRepositories,
)

log = logging.getLogger(__name__)
VENUSIAN_CATEGORY = "messagebus"
Expand Down Expand Up @@ -55,38 +64,38 @@ class AsyncMessageBus(Generic[TRepositories]):

def __init__(self, **dependencies: Any) -> None:
self.commands_registry: dict[
type[GenericCommand[Any]],
AsyncMessageHandler[GenericCommand[Any], Any, ...],
type[GenericCommand[Any]], AsyncMessageHook[Any, Any, Any]
] = {}
self.events_registry: dict[
type[GenericEvent[Any]],
list[AsyncMessageHandler[GenericEvent[Any], Any, ...]],
type[GenericEvent[Any]], list[AsyncMessageHook[Any, Any, Any]]
] = defaultdict(list)
self.depencencies = dependencies or {}
self.dependencies = cast(
Mapping[str, type[AsyncDependency]], dependencies or {}
)

def add_listener(
self, msg_type: type[Message[Any]], callback: AsyncMessageHandler[Any, Any, P]
) -> None:
signature = inspect.signature(callback)
kwargs = {}
dependencies: list[str] = []
for idx, key in enumerate(signature.parameters):
if idx >= 2:
if key not in self.depencencies:
if key not in self.dependencies:
raise ConfigurationError(
f"Missing dependency in message bus: {key} for command "
f"type {msg_type.__name__}, listener: {callback.__name__}"
)
kwargs[key] = self.depencencies[key]
if kwargs:
callback = partial(callback, **kwargs) # type: ignore
dependencies.append(key)

msghook = AsyncMessageHook(callback, dependencies)
if issubclass(msg_type, GenericCommand):
if msg_type in self.commands_registry:
raise ConfigurationError(
f"{msg_type} command has been registered twice"
)
self.commands_registry[msg_type] = callback
self.commands_registry[msg_type] = msghook
elif issubclass(msg_type, GenericEvent):
self.events_registry[msg_type].append(callback)
self.events_registry[msg_type].append(msghook)
else:
raise ConfigurationError(
f"Invalid usage of the listen decorator: "
Expand All @@ -101,12 +110,13 @@ def remove_listener(
raise ConfigurationError(f"{msg_type} command has not been registered")
del self.commands_registry[msg_type]
elif issubclass(msg_type, GenericEvent):
try:
self.events_registry[msg_type].remove(callback)
except ValueError as exc:
raise ConfigurationError(
f"{msg_type} event has not been registered"
) from exc
msg_hooks = [
v for v in self.events_registry[msg_type] if v.callback == callback
]
if msg_hooks:
self.events_registry[msg_type].remove(msg_hooks[0])
else:
raise ConfigurationError(f"{msg_type} event has not been registered")
else:
raise ConfigurationError(
f"Invalid usage of the listen decorator: "
Expand All @@ -120,6 +130,7 @@ async def handle(
Notify listener of that event registered with `messagebus.add_listener`.
Return the first event from the command.
"""
dependencies = {k: uow.add_listener(v()) for k, v in self.dependencies.items()}
queue = [message]
idx = 0
ret = None
Expand All @@ -130,14 +141,14 @@ async def handle(
msg_type = type(message)
if msg_type in self.commands_registry:
cmdret = await self.commands_registry[msg_type]( # type: ignore
cast(GenericCommand[Any], message), uow
cast(GenericCommand[Any], message), uow, dependencies
)
if idx == 0:
ret = cmdret
queue.extend(uow.uow.collect_new_events())
elif msg_type in self.events_registry:
for callback in self.events_registry[msg_type]: # type: ignore
await callback(cast(GenericEvent[Any], message), uow)
for msghook in self.events_registry[msg_type]: # type: ignore
await msghook(cast(GenericEvent[Any], message), uow, dependencies)
queue.extend(uow.uow.collect_new_events())
await uow.eventstore.add(message)
idx += 1
Expand Down
23 changes: 22 additions & 1 deletion src/messagebus/service/_async/unit_of_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import enum
from collections.abc import Iterator
from types import TracebackType
from typing import Any, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from messagebus.domain.model import Message

if TYPE_CHECKING:
from messagebus.service._async.dependency import AsyncDependency
from messagebus.service._async.repository import (
AsyncAbstractRepository,
AsyncEventstoreAbstractRepository,
Expand Down Expand Up @@ -39,6 +42,7 @@ class AsyncUnitOfWorkTransaction(Generic[TRepositories]):
def __init__(self, uow: AsyncAbstractUnitOfWork[TRepositories]) -> None:
self.status = TransactionStatus.running
self.uow = uow
self._hooks: list[Any] = []

def __getattr__(self, name: str) -> TRepositories:
return getattr(self.uow, name)
Expand All @@ -47,15 +51,29 @@ def __getattr__(self, name: str) -> TRepositories:
def eventstore(self) -> AsyncEventstoreAbstractRepository:
return self.uow.eventstore

def add_listener(self, listener: AsyncDependency) -> AsyncDependency:
self._hooks.append(listener)
return listener

async def _on_after_commit(self) -> None:
for val in self._hooks:
await val.on_after_commit()

async def _on_after_rollback(self) -> None:
for val in self._hooks:
await val.on_after_rollback()

async def commit(self) -> None:
if self.status != TransactionStatus.running:
raise TransactionError(f"Transaction already closed ({self.status.value}).")
await self.uow.commit()
self.status = TransactionStatus.committed
await self._on_after_commit()

async def rollback(self) -> None:
await self.uow.rollback()
self.status = TransactionStatus.rolledback
await self._on_after_rollback()

async def __aenter__(self) -> AsyncUnitOfWorkTransaction[TRepositories]:
if self.status != TransactionStatus.running:
Expand Down Expand Up @@ -130,3 +148,6 @@ async def commit(self) -> None:
@abc.abstractmethod
async def rollback(self) -> None:
"""Rollback the transation."""


TAsyncUow = TypeVar("TAsyncUow", bound=AsyncAbstractUnitOfWork[Any])
39 changes: 39 additions & 0 deletions src/messagebus/service/_sync/dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from collections.abc import Mapping, Sequence
from typing import Any, Generic

from messagebus.domain.model.message import TMessage
from messagebus.service._sync.unit_of_work import TSyncUow
from messagebus.typing import P, SyncMessageHandler


class SyncDependency:
"""Describe an async dependency"""

def on_after_commit(self) -> None:
"""Method called when the unit of work transaction is has been commited."""

def on_after_rollback(self) -> None:
"""Method called when the unit of work transaction is has been rolled back."""


class SyncMessageHook(Generic[TMessage, TSyncUow, P]):
callback: SyncMessageHandler[TMessage, "TSyncUow", P]
dependencies: Sequence[str]

def __init__(
self,
callback: SyncMessageHandler[TMessage, "TSyncUow", P],
dependencies: Sequence[str],
) -> None:
self.callback = callback
self.dependencies = dependencies

def __call__(
self,
msg: TMessage,
uow: "TSyncUow",
dependencies: Mapping[str, SyncDependency],
) -> Any:
deps = {k: dependencies[k] for k in self.dependencies}
resp = self.callback(msg, uow, **deps) # type: ignore
return resp
Loading

0 comments on commit 8f0b52a

Please sign in to comment.