Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function for fetching ICE servers from service handlers #717

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hass_nabucasa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
STATE_CONNECTED,
)
from .google_report_state import GoogleReportState
from .ice_servers import IceServers
from .iot import CloudIoT
from .remote import RemoteUI
from .utils import UTC, gather_callbacks, parse_date, utcnow
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self.remote = RemoteUI(self)
self.auth = CognitoAuth(self)
self.voice = Voice(self)
self.ice_servers = IceServers(self)

self._init_task: asyncio.Task | None = None

Expand Down
141 changes: 141 additions & 0 deletions hass_nabucasa/ice_servers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Manage ICE servers."""

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
import logging
import random
import time
from typing import TYPE_CHECKING

from aiohttp import ClientResponseError
from aiohttp.hdrs import AUTHORIZATION, USER_AGENT
from webrtc_models import RTCIceServer

if TYPE_CHECKING:
from . import Cloud, _ClientT


_LOGGER = logging.getLogger(__name__)


class IceServers:
"""Class to manage ICE servers."""

def __init__(self, cloud: Cloud[_ClientT]) -> None:
"""Initialize ICE Servers."""
self.cloud = cloud
self._refresh_task: asyncio.Task | None = None
self._ice_servers: list[RTCIceServer] = []
self._ice_servers_listener: Callable[[], Awaitable[None]] | None = None
self._ice_servers_listener_unregister: list[Callable[[], None]] = []

async def _async_fetch_ice_servers(self) -> None:
"""Fetch ICE servers."""
if TYPE_CHECKING:
assert self.cloud.id_token is not None

async with self.cloud.websession.get(
f"https://{self.cloud.servicehandlers_server}/webrtc/ice_servers",
headers={
AUTHORIZATION: self.cloud.id_token,
USER_AGENT: self.cloud.client.client_name,
},
) as resp:
resp.raise_for_status()

self._ice_servers = [
RTCIceServer(
urls=item["urls"],
username=item["username"],
credential=item["credential"],
)
for item in await resp.json()
]

if self._ice_servers_listener is not None:
await self._ice_servers_listener()

def _get_refresh_sleep_time(self) -> int:
"""Get the sleep time for refreshing ICE servers."""
timestamps = [
int(server.username.split(":")[0])
for server in self._ice_servers
if server.username is not None and ":" in server.username
]

if not timestamps:
return 3600 # 1 hour

if (expiration := min(timestamps) - int(time.time()) - 3600) < 0:
return random.randint(100, 300)

# 1 hour before the earliest expiration
return expiration

async def _async_refresh_ice_servers(self) -> None:
"""Handle ICE server refresh."""
while True:
try:
await self._async_fetch_ice_servers()
except ClientResponseError as err:
_LOGGER.error("Can't refresh ICE servers: %s", err)
except asyncio.CancelledError:
# Task is canceled, stop it.
break

sleep_time = self._get_refresh_sleep_time()
await asyncio.sleep(sleep_time)

def _on_add_listener(self) -> None:
"""When the instance is connected."""
self._refresh_task = asyncio.create_task(self._async_refresh_ice_servers())

def _on_remove_listener(self) -> None:
"""When the instance is disconnected."""
if self._refresh_task is not None:
self._refresh_task.cancel()
self._refresh_task = None

async def async_register_ice_servers_listener(
self,
register_ice_server_fn: Callable[[RTCIceServer], Awaitable[Callable[[], None]]],
) -> Callable[[], None]:
"""Register a listener for ICE servers."""
_LOGGER.debug("Registering ICE servers listener")

async def perform_ice_server_update() -> None:
"""Perform ICE server update."""
_LOGGER.debug("Updating ICE servers")

for unregister in self._ice_servers_listener_unregister:
unregister()

if not self._ice_servers:
self._ice_servers_listener_unregister = []
return

self._ice_servers_listener_unregister = [
await register_ice_server_fn(ice_server)
for ice_server in self._ice_servers
]

_LOGGER.debug("ICE servers updated")

def remove_listener() -> None:
"""Remove listener."""
for unregister in self._ice_servers_listener_unregister:
unregister()

self._ice_servers = []
self._ice_servers_listener = None
self._ice_servers_listener_unregister = []

self._on_remove_listener()

self._ice_servers_listener = perform_ice_server_update

self._on_add_listener()

return remove_listener
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"pycognito==2024.5.1",
"PyJWT>=2.8.0",
"snitun==0.39.1",
"webrtc-models==0.0.0b2",
]
description = "Home Assistant cloud integration by Nabu Casa, Inc."
license = {text = "GPL v3"}
Expand Down
162 changes: 162 additions & 0 deletions tests/test_ice_servers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Test the ICE servers module."""

import asyncio
import time

import pytest
from webrtc_models import RTCIceServer

from hass_nabucasa import ice_servers


@pytest.fixture
def ice_servers_api(auth_cloud_mock) -> ice_servers.IceServers:
"""ICE servers API fixture."""
auth_cloud_mock.servicehandlers_server = "example.com/test"
auth_cloud_mock.id_token = "mock-id-token"
return ice_servers.IceServers(auth_cloud_mock)


@pytest.fixture(autouse=True)
def mock_ice_servers(aioclient_mock):
"""Mock ICE servers."""
aioclient_mock.get(
"https://example.com/test/webrtc/ice_servers",
json=[
{
"urls": "turn:example.com:80",
"username": "12345678:test-user",
"credential": "secret-value",
},
],
)


async def test_ice_servers_listener_registration_triggers_periodic_ice_servers_update(
ice_servers_api: ice_servers.IceServers,
):
"""Test that registering an ICE servers listener triggers a periodic update."""
times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: -1
klejejs marked this conversation as resolved.
Show resolved Hide resolved

async def register_ice_server(ice_server: RTCIceServer):
nonlocal times_register_called_successfully

# There asserts will silently fail and variable will not be incremented
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why will they silently fail? Typo for there -> these.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is a callback, the test runner does not detect it failing in error logs; it just stops the further execution of the test. This could be due to try/catch in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that's just normal test behavior when using assert. I'd not add a comment about that. It's confusing to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand now. This is happening in a separate task, not the same as the test, and that's why the test output is different. Is that correct?

assert ice_server.urls == "turn:example.com:80"
assert ice_server.username == "12345678:test-user"
assert ice_server.credential == "secret-value"

times_register_called_successfully += 1

def unregister():
pass

return unregister

unregister = await ice_servers_api.async_register_ice_servers_listener(
register_ice_server,
)

# Let the periodic update run once
await asyncio.sleep(0)
# Let the periodic update run again
await asyncio.sleep(0)

unregister()

assert times_register_called_successfully == 2

assert ice_servers_api._refresh_task is None
assert ice_servers_api._ice_servers == []
assert ice_servers_api._ice_servers_listener is None
assert ice_servers_api._ice_servers_listener_unregister == []


async def test_ice_servers_listener_deregistration_stops_periodic_ice_servers_update(
ice_servers_api: ice_servers.IceServers,
):
"""Test that deregistering an ICE servers listener stops the periodic update."""
klejejs marked this conversation as resolved.
Show resolved Hide resolved
times_register_called_successfully = 0

ice_servers_api._get_refresh_sleep_time = lambda: -1

async def register_ice_server(ice_server: RTCIceServer):
nonlocal times_register_called_successfully

# There asserts will silently fail and variable will not be incremented
assert ice_server.urls == "turn:example.com:80"
assert ice_server.username == "12345678:test-user"
assert ice_server.credential == "secret-value"

times_register_called_successfully += 1

def unregister():
pass

return unregister

unregister = await ice_servers_api.async_register_ice_servers_listener(
register_ice_server,
)

# Let the periodic update run once
await asyncio.sleep(0)

unregister()

# The periodic update should not run again
await asyncio.sleep(0)

assert times_register_called_successfully == 1

assert ice_servers_api._refresh_task is None
assert ice_servers_api._ice_servers == []
assert ice_servers_api._ice_servers_listener is None
assert ice_servers_api._ice_servers_listener_unregister == []


def test_get_refresh_sleep_time(ice_servers_api: ice_servers.IceServers):
"""Test get refresh sleep time."""
min_timestamp = 8888888888

ice_servers_api._ice_servers = [
RTCIceServer(urls="turn:example.com:80", username="9999999999:test-user"),
RTCIceServer(
urls="turn:example.com:80",
username=f"{min_timestamp!s}:test-user",
),
]

assert (
ice_servers_api._get_refresh_sleep_time()
== min_timestamp - int(time.time()) - 3600
)


def test_get_refresh_sleep_time_no_turn_servers(
ice_servers_api: ice_servers.IceServers,
):
"""Test get refresh sleep time."""
assert ice_servers_api._get_refresh_sleep_time() == 3600


def test_get_refresh_sleep_time_expiration_less_than_one_hour(
ice_servers_api: ice_servers.IceServers,
):
"""Test get refresh sleep time."""
min_timestamp = 10

ice_servers_api._ice_servers = [
RTCIceServer(urls="turn:example.com:80", username="12345678:test-user"),
RTCIceServer(
urls="turn:example.com:80",
username=f"{min_timestamp!s}:test-user",
),
]

refresh_time = ice_servers_api._get_refresh_sleep_time()

assert refresh_time >= 100
assert refresh_time <= 300