From 8943960a808f9938653fa930d6bb3f70d752eb45 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Mon, 14 Oct 2024 17:32:18 +0300 Subject: [PATCH 01/14] Add function for fetching ICE servers from service handlers --- hass_nabucasa/cloud_api.py | 23 ++++++++++++++++++ tests/test_cloud_api.py | 50 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/hass_nabucasa/cloud_api.py b/hass_nabucasa/cloud_api.py index da8c66e1f..691d01fd6 100644 --- a/hass_nabucasa/cloud_api.py +++ b/hass_nabucasa/cloud_api.py @@ -52,6 +52,14 @@ class FilesHandlerListEntry(TypedDict): tags: dict[str, Any] +class IceServer(TypedDict): + """ICE Server.""" + + urls: str + username: str + credential: str + + def _do_log_response(resp: ClientResponse, content: str = "") -> None: """Log the response.""" meth = _LOGGER.debug if resp.status < 400 else _LOGGER.warning @@ -349,3 +357,18 @@ async def async_resolve_cname(cloud: Cloud[_ClientT], hostname: str) -> list[str resp.raise_for_status() data: list[str] = await resp.json() return data + + +@_check_token +async def async_ice_servers(cloud: Cloud[_ClientT]) -> list[IceServer]: + """Resolve ICE Servers.""" + if TYPE_CHECKING: + assert cloud.id_token is not None + resp = await cloud.websession.get( + f"https://{cloud.servicehandlers_server}/webrtc/ice_servers", + headers={"authorization": cloud.id_token, USER_AGENT: cloud.client.client_name}, + ) + _do_log_response(resp) + resp.raise_for_status() + data: list[IceServer] = await resp.json() + return data diff --git a/tests/test_cloud_api.py b/tests/test_cloud_api.py index 971f4519e..63c9ac282 100644 --- a/tests/test_cloud_api.py +++ b/tests/test_cloud_api.py @@ -393,3 +393,53 @@ async def test_async_files_upload_details_error( } assert "Fetched https://example.com/files/upload_details (400) Boom!" in caplog.text + + +async def test_async_ice_servers( + auth_cloud_mock: MagicMock, + aioclient_mock: Generator[AiohttpClientMocker, Any, None], + caplog: pytest.LogCaptureFixture, +): + """Test the async_ice_servers function.""" + ice_servers_response = [ + { + "urls": "turn:example.com:3478", + "username": "test-username", + "credential": "test-credential", + }, + ] + + aioclient_mock.get( + "https://example.com/webrtc/ice_servers", + json=ice_servers_response, + ) + auth_cloud_mock.id_token = "mock-id-token" + auth_cloud_mock.servicehandlers_server = "example.com" + + details = await cloud_api.async_ice_servers(cloud=auth_cloud_mock) + + assert len(aioclient_mock.mock_calls) == 1 + + assert details == ice_servers_response + assert "Fetched https://example.com/webrtc/ice_servers (200)" in caplog.text + + +async def test_async_ice_servers_error( + auth_cloud_mock: MagicMock, + aioclient_mock: Generator[AiohttpClientMocker, Any, None], + caplog: pytest.LogCaptureFixture, +): + """Test the async_ice_servers function with error fetching ice servers.""" + aioclient_mock.get( + "https://example.com/webrtc/ice_servers", + status=400, + ) + auth_cloud_mock.id_token = "mock-id-token" + auth_cloud_mock.servicehandlers_server = "example.com" + + with pytest.raises(ClientResponseError): + await cloud_api.async_ice_servers(cloud=auth_cloud_mock) + + assert len(aioclient_mock.mock_calls) == 1 + + assert "Fetched https://example.com/webrtc/ice_servers (400)" in caplog.text From 4fac79cda3050de72e2a1a85b1a9064a012483b3 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Mon, 21 Oct 2024 23:07:48 +0300 Subject: [PATCH 02/14] Refactor code for fetching webrtc servers --- hass_nabucasa/__init__.py | 2 + hass_nabucasa/ice_servers.py | 149 +++++++++++++++++++++++++++++++++++ tests/test_cloud_api.py | 50 ------------ 3 files changed, 151 insertions(+), 50 deletions(-) create mode 100644 hass_nabucasa/ice_servers.py diff --git a/hass_nabucasa/__init__.py b/hass_nabucasa/__init__.py index f5e66422d..e2eb892ac 100644 --- a/hass_nabucasa/__init__.py +++ b/hass_nabucasa/__init__.py @@ -26,6 +26,7 @@ STATE_CONNECTED, ) from .google_report_state import GoogleReportState +from .ice_servers import IceServers from .iot import CloudIoT from .remote import RemoteUI from .utils import UTC, gather_callbacks, parse_date, utcnow @@ -75,6 +76,7 @@ def __init__( self.remote = RemoteUI(self) self.auth = CognitoAuth(self) self.voice = Voice(self) + self.ice_servers = IceServers(self) self._init_task: asyncio.Task | None = None diff --git a/hass_nabucasa/ice_servers.py b/hass_nabucasa/ice_servers.py new file mode 100644 index 000000000..7a4d63e73 --- /dev/null +++ b/hass_nabucasa/ice_servers.py @@ -0,0 +1,149 @@ +"""Class to manage ICE servers.""" + +import asyncio +from collections.abc import Awaitable, Callable +import logging +import time +from typing import TYPE_CHECKING, TypedDict + +from aiohttp import ClientResponseError +from aiohttp.hdrs import AUTHORIZATION, USER_AGENT + +if TYPE_CHECKING: + from . import Cloud, _ClientT + + +_LOGGER = logging.getLogger(__name__) + + +class IceServer(TypedDict): + """ICE Server.""" + + urls: str + username: str + credential: str + + +class IceServersListener(TypedDict): + """ICE Servers Listener.""" + + register_ice_server_fn: Callable[[list[IceServer]], Awaitable[None]] + servers_unregister: list[Callable[[], None]] + + +class IceServers: + """Class to manage ICE servers.""" + + def __init__(self, cloud: Cloud[_ClientT]) -> None: + """Initialize ICE Servers.""" + self.cloud = cloud + self._refresh_task: asyncio.Task | None = None + self._ice_servers: list[IceServer] = [] + self._ice_servers_listeners: dict[str, IceServersListener] = {} + + cloud.iot.register_on_connect(self.on_connect) + cloud.iot.register_on_disconnect(self.on_disconnect) + + async def _async_fetch_ice_servers(self) -> None: + """Fetch ICE servers.""" + async with self.cloud.websession.get( + f"https://{self.cloud.servicehandlers_server}/webrtc/ice_servers", + headers={ + AUTHORIZATION: self.cloud.id_token, + USER_AGENT: self.cloud.client.client_name, + }, + ) as resp: + if resp.status >= 400: + _LOGGER.error("Failed to fetch ICE servers: %s", resp.status) + + resp.raise_for_status() + data: list[IceServer] = await resp.json() + + self._ice_servers = data + + for listener_id in self._ice_servers_listeners: + await self._perform_ice_server_listener_update(listener_id) + + def _get_refresh_sleep_time(self) -> int: + """Get the sleep time for refreshing ICE servers.""" + timestamps = [ + int(server["username"].split(":")[0]) + for server in self._ice_servers + if server["urls"].startswith("turn:") + ] + + if not timestamps: + return 3600 # 1 hour + + # 1 hour before the earliest expiration + return min(timestamps) - int(time.time()) - 3600 + + async def _async_refresh_ice_servers(self) -> None: + """Handle ICE server refresh.""" + while True: + try: + await self._async_fetch_ice_servers() + except ClientResponseError as err: + _LOGGER.error("Can't refresh ICE servers: %s", err) + except asyncio.CancelledError: + # Task is canceled, stop it. + break + + sleep_time = self._get_refresh_sleep_time() + await asyncio.sleep(sleep_time) + + async def on_connect(self) -> None: + """When the instance is connected.""" + self._refresh_task = asyncio.create_task(self._async_refresh_ice_servers()) + + async def on_disconnect(self) -> None: + """When the instance is disconnected.""" + if self._refresh_task is not None: + self._refresh_task.cancel() + self._refresh_task = None + + async def _perform_ice_server_listener_update(self, listener_id: str) -> None: + """Perform ICE server listener update.""" + _LOGGER.debug("Performing ICE servers listener update: %s", listener_id) + + listener_obj = self._ice_servers_listeners.get(listener_id) + if listener_obj is None: + return + + register_ice_server_fn = listener_obj["register_ice_server_fn"] + servers_unregister = listener_obj["servers_unregister"] + + for unregister in servers_unregister: + await unregister() + + if not self._ice_servers: + self._ice_servers_listeners[listener_id]["servers_unregister"] = [] + return + + self._ice_servers_listeners[listener_id]["servers_unregister"] = [ + await register_ice_server_fn(ice_server) for ice_server in self._ice_servers + ] + + _LOGGER.debug("ICE servers listener update done: %s", str(self._ice_servers)) + + async def async_register_ice_servers_listener( + self, + register_ice_server_fn: Callable[[list[IceServer]], Awaitable[None]], + ) -> None: + """Register a listener for ICE servers.""" + listener_id = str(id(register_ice_server_fn)) + + _LOGGER.debug("Registering ICE servers listener: %s", listener_id) + + def remove_listener() -> None: + """Remove listener.""" + self._ice_servers_listeners.pop(listener_id, None) + + self._ice_servers_listeners[listener_id] = { + "register_ice_server_fn": register_ice_server_fn, + "servers_unregister": [], + } + + await self._perform_ice_server_listener_update(listener_id) + + return remove_listener diff --git a/tests/test_cloud_api.py b/tests/test_cloud_api.py index 63c9ac282..971f4519e 100644 --- a/tests/test_cloud_api.py +++ b/tests/test_cloud_api.py @@ -393,53 +393,3 @@ async def test_async_files_upload_details_error( } assert "Fetched https://example.com/files/upload_details (400) Boom!" in caplog.text - - -async def test_async_ice_servers( - auth_cloud_mock: MagicMock, - aioclient_mock: Generator[AiohttpClientMocker, Any, None], - caplog: pytest.LogCaptureFixture, -): - """Test the async_ice_servers function.""" - ice_servers_response = [ - { - "urls": "turn:example.com:3478", - "username": "test-username", - "credential": "test-credential", - }, - ] - - aioclient_mock.get( - "https://example.com/webrtc/ice_servers", - json=ice_servers_response, - ) - auth_cloud_mock.id_token = "mock-id-token" - auth_cloud_mock.servicehandlers_server = "example.com" - - details = await cloud_api.async_ice_servers(cloud=auth_cloud_mock) - - assert len(aioclient_mock.mock_calls) == 1 - - assert details == ice_servers_response - assert "Fetched https://example.com/webrtc/ice_servers (200)" in caplog.text - - -async def test_async_ice_servers_error( - auth_cloud_mock: MagicMock, - aioclient_mock: Generator[AiohttpClientMocker, Any, None], - caplog: pytest.LogCaptureFixture, -): - """Test the async_ice_servers function with error fetching ice servers.""" - aioclient_mock.get( - "https://example.com/webrtc/ice_servers", - status=400, - ) - auth_cloud_mock.id_token = "mock-id-token" - auth_cloud_mock.servicehandlers_server = "example.com" - - with pytest.raises(ClientResponseError): - await cloud_api.async_ice_servers(cloud=auth_cloud_mock) - - assert len(aioclient_mock.mock_calls) == 1 - - assert "Fetched https://example.com/webrtc/ice_servers (400)" in caplog.text From 77e85ac7e99bad13e11720b774ae35970492e082 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Mon, 21 Oct 2024 23:25:48 +0300 Subject: [PATCH 03/14] Fix types issues --- hass_nabucasa/ice_servers.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/hass_nabucasa/ice_servers.py b/hass_nabucasa/ice_servers.py index 7a4d63e73..1a4c3116a 100644 --- a/hass_nabucasa/ice_servers.py +++ b/hass_nabucasa/ice_servers.py @@ -1,4 +1,6 @@ -"""Class to manage ICE servers.""" +"""Manage ICE servers.""" + +from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable @@ -27,7 +29,7 @@ class IceServer(TypedDict): class IceServersListener(TypedDict): """ICE Servers Listener.""" - register_ice_server_fn: Callable[[list[IceServer]], Awaitable[None]] + register_ice_server_fn: Callable[[IceServer], Awaitable[Callable[[], None]]] servers_unregister: list[Callable[[], None]] @@ -46,6 +48,9 @@ def __init__(self, cloud: Cloud[_ClientT]) -> None: async def _async_fetch_ice_servers(self) -> None: """Fetch ICE servers.""" + if TYPE_CHECKING: + assert self.cloud.id_token is not None + async with self.cloud.websession.get( f"https://{self.cloud.servicehandlers_server}/webrtc/ice_servers", headers={ @@ -114,7 +119,7 @@ async def _perform_ice_server_listener_update(self, listener_id: str) -> None: servers_unregister = listener_obj["servers_unregister"] for unregister in servers_unregister: - await unregister() + unregister() if not self._ice_servers: self._ice_servers_listeners[listener_id]["servers_unregister"] = [] @@ -128,8 +133,8 @@ async def _perform_ice_server_listener_update(self, listener_id: str) -> None: async def async_register_ice_servers_listener( self, - register_ice_server_fn: Callable[[list[IceServer]], Awaitable[None]], - ) -> None: + register_ice_server_fn: Callable[[IceServer], Awaitable[Callable[[], None]]], + ) -> Callable[[], None]: """Register a listener for ICE servers.""" listener_id = str(id(register_ice_server_fn)) From e545e7fb0a0fc0443af0c9cbaef904011e1c4703 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Tue, 22 Oct 2024 11:45:55 +0300 Subject: [PATCH 04/14] Remove old code --- hass_nabucasa/cloud_api.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/hass_nabucasa/cloud_api.py b/hass_nabucasa/cloud_api.py index 691d01fd6..da8c66e1f 100644 --- a/hass_nabucasa/cloud_api.py +++ b/hass_nabucasa/cloud_api.py @@ -52,14 +52,6 @@ class FilesHandlerListEntry(TypedDict): tags: dict[str, Any] -class IceServer(TypedDict): - """ICE Server.""" - - urls: str - username: str - credential: str - - def _do_log_response(resp: ClientResponse, content: str = "") -> None: """Log the response.""" meth = _LOGGER.debug if resp.status < 400 else _LOGGER.warning @@ -357,18 +349,3 @@ async def async_resolve_cname(cloud: Cloud[_ClientT], hostname: str) -> list[str resp.raise_for_status() data: list[str] = await resp.json() return data - - -@_check_token -async def async_ice_servers(cloud: Cloud[_ClientT]) -> list[IceServer]: - """Resolve ICE Servers.""" - if TYPE_CHECKING: - assert cloud.id_token is not None - resp = await cloud.websession.get( - f"https://{cloud.servicehandlers_server}/webrtc/ice_servers", - headers={"authorization": cloud.id_token, USER_AGENT: cloud.client.client_name}, - ) - _do_log_response(resp) - resp.raise_for_status() - data: list[IceServer] = await resp.json() - return data From 431c39a797ce0861d4529ff993e0bdfd6cddf3be Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Tue, 22 Oct 2024 16:20:39 +0300 Subject: [PATCH 05/14] Migrate to support only one listener --- hass_nabucasa/ice_servers.py | 78 +++++++++++++++--------------------- 1 file changed, 32 insertions(+), 46 deletions(-) diff --git a/hass_nabucasa/ice_servers.py b/hass_nabucasa/ice_servers.py index 1a4c3116a..c51a919a7 100644 --- a/hass_nabucasa/ice_servers.py +++ b/hass_nabucasa/ice_servers.py @@ -4,9 +4,10 @@ import asyncio from collections.abc import Awaitable, Callable +from dataclasses import dataclass import logging import time -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING from aiohttp import ClientResponseError from aiohttp.hdrs import AUTHORIZATION, USER_AGENT @@ -18,7 +19,8 @@ _LOGGER = logging.getLogger(__name__) -class IceServer(TypedDict): +@dataclass +class IceServer: """ICE Server.""" urls: str @@ -26,13 +28,6 @@ class IceServer(TypedDict): credential: str -class IceServersListener(TypedDict): - """ICE Servers Listener.""" - - register_ice_server_fn: Callable[[IceServer], Awaitable[Callable[[], None]]] - servers_unregister: list[Callable[[], None]] - - class IceServers: """Class to manage ICE servers.""" @@ -41,7 +36,8 @@ def __init__(self, cloud: Cloud[_ClientT]) -> None: self.cloud = cloud self._refresh_task: asyncio.Task | None = None self._ice_servers: list[IceServer] = [] - self._ice_servers_listeners: dict[str, IceServersListener] = {} + self._ice_servers_listener: Callable[[], Awaitable[None]] | None = None + self._ice_servers_listener_unregister: list[Callable[[], None]] = [] cloud.iot.register_on_connect(self.on_connect) cloud.iot.register_on_disconnect(self.on_disconnect) @@ -66,15 +62,15 @@ async def _async_fetch_ice_servers(self) -> None: self._ice_servers = data - for listener_id in self._ice_servers_listeners: - await self._perform_ice_server_listener_update(listener_id) + if self._ice_servers_listener is not None: + await self._ice_servers_listener() def _get_refresh_sleep_time(self) -> int: """Get the sleep time for refreshing ICE servers.""" timestamps = [ - int(server["username"].split(":")[0]) + int(server.username.split(":")[0]) for server in self._ice_servers - if server["urls"].startswith("turn:") + if server.urls.startswith("turn:") ] if not timestamps: @@ -107,48 +103,38 @@ async def on_disconnect(self) -> None: self._refresh_task.cancel() self._refresh_task = None - async def _perform_ice_server_listener_update(self, listener_id: str) -> None: - """Perform ICE server listener update.""" - _LOGGER.debug("Performing ICE servers listener update: %s", listener_id) - - listener_obj = self._ice_servers_listeners.get(listener_id) - if listener_obj is None: - return - - register_ice_server_fn = listener_obj["register_ice_server_fn"] - servers_unregister = listener_obj["servers_unregister"] - - for unregister in servers_unregister: - unregister() - - if not self._ice_servers: - self._ice_servers_listeners[listener_id]["servers_unregister"] = [] - return - - self._ice_servers_listeners[listener_id]["servers_unregister"] = [ - await register_ice_server_fn(ice_server) for ice_server in self._ice_servers - ] - - _LOGGER.debug("ICE servers listener update done: %s", str(self._ice_servers)) - async def async_register_ice_servers_listener( self, register_ice_server_fn: Callable[[IceServer], Awaitable[Callable[[], None]]], ) -> Callable[[], None]: """Register a listener for ICE servers.""" - listener_id = str(id(register_ice_server_fn)) + _LOGGER.debug("Registering ICE servers listener") + + async def perform_ice_server_update() -> None: + """Perform ICE server update.""" + _LOGGER.debug("Updating ICE servers") + + for unregister in self._ice_servers_listener_unregister: + unregister() + + if not self._ice_servers: + self._ice_servers_listener_unregister = [] + return + + self._ice_servers_listener_unregister = [ + await register_ice_server_fn(ice_server) + for ice_server in self._ice_servers + ] - _LOGGER.debug("Registering ICE servers listener: %s", listener_id) + _LOGGER.debug("ICE servers updated") def remove_listener() -> None: """Remove listener.""" - self._ice_servers_listeners.pop(listener_id, None) + self._ice_servers_listener = None - self._ice_servers_listeners[listener_id] = { - "register_ice_server_fn": register_ice_server_fn, - "servers_unregister": [], - } + self._ice_servers_listener = perform_ice_server_update - await self._perform_ice_server_listener_update(listener_id) + if self._ice_servers: + await self._ice_servers_listener() return remove_listener From ceccd80af21dbaee3a8df1fd0e974b91556a691b Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Tue, 22 Oct 2024 21:24:12 +0300 Subject: [PATCH 06/14] PR suggestion fixes --- hass_nabucasa/ice_servers.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/hass_nabucasa/ice_servers.py b/hass_nabucasa/ice_servers.py index c51a919a7..631f2042d 100644 --- a/hass_nabucasa/ice_servers.py +++ b/hass_nabucasa/ice_servers.py @@ -19,7 +19,7 @@ _LOGGER = logging.getLogger(__name__) -@dataclass +@dataclass(frozen=True) class IceServer: """ICE Server.""" @@ -54,13 +54,16 @@ async def _async_fetch_ice_servers(self) -> None: USER_AGENT: self.cloud.client.client_name, }, ) as resp: - if resp.status >= 400: - _LOGGER.error("Failed to fetch ICE servers: %s", resp.status) - resp.raise_for_status() - data: list[IceServer] = await resp.json() - self._ice_servers = data + self._ice_servers = [ + IceServer( + urls=item["urls"], + username=item["username"], + credential=item["credential"], + ) + for item in await resp.json() + ] if self._ice_servers_listener is not None: await self._ice_servers_listener() From eeefdae683a1073fc0e5de9e3b7e58b83994aaca Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Thu, 24 Oct 2024 12:11:53 +0300 Subject: [PATCH 07/14] Minor code improvements, add tests --- hass_nabucasa/ice_servers.py | 34 ++++----- tests/test_ice_servers.py | 133 +++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 20 deletions(-) create mode 100644 tests/test_ice_servers.py diff --git a/hass_nabucasa/ice_servers.py b/hass_nabucasa/ice_servers.py index 631f2042d..cdad3fb63 100644 --- a/hass_nabucasa/ice_servers.py +++ b/hass_nabucasa/ice_servers.py @@ -4,13 +4,13 @@ import asyncio from collections.abc import Awaitable, Callable -from dataclasses import dataclass import logging import time from typing import TYPE_CHECKING from aiohttp import ClientResponseError from aiohttp.hdrs import AUTHORIZATION, USER_AGENT +from webrtc_models import RTCIceServer if TYPE_CHECKING: from . import Cloud, _ClientT @@ -19,15 +19,6 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True) -class IceServer: - """ICE Server.""" - - urls: str - username: str - credential: str - - class IceServers: """Class to manage ICE servers.""" @@ -35,13 +26,10 @@ def __init__(self, cloud: Cloud[_ClientT]) -> None: """Initialize ICE Servers.""" self.cloud = cloud self._refresh_task: asyncio.Task | None = None - self._ice_servers: list[IceServer] = [] + self._ice_servers: list[RTCIceServer] = [] self._ice_servers_listener: Callable[[], Awaitable[None]] | None = None self._ice_servers_listener_unregister: list[Callable[[], None]] = [] - cloud.iot.register_on_connect(self.on_connect) - cloud.iot.register_on_disconnect(self.on_disconnect) - async def _async_fetch_ice_servers(self) -> None: """Fetch ICE servers.""" if TYPE_CHECKING: @@ -57,7 +45,7 @@ async def _async_fetch_ice_servers(self) -> None: resp.raise_for_status() self._ice_servers = [ - IceServer( + RTCIceServer( urls=item["urls"], username=item["username"], credential=item["credential"], @@ -96,11 +84,11 @@ async def _async_refresh_ice_servers(self) -> None: sleep_time = self._get_refresh_sleep_time() await asyncio.sleep(sleep_time) - async def on_connect(self) -> None: + def _on_add_listener(self) -> None: """When the instance is connected.""" self._refresh_task = asyncio.create_task(self._async_refresh_ice_servers()) - async def on_disconnect(self) -> None: + def _on_remove_listener(self) -> None: """When the instance is disconnected.""" if self._refresh_task is not None: self._refresh_task.cancel() @@ -108,7 +96,7 @@ async def on_disconnect(self) -> None: async def async_register_ice_servers_listener( self, - register_ice_server_fn: Callable[[IceServer], Awaitable[Callable[[], None]]], + register_ice_server_fn: Callable[[RTCIceServer], Awaitable[Callable[[], None]]], ) -> Callable[[], None]: """Register a listener for ICE servers.""" _LOGGER.debug("Registering ICE servers listener") @@ -133,11 +121,17 @@ async def perform_ice_server_update() -> None: def remove_listener() -> None: """Remove listener.""" + for unregister in self._ice_servers_listener_unregister: + unregister() + + self._ice_servers = [] self._ice_servers_listener = None + self._ice_servers_listener_unregister = [] + + self._on_remove_listener() self._ice_servers_listener = perform_ice_server_update - if self._ice_servers: - await self._ice_servers_listener() + self._on_add_listener() return remove_listener diff --git a/tests/test_ice_servers.py b/tests/test_ice_servers.py new file mode 100644 index 000000000..e668be335 --- /dev/null +++ b/tests/test_ice_servers.py @@ -0,0 +1,133 @@ +"""Test the ICE servers module.""" + +import asyncio +import time + +import pytest +from webrtc_models import RTCIceServer + +from hass_nabucasa import ice_servers + + +@pytest.fixture +def ice_servers_api(auth_cloud_mock) -> ice_servers.IceServers: + """ICE servers API fixture.""" + auth_cloud_mock.servicehandlers_server = "example.com/test" + auth_cloud_mock.id_token = "mock-id-token" + return ice_servers.IceServers(auth_cloud_mock) + + +@pytest.fixture(autouse=True) +def mock_ice_servers(aioclient_mock): + """Mock ICE servers.""" + aioclient_mock.get( + "https://example.com/test/webrtc/ice_servers", + json=[ + { + "urls": "turn:example.com:80", + "username": "12345678:test-user", + "credential": "secret-value", + }, + ], + ) + + +async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_update( + ice_servers_api: ice_servers.IceServers, +): + """Test that registering an ICE servers listener triggers a periodic update.""" + times_register_called_successfully = 0 + + async def register_ice_server(ice_server: RTCIceServer): + nonlocal times_register_called_successfully + + # There asserts will silently fail and variable will not be incremented + assert ice_server.urls == "turn:example.com:80" + assert ice_server.username == "12345678:test-user" + assert ice_server.credential == "secret-value" + + times_register_called_successfully += 1 + + def unregister(): + pass + + return unregister + + unregister = await ice_servers_api.async_register_ice_servers_listener( + register_ice_server, + ) + + # Let the periodic update run once + await asyncio.sleep(0) + # Let the periodic update run again + await asyncio.sleep(0) + + unregister() + + assert times_register_called_successfully == 2 + + assert ice_servers_api._refresh_task is None + assert ice_servers_api._ice_servers == [] + assert ice_servers_api._ice_servers_listener is None + assert ice_servers_api._ice_servers_listener_unregister == [] + + +async def test_ice_servers_listener_deregistration_stops_periodic_ice_servers_update( + ice_servers_api: ice_servers.IceServers, +): + """Test that deregistering an ICE servers listener stops the periodic update.""" + times_register_called_successfully = 0 + + async def register_ice_server(ice_server: RTCIceServer): + nonlocal times_register_called_successfully + + # There asserts will silently fail and variable will not be incremented + assert ice_server.urls == "turn:example.com:80" + assert ice_server.username == "12345678:test-user" + assert ice_server.credential == "secret-value" + + times_register_called_successfully += 1 + + def unregister(): + pass + + return unregister + + unregister = await ice_servers_api.async_register_ice_servers_listener( + register_ice_server, + ) + + # Let the periodic update run once + await asyncio.sleep(0) + + unregister() + + # The periodic update should not run again + await asyncio.sleep(0) + + assert times_register_called_successfully == 1 + + assert ice_servers_api._refresh_task is None + assert ice_servers_api._ice_servers == [] + assert ice_servers_api._ice_servers_listener is None + assert ice_servers_api._ice_servers_listener_unregister == [] + + +def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers): + """Test get refresh sleep time.""" + assert ice_servers_api._get_refresh_sleep_time() == 3600 + + min_timestamp = 12345678 + + ice_servers_api._ice_servers = [ + RTCIceServer(urls="turn:example.com:80", username="1234567890:test-user"), + RTCIceServer( + urls="turn:example.com:80", + username=f"{min_timestamp!s}:test-user", + ), + ] + + assert ( + ice_servers_api._get_refresh_sleep_time() + == min_timestamp - int(time.time()) - 3600 + ) From a27348f26e478a4678442b2214b74c088f2c2d35 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Thu, 24 Oct 2024 12:35:08 +0300 Subject: [PATCH 08/14] Add webrtc-models dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 9ae5cd103..6211a3dd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "pycognito==2024.5.1", "PyJWT>=2.8.0", "snitun==0.39.1", + "webrtc-models==0.0.0b2", ] description = "Home Assistant cloud integration by Nabu Casa, Inc." license = {text = "GPL v3"} From c42daf518e1bbb8422767002a5102713c574adda Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Thu, 24 Oct 2024 13:46:39 +0300 Subject: [PATCH 09/14] Add minimum refresh time constraint, improve tests --- hass_nabucasa/ice_servers.py | 6 +++++- tests/test_ice_servers.py | 37 ++++++++++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/hass_nabucasa/ice_servers.py b/hass_nabucasa/ice_servers.py index cdad3fb63..60702f090 100644 --- a/hass_nabucasa/ice_servers.py +++ b/hass_nabucasa/ice_servers.py @@ -5,6 +5,7 @@ import asyncio from collections.abc import Awaitable, Callable import logging +import random import time from typing import TYPE_CHECKING @@ -67,8 +68,11 @@ def _get_refresh_sleep_time(self) -> int: if not timestamps: return 3600 # 1 hour + if (expiration := min(timestamps) - int(time.time()) - 3600) < 0: + return random.randint(100, 300) + # 1 hour before the earliest expiration - return min(timestamps) - int(time.time()) - 3600 + return expiration async def _async_refresh_ice_servers(self) -> None: """Handle ICE server refresh.""" diff --git a/tests/test_ice_servers.py b/tests/test_ice_servers.py index e668be335..35f898e51 100644 --- a/tests/test_ice_servers.py +++ b/tests/test_ice_servers.py @@ -38,6 +38,8 @@ async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_u """Test that registering an ICE servers listener triggers a periodic update.""" times_register_called_successfully = 0 + ice_servers_api._get_refresh_sleep_time = lambda: -1 + async def register_ice_server(ice_server: RTCIceServer): nonlocal times_register_called_successfully @@ -78,6 +80,8 @@ async def test_ice_servers_listener_deregistration_stops_periodic_ice_servers_up """Test that deregistering an ICE servers listener stops the periodic update.""" times_register_called_successfully = 0 + ice_servers_api._get_refresh_sleep_time = lambda: -1 + async def register_ice_server(ice_server: RTCIceServer): nonlocal times_register_called_successfully @@ -115,12 +119,10 @@ def unregister(): def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers): """Test get refresh sleep time.""" - assert ice_servers_api._get_refresh_sleep_time() == 3600 - - min_timestamp = 12345678 + min_timestamp = 8888888888 ice_servers_api._ice_servers = [ - RTCIceServer(urls="turn:example.com:80", username="1234567890:test-user"), + RTCIceServer(urls="turn:example.com:80", username="9999999999:test-user"), RTCIceServer( urls="turn:example.com:80", username=f"{min_timestamp!s}:test-user", @@ -131,3 +133,30 @@ def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers): ice_servers_api._get_refresh_sleep_time() == min_timestamp - int(time.time()) - 3600 ) + + +def test_get_refresh_sleep_time_no_turn_servers( + ice_servers_api: ice_servers.IceServers, +): + """Test get refresh sleep time.""" + assert ice_servers_api._get_refresh_sleep_time() == 3600 + + +def test_get_refresh_sleep_time_expiration_less_than_one_hour( + ice_servers_api: ice_servers.IceServers, +): + """Test get refresh sleep time.""" + min_timestamp = 10 + + ice_servers_api._ice_servers = [ + RTCIceServer(urls="turn:example.com:80", username="12345678:test-user"), + RTCIceServer( + urls="turn:example.com:80", + username=f"{min_timestamp!s}:test-user", + ), + ] + + refresh_time = ice_servers_api._get_refresh_sleep_time() + + assert refresh_time >= 100 + assert refresh_time <= 300 From 168fbeb5e07ea620394e3183fae99ad5754ab2e8 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Thu, 24 Oct 2024 13:52:57 +0300 Subject: [PATCH 10/14] Improve sleep time timestamps check --- hass_nabucasa/ice_servers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hass_nabucasa/ice_servers.py b/hass_nabucasa/ice_servers.py index 60702f090..5b2e3ab00 100644 --- a/hass_nabucasa/ice_servers.py +++ b/hass_nabucasa/ice_servers.py @@ -62,7 +62,7 @@ def _get_refresh_sleep_time(self) -> int: timestamps = [ int(server.username.split(":")[0]) for server in self._ice_servers - if server.urls.startswith("turn:") + if server.username is not None and ":" in server.username ] if not timestamps: From 1385434b4b91c8353e47506a669c95973340ea6e Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Thu, 24 Oct 2024 17:15:47 +0300 Subject: [PATCH 11/14] Add support for listener that supports list of ICE servers --- hass_nabucasa/ice_servers.py | 30 ++++++++++++++++-------------- tests/test_ice_servers.py | 26 ++++++++++++++------------ 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/hass_nabucasa/ice_servers.py b/hass_nabucasa/ice_servers.py index 5b2e3ab00..1870dcc96 100644 --- a/hass_nabucasa/ice_servers.py +++ b/hass_nabucasa/ice_servers.py @@ -29,7 +29,7 @@ def __init__(self, cloud: Cloud[_ClientT]) -> None: self._refresh_task: asyncio.Task | None = None self._ice_servers: list[RTCIceServer] = [] self._ice_servers_listener: Callable[[], Awaitable[None]] | None = None - self._ice_servers_listener_unregister: list[Callable[[], None]] = [] + self._ice_servers_listener_unregister: Callable[[], None] | None = None async def _async_fetch_ice_servers(self) -> None: """Fetch ICE servers.""" @@ -100,37 +100,39 @@ def _on_remove_listener(self) -> None: async def async_register_ice_servers_listener( self, - register_ice_server_fn: Callable[[RTCIceServer], Awaitable[Callable[[], None]]], + register_ice_server_fn: Callable[ + [list[RTCIceServer]], + Awaitable[Callable[[], None]], + ], ) -> Callable[[], None]: - """Register a listener for ICE servers.""" + """Register a listener for ICE servers and return unregister function.""" _LOGGER.debug("Registering ICE servers listener") async def perform_ice_server_update() -> None: - """Perform ICE server update.""" + """Perform ICE server update by unregistering and registering servers.""" _LOGGER.debug("Updating ICE servers") - for unregister in self._ice_servers_listener_unregister: - unregister() + if self._ice_servers_listener_unregister is not None: + self._ice_servers_listener_unregister() if not self._ice_servers: - self._ice_servers_listener_unregister = [] + self._ice_servers_listener_unregister = None return - self._ice_servers_listener_unregister = [ - await register_ice_server_fn(ice_server) - for ice_server in self._ice_servers - ] + self._ice_servers_listener_unregister = await register_ice_server_fn( + self._ice_servers, + ) _LOGGER.debug("ICE servers updated") def remove_listener() -> None: """Remove listener.""" - for unregister in self._ice_servers_listener_unregister: - unregister() + if self._ice_servers_listener_unregister is not None: + self._ice_servers_listener_unregister() self._ice_servers = [] self._ice_servers_listener = None - self._ice_servers_listener_unregister = [] + self._ice_servers_listener_unregister = None self._on_remove_listener() diff --git a/tests/test_ice_servers.py b/tests/test_ice_servers.py index 35f898e51..ab8cc309c 100644 --- a/tests/test_ice_servers.py +++ b/tests/test_ice_servers.py @@ -40,13 +40,14 @@ async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_u ice_servers_api._get_refresh_sleep_time = lambda: -1 - async def register_ice_server(ice_server: RTCIceServer): + async def register_ice_servers(ice_servers: list[RTCIceServer]): nonlocal times_register_called_successfully # There asserts will silently fail and variable will not be incremented - assert ice_server.urls == "turn:example.com:80" - assert ice_server.username == "12345678:test-user" - assert ice_server.credential == "secret-value" + assert len(ice_servers) == 1 + assert ice_servers[0].urls == "turn:example.com:80" + assert ice_servers[0].username == "12345678:test-user" + assert ice_servers[0].credential == "secret-value" times_register_called_successfully += 1 @@ -56,7 +57,7 @@ def unregister(): return unregister unregister = await ice_servers_api.async_register_ice_servers_listener( - register_ice_server, + register_ice_servers, ) # Let the periodic update run once @@ -71,7 +72,7 @@ def unregister(): assert ice_servers_api._refresh_task is None assert ice_servers_api._ice_servers == [] assert ice_servers_api._ice_servers_listener is None - assert ice_servers_api._ice_servers_listener_unregister == [] + assert ice_servers_api._ice_servers_listener_unregister == None async def test_ice_servers_listener_deregistration_stops_periodic_ice_servers_update( @@ -82,13 +83,14 @@ async def test_ice_servers_listener_deregistration_stops_periodic_ice_servers_up ice_servers_api._get_refresh_sleep_time = lambda: -1 - async def register_ice_server(ice_server: RTCIceServer): + async def register_ice_servers(ice_servers: list[RTCIceServer]): nonlocal times_register_called_successfully # There asserts will silently fail and variable will not be incremented - assert ice_server.urls == "turn:example.com:80" - assert ice_server.username == "12345678:test-user" - assert ice_server.credential == "secret-value" + assert len(ice_servers) == 1 + assert ice_servers[0].urls == "turn:example.com:80" + assert ice_servers[0].username == "12345678:test-user" + assert ice_servers[0].credential == "secret-value" times_register_called_successfully += 1 @@ -98,7 +100,7 @@ def unregister(): return unregister unregister = await ice_servers_api.async_register_ice_servers_listener( - register_ice_server, + register_ice_servers, ) # Let the periodic update run once @@ -114,7 +116,7 @@ def unregister(): assert ice_servers_api._refresh_task is None assert ice_servers_api._ice_servers == [] assert ice_servers_api._ice_servers_listener is None - assert ice_servers_api._ice_servers_listener_unregister == [] + assert ice_servers_api._ice_servers_listener_unregister == None def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers): From ed7ca6db1200c67ac84c6f7dae97bb805b856b25 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Thu, 24 Oct 2024 17:18:13 +0300 Subject: [PATCH 12/14] Move listener unregister clearance in condition --- hass_nabucasa/ice_servers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hass_nabucasa/ice_servers.py b/hass_nabucasa/ice_servers.py index 1870dcc96..e9aa8dbe4 100644 --- a/hass_nabucasa/ice_servers.py +++ b/hass_nabucasa/ice_servers.py @@ -114,9 +114,9 @@ async def perform_ice_server_update() -> None: if self._ice_servers_listener_unregister is not None: self._ice_servers_listener_unregister() + self._ice_servers_listener_unregister = None if not self._ice_servers: - self._ice_servers_listener_unregister = None return self._ice_servers_listener_unregister = await register_ice_server_fn( @@ -129,10 +129,10 @@ def remove_listener() -> None: """Remove listener.""" if self._ice_servers_listener_unregister is not None: self._ice_servers_listener_unregister() + self._ice_servers_listener_unregister = None self._ice_servers = [] self._ice_servers_listener = None - self._ice_servers_listener_unregister = None self._on_remove_listener() From 0a422e9b60489d61394e1eac0d414156526e40f9 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Thu, 24 Oct 2024 17:19:55 +0300 Subject: [PATCH 13/14] Fix test None check --- tests/test_ice_servers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_ice_servers.py b/tests/test_ice_servers.py index ab8cc309c..3ad70bcba 100644 --- a/tests/test_ice_servers.py +++ b/tests/test_ice_servers.py @@ -72,7 +72,7 @@ def unregister(): assert ice_servers_api._refresh_task is None assert ice_servers_api._ice_servers == [] assert ice_servers_api._ice_servers_listener is None - assert ice_servers_api._ice_servers_listener_unregister == None + assert ice_servers_api._ice_servers_listener_unregister is None async def test_ice_servers_listener_deregistration_stops_periodic_ice_servers_update( @@ -116,7 +116,7 @@ def unregister(): assert ice_servers_api._refresh_task is None assert ice_servers_api._ice_servers == [] assert ice_servers_api._ice_servers_listener is None - assert ice_servers_api._ice_servers_listener_unregister == None + assert ice_servers_api._ice_servers_listener_unregister is None def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers): From 25802ba536fa5305fda2efbb0426f2c0d23e54b1 Mon Sep 17 00:00:00 2001 From: Krisjanis Lejejs Date: Fri, 25 Oct 2024 16:42:42 +0300 Subject: [PATCH 14/14] Improve tests based on PR reviews --- tests/test_ice_servers.py | 45 +++------------------------------------ 1 file changed, 3 insertions(+), 42 deletions(-) diff --git a/tests/test_ice_servers.py b/tests/test_ice_servers.py index 3ad70bcba..ca248ce33 100644 --- a/tests/test_ice_servers.py +++ b/tests/test_ice_servers.py @@ -38,12 +38,12 @@ async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_u """Test that registering an ICE servers listener triggers a periodic update.""" times_register_called_successfully = 0 - ice_servers_api._get_refresh_sleep_time = lambda: -1 + ice_servers_api._get_refresh_sleep_time = lambda: 0 async def register_ice_servers(ice_servers: list[RTCIceServer]): nonlocal times_register_called_successfully - # There asserts will silently fail and variable will not be incremented + # These asserts will silently fail and variable will not be incremented assert len(ice_servers) == 1 assert ice_servers[0].urls == "turn:example.com:80" assert ice_servers[0].username == "12345678:test-user" @@ -65,53 +65,14 @@ def unregister(): # Let the periodic update run again await asyncio.sleep(0) - unregister() - assert times_register_called_successfully == 2 - assert ice_servers_api._refresh_task is None - assert ice_servers_api._ice_servers == [] - assert ice_servers_api._ice_servers_listener is None - assert ice_servers_api._ice_servers_listener_unregister is None - - -async def test_ice_servers_listener_deregistration_stops_periodic_ice_servers_update( - ice_servers_api: ice_servers.IceServers, -): - """Test that deregistering an ICE servers listener stops the periodic update.""" - times_register_called_successfully = 0 - - ice_servers_api._get_refresh_sleep_time = lambda: -1 - - async def register_ice_servers(ice_servers: list[RTCIceServer]): - nonlocal times_register_called_successfully - - # There asserts will silently fail and variable will not be incremented - assert len(ice_servers) == 1 - assert ice_servers[0].urls == "turn:example.com:80" - assert ice_servers[0].username == "12345678:test-user" - assert ice_servers[0].credential == "secret-value" - - times_register_called_successfully += 1 - - def unregister(): - pass - - return unregister - - unregister = await ice_servers_api.async_register_ice_servers_listener( - register_ice_servers, - ) - - # Let the periodic update run once - await asyncio.sleep(0) - unregister() # The periodic update should not run again await asyncio.sleep(0) - assert times_register_called_successfully == 1 + assert times_register_called_successfully == 2 assert ice_servers_api._refresh_task is None assert ice_servers_api._ice_servers == []