From aedfd6c983abe465b37343b83de30a8abd8cdcd1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 3 Apr 2024 21:04:26 -1000 Subject: [PATCH] Add index for floor/label to the area registry (#114777) --- homeassistant/helpers/area_registry.py | 47 +++++++++++++++++-- homeassistant/helpers/service.py | 12 ++--- tests/common.py | 5 +- .../conversation/test_default_agent.py | 2 +- 4 files changed, 53 insertions(+), 13 deletions(-) diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index fc535bed61012..24f58c56d2f30 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -87,10 +87,49 @@ async def _async_migrate_func( return old_data +class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]): + """Class to hold area registry items.""" + + def __init__(self) -> None: + """Initialize the area registry items.""" + super().__init__() + self._labels_index: dict[str, dict[str, Literal[True]]] = {} + self._floors_index: dict[str, dict[str, Literal[True]]] = {} + + def _index_entry(self, key: str, entry: AreaEntry) -> None: + """Index an entry.""" + if entry.floor_id is not None: + self._floors_index.setdefault(entry.floor_id, {})[key] = True + for label in entry.labels: + self._labels_index.setdefault(label, {})[key] = True + super()._index_entry(key, entry) + + def _unindex_entry( + self, key: str, replacement_entry: AreaEntry | None = None + ) -> None: + entry = self.data[key] + if labels := entry.labels: + for label in labels: + self._unindex_entry_value(key, label, self._labels_index) + if floor_id := entry.floor_id: + self._unindex_entry_value(key, floor_id, self._floors_index) + return super()._unindex_entry(key, replacement_entry) + + def get_areas_for_label(self, label: str) -> list[AreaEntry]: + """Get areas for label.""" + data = self.data + return [data[key] for key in self._labels_index.get(label, ())] + + def get_areas_for_floor(self, floor: str) -> list[AreaEntry]: + """Get areas for floor.""" + data = self.data + return [data[key] for key in self._floors_index.get(floor, ())] + + class AreaRegistry(BaseRegistry): """Class to hold a registry of areas.""" - areas: NormalizedNameBaseRegistryItems[AreaEntry] + areas: AreaRegistryItems _area_data: dict[str, AreaEntry] def __init__(self, hass: HomeAssistant) -> None: @@ -254,7 +293,7 @@ async def async_load(self) -> None: data = await self._store.async_load() - areas = NormalizedNameBaseRegistryItems[AreaEntry]() + areas = AreaRegistryItems() if data is not None: for area in data["areas"]: @@ -369,10 +408,10 @@ async def async_load(hass: HomeAssistant) -> None: @callback def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaEntry]: """Return entries that match a floor.""" - return [area for area in registry.areas.values() if floor_id == area.floor_id] + return registry.areas.get_areas_for_floor(floor_id) @callback def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]: """Return entries that match a label.""" - return [area for area in registry.areas.values() if label_id in area.labels] + return registry.areas.get_areas_for_label(label_id) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 00dfea235499c..9af02402bc0bc 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -537,16 +537,16 @@ def async_extract_referenced_entity_ids( # noqa: C901 for device_entry in dev_reg.devices.get_devices_for_label(label_id): selected.referenced_devices.add(device_entry.id) - # Find areas for targeted labels - for area_entry in area_reg.areas.values(): - if area_entry.labels.intersection(selector.label_ids): + for area_entry in area_reg.areas.get_areas_for_label(label_id): selected.referenced_areas.add(area_entry.id) # Find areas for targeted floors if selector.floor_ids: - for area_entry in area_reg.areas.values(): - if area_entry.id and area_entry.floor_id in selector.floor_ids: - selected.referenced_areas.add(area_entry.id) + selected.referenced_areas.update( + area_entry.id + for floor_id in selector.floor_ids + for area_entry in area_reg.areas.get_areas_for_floor(floor_id) + ) # Find devices for targeted areas selected.referenced_devices.update(selector.device_ids) diff --git a/tests/common.py b/tests/common.py index d3bcdcbd004d0..db96e36f7ecd8 100644 --- a/tests/common.py +++ b/tests/common.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio -from collections import OrderedDict from collections.abc import AsyncGenerator, Generator, Mapping, Sequence from contextlib import asynccontextmanager, contextmanager from datetime import UTC, datetime, timedelta @@ -649,7 +648,9 @@ def mock_area_registry( fixture instead. """ registry = ar.AreaRegistry(hass) - registry.areas = mock_entries or OrderedDict() + registry.areas = ar.AreaRegistryItems() + for key, entry in mock_entries.items(): + registry.areas[key] = entry hass.data[ar.DATA_REGISTRY] = registry return registry diff --git a/tests/components/conversation/test_default_agent.py b/tests/components/conversation/test_default_agent.py index 474198cb8a3ff..9048a1259c595 100644 --- a/tests/components/conversation/test_default_agent.py +++ b/tests/components/conversation/test_default_agent.py @@ -823,7 +823,7 @@ async def test_empty_aliases( area_kitchen = area_registry.async_get_or_create("kitchen_id") area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen") area_kitchen = area_registry.async_update( - area_kitchen.id, aliases={" "}, floor_id=floor_1 + area_kitchen.id, aliases={" "}, floor_id=floor_1.floor_id ) entry = MockConfigEntry()