Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ice servers to async WebRTC offer #4

Merged
merged 7 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ repos:
name: 🐶 Ruff lint
args:
- --fix
# - --unsafe-fixes

- id: ruff-format
name: 🐶 Ruff format
Expand Down
17 changes: 12 additions & 5 deletions go2rtc_client/ws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from go2rtc_client.exceptions import handle_error

from .messages import BaseMessage
from .messages import BaseMessage, ReceiveMessages, SendMessages, WebRTC, WsMessage

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
self._params = params
self._client: ClientWebSocketResponse | None = None
self._rx_task: asyncio.Task[None] | None = None
self._subscribers: list[Callable[[BaseMessage], None]] = []
self._subscribers: list[Callable[[ReceiveMessages], None]] = []
self._connect_lock = asyncio.Lock()

@property
Expand Down Expand Up @@ -77,7 +77,7 @@ async def close(self) -> None:
await client.close()

@handle_error
async def send(self, message: BaseMessage) -> None:
async def send(self, message: SendMessages) -> None:
"""Send a message."""
if not self.connected:
await self.connect()
Expand All @@ -90,10 +90,15 @@ async def send(self, message: BaseMessage) -> None:
def _process_text_message(self, data: Any) -> None:
"""Process text message."""
try:
message = BaseMessage.from_json(data)
message: WsMessage = BaseMessage.from_json(data)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Invalid message received: %s", data)
else:
if isinstance(message, WebRTC):
message = message.value
if not isinstance(message, ReceiveMessages):
_LOGGER.error("Received unexpected message: %s", message)
return
for subscriber in self._subscribers:
try:
subscriber(message)
Expand Down Expand Up @@ -134,7 +139,9 @@ async def _receive_messages(self) -> None:
if self.connected:
await self.close()

def subscribe(self, callback: Callable[[BaseMessage], None]) -> Callable[[], None]:
def subscribe(
self, callback: Callable[[ReceiveMessages], None]
) -> Callable[[], None]:
"""Subscribe to messages."""

def _unsubscribe() -> None:
Expand Down
74 changes: 59 additions & 15 deletions go2rtc_client/ws/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,34 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, ClassVar
from typing import Annotated, Any, ClassVar

from mashumaro import field_options
from mashumaro.config import BaseConfig
from mashumaro.mixins.orjson import DataClassORJSONMixin
from mashumaro.types import Discriminator
from webrtc_models import (
RTCIceServer, # noqa: TCH002 # Mashumaro needs the import to generate the correct code
)


@dataclass(frozen=True)
class BaseMessage(DataClassORJSONMixin):
"""Base message class."""
class WsMessage:
"""Websocket message."""

TYPE: ClassVar[str]

def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
"""Add type to serialized dict."""
# ClassVar will not serialize by default
d["type"] = self.TYPE
return d


@dataclass(frozen=True)
class BaseMessage(WsMessage, DataClassORJSONMixin):
"""Base message class."""

class Config(BaseConfig):
"""Config for BaseMessage."""

Expand All @@ -27,12 +41,6 @@ class Config(BaseConfig):
variant_tagger_fn=lambda cls: cls.TYPE,
)

def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
"""Add type to serialized dict."""
# ClassVar will not serialize by default
d["type"] = self.TYPE
return d


@dataclass(frozen=True)
class WebRTCCandidate(BaseMessage):
Expand All @@ -43,19 +51,55 @@ class WebRTCCandidate(BaseMessage):


@dataclass(frozen=True)
class WebRTCOffer(BaseMessage):
class WebRTC(BaseMessage):
"""WebRTC message."""

TYPE = "webrtc"
value: Annotated[
WebRTCOffer | WebRTCValue,
Discriminator(
field="type",
include_subtypes=True,
variant_tagger_fn=lambda cls: cls.TYPE,
),
]


@dataclass(frozen=True)
class WebRTCValue(WsMessage):
"""WebRTC value for WebRTC message."""

sdp: str


@dataclass(frozen=True)
class WebRTCOffer(WebRTCValue):
"""WebRTC offer message."""

TYPE = "webrtc/offer"
offer: str = field(metadata=field_options(alias="value"))
TYPE = "offer"
ice_servers: list[RTCIceServer]

def __pre_serialize__(self) -> WebRTCOffer:
"""Pre serialize.

Go2rtc supports only ice_servers with urls as list of strings.
"""
for server in self.ice_servers:
if isinstance(server.urls, str):
server.urls = [server.urls]

return self

def to_json(self, **kwargs: Any) -> str:
"""Convert to json."""
return WebRTC(self).to_json(**kwargs)


@dataclass(frozen=True)
class WebRTCAnswer(BaseMessage):
class WebRTCAnswer(WebRTCValue):
"""WebRTC answer message."""

TYPE = "webrtc/answer"
answer: str = field(metadata=field_options(alias="value"))
TYPE = "answer"


@dataclass(frozen=True)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"awesomeversion>=24.6.0",
"mashumaro~=3.13",
"orjson>=3.10.7",
"webrtc-models>=0.1.0",
]
version = "0.0.0"

Expand Down
64 changes: 53 additions & 11 deletions tests/ws/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,18 @@
from aiohttp.web import WebSocketResponse
from multidict import CIMultiDict, CIMultiDictProxy
import pytest
from webrtc_models import RTCIceServer
from yarl import URL

from go2rtc_client.exceptions import Go2RtcClientError
from go2rtc_client.ws.client import Go2RtcWsClient
from go2rtc_client.ws.messages import BaseMessage, WebRTCAnswer, WebRTCCandidate
from go2rtc_client.ws import (
Go2RtcWsClient,
ReceiveMessages,
SendMessages,
WebRTCAnswer,
WebRTCCandidate,
WebRTCOffer,
)


class TestServer:
Expand Down Expand Up @@ -111,7 +118,27 @@ async def test_connect_parallel(server: TestServer) -> None:
assert client.connected


async def test_send(ws_client: Go2RtcWsClient, server: TestServer) -> None:
@pytest.mark.parametrize(
("message", "expected"),
[
(WebRTCCandidate("test"), '{"value":"test","type":"webrtc/candidate"}'),
(
WebRTCOffer("test", []),
'{"value":{"sdp":"test","ice_servers":[],"type":"offer"},"type":"webrtc"}',
),
(
WebRTCOffer("test", [RTCIceServer("url")]),
'{"value":{"sdp":"test","ice_servers":[{"urls":["url"]}],"type":"offer"},"type":"webrtc"}',
),
(
WebRTCOffer("test", [RTCIceServer(["url1", "url2"])]),
'{"value":{"sdp":"test","ice_servers":[{"urls":["url1","url2"]}],"type":"offer"},"type":"webrtc"}',
),
],
)
async def test_send(
ws_client: Go2RtcWsClient, server: TestServer, message: SendMessages, expected: str
) -> None:
"""Test sending a message through the WebSocket."""
received_message = None

Expand All @@ -121,28 +148,31 @@ def on_message(msg: WSMessage) -> None:

server.on_message = on_message

await ws_client.send(WebRTCCandidate("test"))
await ws_client.send(message)
await asyncio.sleep(0.1)
assert received_message == '{"value":"test","type":"webrtc/candidate"}'
assert received_message == expected


@pytest.mark.parametrize(
("message", "expected"),
[
('{"value":"test","type":"webrtc/candidate"}', WebRTCCandidate("test")),
('{"value":"test","type":"webrtc/answer"}', WebRTCAnswer("test")),
(
'{"value":{"type":"answer", "sdp":"test"},"type":"webrtc"}',
WebRTCAnswer("test"),
),
],
)
async def test_receive(
ws_client_connected: Go2RtcWsClient,
server: TestServer,
message: str,
expected: BaseMessage,
expected: ReceiveMessages,
) -> None:
"""Test receiving a message through the WebSocket."""
received_message = None

def on_message(message: BaseMessage) -> None:
def on_message(message: ReceiveMessages) -> None:
nonlocal received_message
received_message = message

Expand Down Expand Up @@ -230,7 +260,7 @@ async def test_subscribe_unsubscribe(ws_client: Go2RtcWsClient) -> None:
# pylint: disable=protected-access
assert ws_client._subscribers == []

def on_message(_: BaseMessage) -> None:
def on_message(_: ReceiveMessages) -> None:
pass

unsub = ws_client.subscribe(on_message)
Expand All @@ -249,14 +279,14 @@ async def test_subscriber_raised(
) -> None:
"""Test any exception raised by any subscriber will be handled."""

def on_message_raise(_: BaseMessage) -> None:
def on_message_raise(_: ReceiveMessages) -> None:
raise ValueError

ws_client_connected.subscribe(on_message_raise)

received_message = None

def on_message(message: BaseMessage) -> None:
def on_message(message: ReceiveMessages) -> None:
nonlocal received_message
received_message = message

Expand Down Expand Up @@ -294,6 +324,18 @@ def on_message(message: BaseMessage) -> None:
WSMessage(WSMsgType.ERROR, "error", None),
("go2rtc_client.ws.client", logging.ERROR, "Error received: error"),
),
(
WSMessage(
WSMsgType.TEXT,
'{"value":{"sdp":"test","ice_servers":[],"type":"offer"},"type":"webrtc"}',
None,
),
(
"go2rtc_client.ws.client",
logging.ERROR,
"Received unexpected message: WebRTCOffer(sdp='test', ice_servers=[])",
),
),
],
)
async def test_unexpected_messages(
Expand Down
15 changes: 15 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.