Skip to content

Commit

Permalink
Refactor websocket connection tests to use AsyncExitStack for improve…
Browse files Browse the repository at this point in the history
…d readability and flexibility
  • Loading branch information
tumblingman committed Nov 27, 2024
1 parent 73368fe commit cbebb66
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 62 deletions.
17 changes: 16 additions & 1 deletion tests/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class Communicator(WebsocketCommunicator):
except asyncio.TimeoutError:
break
"""
_connected = False

@property
def connected(self):
return self._connected

async def receive_output(self, timeout=1):
if self.future.done():
self.future.result() # Ensure exceptions are re-raised if future is complete
Expand All @@ -43,6 +49,14 @@ async def receive_output(self, timeout=1):
self.future.result()
raise e # Propagate the timeout exception

async def connect(self, timeout=1):
self._connected, subprotocol = await super().connect(timeout)
return self._connected, subprotocol

async def disconnect(self, code=1000, timeout=1):
await super().disconnect(code, timeout)
self._connected = False


@asynccontextmanager
async def connected_communicator(consumer, path: str = "/testws/") -> Awaitable[Communicator]:
Expand All @@ -67,4 +81,5 @@ async def connected_communicator(consumer, path: str = "/testws/") -> Awaitable[
assert connected, "Failed to connect to WebSocket"
yield communicator
finally:
await communicator.disconnect()
if communicator.connected:
await communicator.disconnect()
81 changes: 42 additions & 39 deletions tests/test_model_observer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import AsyncExitStack

import pytest
from channels import DEFAULT_CHANNEL_LAYER
from channels.db import database_sync_to_async
Expand Down Expand Up @@ -216,53 +218,54 @@ async def update_username(self, pk=None, name=None, **kwargs):
assert not await database_sync_to_async(get_user_model().objects.all().exists)()

# Test a normal connection
async with connected_communicator(TestOtherConsumer()) as communicator1:
async with connected_communicator(TestUserConsumer()) as communicator2:

u1 = await database_sync_to_async(get_user_model().objects.create)(
username="test1", email="42@example.com"
)
t1 = await database_sync_to_async(TestModel.objects.create)(name="test2")

await communicator1.send_json_to(
{"action": "subscribe_instance", "pk": t1.id, "request_id": 4}
)

response = await communicator1.receive_json_from()

assert response == {
"action": "subscribe_instance",
"errors": [],
"response_status": 201,
"request_id": 4,
"data": None,
}
async with AsyncExitStack() as stack:
communicator1 = await stack.enter_async_context(connected_communicator(TestOtherConsumer()))
communicator2 = await stack.enter_async_context(connected_communicator(TestUserConsumer()))

await communicator2.send_json_to(
{"action": "subscribe_instance", "pk": u1.id, "request_id": 4}
)
u1 = await database_sync_to_async(get_user_model().objects.create)(
username="test1", email="42@example.com"
)
t1 = await database_sync_to_async(TestModel.objects.create)(name="test2")

response = await communicator2.receive_json_from()
await communicator1.send_json_to(
{"action": "subscribe_instance", "pk": t1.id, "request_id": 4}
)

assert response == {
"action": "subscribe_instance",
"errors": [],
"response_status": 201,
"request_id": 4,
"data": None,
}
response = await communicator1.receive_json_from()

assert response == {
"action": "subscribe_instance",
"errors": [],
"response_status": 201,
"request_id": 4,
"data": None,
}

await communicator2.send_json_to(
{"action": "subscribe_instance", "pk": u1.id, "request_id": 4}
)

response = await communicator2.receive_json_from()

assert response == {
"action": "subscribe_instance",
"errors": [],
"response_status": 201,
"request_id": 4,
"data": None,
}

# update the user
# update the user

u1.username = "no not a value"
u1.username = "no not a value"

await database_sync_to_async(u1.save)()
await database_sync_to_async(u1.save)()

# user is updated
assert await communicator2.receive_json_from()
# user is updated
assert await communicator2.receive_json_from()

# test model is not
assert await communicator1.receive_nothing()
# test model is not
assert await communicator1.receive_nothing()


@pytest.mark.django_db(transaction=True)
Expand Down
50 changes: 28 additions & 22 deletions tests/test_observer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from contextlib import AsyncExitStack

import pytest
from asgiref.sync import async_to_sync
Expand Down Expand Up @@ -227,20 +228,22 @@ async def user_change_many_connections_wrapper(
):
await self.send_json(dict(body=message, action=action, type=message_type))

async with connected_communicator(TestConsumer()) as communicator2:
async with connected_communicator(TestConsumer()) as communicator1:
async with AsyncExitStack() as stack:
communicator1 = await stack.enter_async_context(connected_communicator(TestConsumer()))
communicator2 = await stack.enter_async_context(connected_communicator(TestConsumer()))
user = await database_sync_to_async(get_user_model().objects.create)(
username="test", email="test@example.com"
)

user = await database_sync_to_async(get_user_model().objects.create)(
username="test", email="test@example.com"
)
response = await communicator1.receive_json_from()

response = await communicator1.receive_json_from()
assert {
"action": "create",
"body": {"pk": user.pk},
"type": "user.change.many.connections.wrapper",
} == response

assert {
"action": "create",
"body": {"pk": user.pk},
"type": "user.change.many.connections.wrapper",
} == response
await communicator1.disconnect()

response = await communicator2.receive_json_from()

Expand Down Expand Up @@ -287,20 +290,23 @@ async def user_change_many_consumers_wrapper_2(
):
await self.send_json(dict(body=message, action=action, type=message_type))

async with connected_communicator(TestConsumer2()) as communicator2:
async with connected_communicator(TestConsumer()) as communicator1:
async with AsyncExitStack() as stack:
communicator1 = await stack.enter_async_context(connected_communicator(TestConsumer()))
communicator2 = await stack.enter_async_context(connected_communicator(TestConsumer2()))

user = await database_sync_to_async(get_user_model().objects.create)(
username="test", email="test@example.com"
)
user = await database_sync_to_async(get_user_model().objects.create)(
username="test", email="test@example.com"
)

response = await communicator1.receive_json_from()

response = await communicator1.receive_json_from()
assert {
"action": "create",
"body": {"pk": user.pk},
"type": "user.change.many.consumers.wrapper.1",
} == response

assert {
"action": "create",
"body": {"pk": user.pk},
"type": "user.change.many.consumers.wrapper.1",
} == response
await communicator1.disconnect()

response = await communicator2.receive_json_from()

Expand Down

0 comments on commit cbebb66

Please sign in to comment.