Skip to content

Commit

Permalink
Add index for floor/label to the area registry (#114777)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Apr 4, 2024
1 parent aa52688 commit aedfd6c
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 13 deletions.
47 changes: 43 additions & 4 deletions homeassistant/helpers/area_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,49 @@ async def _async_migrate_func(
return old_data


class AreaRegistryItems(NormalizedNameBaseRegistryItems[AreaEntry]):
"""Class to hold area registry items."""

def __init__(self) -> None:
"""Initialize the area registry items."""
super().__init__()
self._labels_index: dict[str, dict[str, Literal[True]]] = {}
self._floors_index: dict[str, dict[str, Literal[True]]] = {}

def _index_entry(self, key: str, entry: AreaEntry) -> None:
"""Index an entry."""
if entry.floor_id is not None:
self._floors_index.setdefault(entry.floor_id, {})[key] = True
for label in entry.labels:
self._labels_index.setdefault(label, {})[key] = True
super()._index_entry(key, entry)

def _unindex_entry(
self, key: str, replacement_entry: AreaEntry | None = None
) -> None:
entry = self.data[key]
if labels := entry.labels:
for label in labels:
self._unindex_entry_value(key, label, self._labels_index)
if floor_id := entry.floor_id:
self._unindex_entry_value(key, floor_id, self._floors_index)
return super()._unindex_entry(key, replacement_entry)

def get_areas_for_label(self, label: str) -> list[AreaEntry]:
"""Get areas for label."""
data = self.data
return [data[key] for key in self._labels_index.get(label, ())]

def get_areas_for_floor(self, floor: str) -> list[AreaEntry]:
"""Get areas for floor."""
data = self.data
return [data[key] for key in self._floors_index.get(floor, ())]


class AreaRegistry(BaseRegistry):
"""Class to hold a registry of areas."""

areas: NormalizedNameBaseRegistryItems[AreaEntry]
areas: AreaRegistryItems
_area_data: dict[str, AreaEntry]

def __init__(self, hass: HomeAssistant) -> None:
Expand Down Expand Up @@ -254,7 +293,7 @@ async def async_load(self) -> None:

data = await self._store.async_load()

areas = NormalizedNameBaseRegistryItems[AreaEntry]()
areas = AreaRegistryItems()

if data is not None:
for area in data["areas"]:
Expand Down Expand Up @@ -369,10 +408,10 @@ async def async_load(hass: HomeAssistant) -> None:
@callback
def async_entries_for_floor(registry: AreaRegistry, floor_id: str) -> list[AreaEntry]:
"""Return entries that match a floor."""
return [area for area in registry.areas.values() if floor_id == area.floor_id]
return registry.areas.get_areas_for_floor(floor_id)


@callback
def async_entries_for_label(registry: AreaRegistry, label_id: str) -> list[AreaEntry]:
"""Return entries that match a label."""
return [area for area in registry.areas.values() if label_id in area.labels]
return registry.areas.get_areas_for_label(label_id)
12 changes: 6 additions & 6 deletions homeassistant/helpers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,16 +537,16 @@ def async_extract_referenced_entity_ids( # noqa: C901
for device_entry in dev_reg.devices.get_devices_for_label(label_id):
selected.referenced_devices.add(device_entry.id)

# Find areas for targeted labels
for area_entry in area_reg.areas.values():
if area_entry.labels.intersection(selector.label_ids):
for area_entry in area_reg.areas.get_areas_for_label(label_id):
selected.referenced_areas.add(area_entry.id)

# Find areas for targeted floors
if selector.floor_ids:
for area_entry in area_reg.areas.values():
if area_entry.id and area_entry.floor_id in selector.floor_ids:
selected.referenced_areas.add(area_entry.id)
selected.referenced_areas.update(
area_entry.id
for floor_id in selector.floor_ids
for area_entry in area_reg.areas.get_areas_for_floor(floor_id)
)

# Find devices for targeted areas
selected.referenced_devices.update(selector.device_ids)
Expand Down
5 changes: 3 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import asyncio
from collections import OrderedDict
from collections.abc import AsyncGenerator, Generator, Mapping, Sequence
from contextlib import asynccontextmanager, contextmanager
from datetime import UTC, datetime, timedelta
Expand Down Expand Up @@ -649,7 +648,9 @@ def mock_area_registry(
fixture instead.
"""
registry = ar.AreaRegistry(hass)
registry.areas = mock_entries or OrderedDict()
registry.areas = ar.AreaRegistryItems()
for key, entry in mock_entries.items():
registry.areas[key] = entry

hass.data[ar.DATA_REGISTRY] = registry
return registry
Expand Down
2 changes: 1 addition & 1 deletion tests/components/conversation/test_default_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ async def test_empty_aliases(
area_kitchen = area_registry.async_get_or_create("kitchen_id")
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
area_kitchen = area_registry.async_update(
area_kitchen.id, aliases={" "}, floor_id=floor_1
area_kitchen.id, aliases={" "}, floor_id=floor_1.floor_id
)

entry = MockConfigEntry()
Expand Down

0 comments on commit aedfd6c

Please sign in to comment.