Skip to content

Commit

Permalink
Before and after request on websockets
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Feb 16, 2025
1 parent 40bb043 commit 024de7d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/en/docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ hide:

# Release Notes

### 3.6.7

### Added

- `before_request` and `after_request` WebSocketGateway handler added.

### 3.6.6

### Added
Expand Down
17 changes: 17 additions & 0 deletions esmerald/routing/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,23 @@ def __init__(
Since the default Lilya Route handler does not understand the Esmerald handlers,
the Gateway bridges both functionalities and adds an extra "flair" to be compliant with both class based views and decorated function views.
"""
self.before_request = before_request if before_request is not None else []
self.after_request = after_request if after_request is not None else []

if self.before_request:
if handler.before_request is None:
handler.before_request = []

for before in self.before_request:
handler.before_request.insert(0, before)

if self.after_request:
if handler.after_request is None:
handler.after_request = []

for after in self.after_request:
handler.after_request.insert(0, after)

self._interceptors: Union[List["Interceptor"], "VoidType"] = Void
self.handler = cast("Callable", handler)
self.dependencies = dependencies or {}
Expand Down
18 changes: 18 additions & 0 deletions esmerald/routing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,6 +2879,15 @@ async def handle_dispatch(self, scope: "Scope", receive: "Receive", send: "Send"
None
"""

for before_request in self.before_request:
if inspect.isclass(before_request):
before_request = before_request()

if is_async_callable(before_request):
await before_request(scope, receive, send)
else:
await run_in_threadpool(before_request, scope, receive, send)

if self.get_interceptors():
await self.intercept(scope, receive, send)

Expand All @@ -2901,6 +2910,15 @@ async def handle_dispatch(self, scope: "Scope", receive: "Receive", send: "Send"
else:
await fn(**kwargs)

for after_request in self.after_request:
if inspect.isclass(after_request):
after_request = after_request()

if is_async_callable(after_request):
await after_request(scope, receive, send)
else:
await run_in_threadpool(after_request, scope, receive, send)

async def get_kwargs(self, websocket: WebSocket) -> Any:
"""Resolves the required kwargs from the request data.
Expand Down

0 comments on commit 024de7d

Please sign in to comment.