Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function for fetching ICE servers from service handlers #717

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
STATE_CONNECTED,
)
from .google_report_state import GoogleReportState
from .ice_servers import IceServers
from .iot import CloudIoT
from .remote import RemoteUI
from .utils import UTC, gather_callbacks, parse_date, utcnow
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self.remote = RemoteUI(self)
self.auth = CognitoAuth(self)
self.voice = Voice(self)
self.ice_servers = IceServers(self)

self._init_task: asyncio.Task | None = None

Expand Down
143 changes: 143 additions & 0 deletions hass_nabucasa/ice_servers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Manage ICE servers."""

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
import logging
import random
import time
from typing import TYPE_CHECKING

from aiohttp import ClientResponseError
from aiohttp.hdrs import AUTHORIZATION, USER_AGENT
from webrtc_models import RTCIceServer

if TYPE_CHECKING:
from . import Cloud, _ClientT


_LOGGER = logging.getLogger(__name__)


class IceServers:
"""Class to manage ICE servers."""

def __init__(self, cloud: Cloud[_ClientT]) -> None:
"""Initialize ICE Servers."""
self.cloud = cloud
self._refresh_task: asyncio.Task | None = None
self._ice_servers: list[RTCIceServer] = []
self._ice_servers_listener: Callable[[], Awaitable[None]] | None = None
self._ice_servers_listener_unregister: Callable[[], None] | None = None

async def _async_fetch_ice_servers(self) -> None:
"""Fetch ICE servers."""
if TYPE_CHECKING:
assert self.cloud.id_token is not None

async with self.cloud.websession.get(
f"https://{self.cloud.servicehandlers_server}/webrtc/ice_servers",
headers={
AUTHORIZATION: self.cloud.id_token,
USER_AGENT: self.cloud.client.client_name,
},
) as resp:
resp.raise_for_status()

self._ice_servers = [
RTCIceServer(
urls=item["urls"],
username=item["username"],
credential=item["credential"],
)
for item in await resp.json()
]

if self._ice_servers_listener is not None:
await self._ice_servers_listener()

def _get_refresh_sleep_time(self) -> int:
"""Get the sleep time for refreshing ICE servers."""
timestamps = [
int(server.username.split(":")[0])
for server in self._ice_servers
if server.username is not None and ":" in server.username
]

if not timestamps:
return 3600 # 1 hour

if (expiration := min(timestamps) - int(time.time()) - 3600) < 0:
return random.randint(100, 300)

# 1 hour before the earliest expiration
return expiration

async def _async_refresh_ice_servers(self) -> None:
"""Handle ICE server refresh."""
while True:
try:
await self._async_fetch_ice_servers()
except ClientResponseError as err:
_LOGGER.error("Can't refresh ICE servers: %s", err)
except asyncio.CancelledError:
# Task is canceled, stop it.
break

sleep_time = self._get_refresh_sleep_time()
await asyncio.sleep(sleep_time)

def _on_add_listener(self) -> None:
"""When the instance is connected."""
self._refresh_task = asyncio.create_task(self._async_refresh_ice_servers())

def _on_remove_listener(self) -> None:
"""When the instance is disconnected."""
if self._refresh_task is not None:
self._refresh_task.cancel()
self._refresh_task = None

async def async_register_ice_servers_listener(
self,
register_ice_server_fn: Callable[
[list[RTCIceServer]],
Awaitable[Callable[[], None]],
],
) -> Callable[[], None]:
"""Register a listener for ICE servers and return unregister function."""
_LOGGER.debug("Registering ICE servers listener")

async def perform_ice_server_update() -> None:
"""Perform ICE server update by unregistering and registering servers."""
_LOGGER.debug("Updating ICE servers")

if self._ice_servers_listener_unregister is not None:
self._ice_servers_listener_unregister()
self._ice_servers_listener_unregister = None

if not self._ice_servers:
return

self._ice_servers_listener_unregister = await register_ice_server_fn(
self._ice_servers,
)

_LOGGER.debug("ICE servers updated")

def remove_listener() -> None:
"""Remove listener."""
if self._ice_servers_listener_unregister is not None:
self._ice_servers_listener_unregister()
self._ice_servers_listener_unregister = None

self._ice_servers = []
self._ice_servers_listener = None

self._on_remove_listener()

self._ice_servers_listener = perform_ice_server_update

self._on_add_listener()

return remove_listener
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"pycognito==2024.5.1",
"PyJWT>=2.8.0",
"snitun==0.39.1",
"webrtc-models==0.0.0b2",
]
description = "Home Assistant cloud integration by Nabu Casa, Inc."
license = {text = "GPL v3"}
Expand Down
125 changes: 125 additions & 0 deletions tests/test_ice_servers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Test the ICE servers module."""

import asyncio
import time

import pytest
from webrtc_models import RTCIceServer

from hass_nabucasa import ice_servers


@pytest.fixture
def ice_servers_api(auth_cloud_mock) -> ice_servers.IceServers:
"""ICE servers API fixture."""
auth_cloud_mock.servicehandlers_server = "example.com/test"
auth_cloud_mock.id_token = "mock-id-token"
return ice_servers.IceServers(auth_cloud_mock)


@pytest.fixture(autouse=True)
def mock_ice_servers(aioclient_mock):
"""Mock ICE servers."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
json=[
{
"urls": "turn:example.com:80",
"username": "12345678:test-user",
"credential": "secret-value",
},
],
)


async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_update(
ice_servers_api: ice_servers.IceServers,
):
"""Test that registering an ICE servers listener triggers a periodic update."""
times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: 0

async def register_ice_servers(ice_servers: list[RTCIceServer]):
nonlocal times_register_called_successfully

# These asserts will silently fail and variable will not be incremented
assert len(ice_servers) == 1
assert ice_servers[0].urls == "turn:example.com:80"
assert ice_servers[0].username == "12345678:test-user"
assert ice_servers[0].credential == "secret-value"

times_register_called_successfully += 1

def unregister():
pass

return unregister

unregister = await ice_servers_api.async_register_ice_servers_listener(
register_ice_servers,
)

# Let the periodic update run once
await asyncio.sleep(0)
# Let the periodic update run again
await asyncio.sleep(0)

assert times_register_called_successfully == 2

unregister()

# The periodic update should not run again
await asyncio.sleep(0)

assert times_register_called_successfully == 2

assert ice_servers_api._refresh_task is None
assert ice_servers_api._ice_servers == []
assert ice_servers_api._ice_servers_listener is None
assert ice_servers_api._ice_servers_listener_unregister is None


def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers):
"""Test get refresh sleep time."""
min_timestamp = 8888888888

ice_servers_api._ice_servers = [
RTCIceServer(urls="turn:example.com:80", username="9999999999:test-user"),
RTCIceServer(
urls="turn:example.com:80",
username=f"{min_timestamp!s}:test-user",
),
]

assert (
ice_servers_api._get_refresh_sleep_time()
== min_timestamp - int(time.time()) - 3600
)


def test_get_refresh_sleep_time_no_turn_servers(
ice_servers_api: ice_servers.IceServers,
):
"""Test get refresh sleep time."""
assert ice_servers_api._get_refresh_sleep_time() == 3600


def test_get_refresh_sleep_time_expiration_less_than_one_hour(
ice_servers_api: ice_servers.IceServers,
):
"""Test get refresh sleep time."""
min_timestamp = 10

ice_servers_api._ice_servers = [
RTCIceServer(urls="turn:example.com:80", username="12345678:test-user"),
RTCIceServer(
urls="turn:example.com:80",
username=f"{min_timestamp!s}:test-user",
),
]

refresh_time = ice_servers_api._get_refresh_sleep_time()

assert refresh_time >= 100
assert refresh_time <= 300