Skip to content

Commit

Permalink
Recover mqtt abbrevations optimizations (home-assistant#118762)
Browse files Browse the repository at this point in the history
Co-authored-by: J. Nick Koston <nick@koston.org>
  • Loading branch information
2 people authored and dgomes committed Jun 4, 2024
1 parent b140a39 commit 245e788
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 61 deletions.
143 changes: 85 additions & 58 deletions homeassistant/components/mqtt/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
from .schemas import MQTT_ORIGIN_INFO_SCHEMA
from .util import async_forward_entry_setup_and_setup_discovery

ABBREVIATIONS_SET = set(ABBREVIATIONS)
DEVICE_ABBREVIATIONS_SET = set(DEVICE_ABBREVIATIONS)
ORIGIN_ABBREVIATIONS_SET = set(ORIGIN_ABBREVIATIONS)

_LOGGER = logging.getLogger(__name__)

TOPIC_MATCHER = re.compile(
Expand Down Expand Up @@ -105,6 +109,82 @@ def async_log_discovery_origin_info(
)


@callback
def _replace_abbreviations(
payload: Any | dict[str, Any],
abbreviations: dict[str, str],
abbreviations_set: set[str],
) -> None:
"""Replace abbreviations in an MQTT discovery payload."""
if not isinstance(payload, dict):
return
for key in abbreviations_set.intersection(payload):
payload[abbreviations[key]] = payload.pop(key)


@callback
def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
"""Replace all abbreviations in an MQTT discovery payload."""

_replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET)

if CONF_ORIGIN in discovery_payload:
_replace_abbreviations(
discovery_payload[CONF_ORIGIN],
ORIGIN_ABBREVIATIONS,
ORIGIN_ABBREVIATIONS_SET,
)

if CONF_DEVICE in discovery_payload:
_replace_abbreviations(
discovery_payload[CONF_DEVICE],
DEVICE_ABBREVIATIONS,
DEVICE_ABBREVIATIONS_SET,
)

if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)


@callback
def _replace_topic_base(discovery_payload: dict[str, Any]) -> None:
"""Replace topic base in MQTT discovery data."""
base = discovery_payload.pop(TOPIC_BASE)
for key, value in discovery_payload.items():
if isinstance(value, str) and value:
if value[0] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{base}{value[1:]}"
if value[-1] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{value[:-1]}{base}"
if discovery_payload.get(CONF_AVAILABILITY):
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
if not isinstance(availability_conf, dict):
continue
if topic := str(availability_conf.get(CONF_TOPIC)):
if topic[0] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
if topic[-1] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"


@callback
def _valid_origin_info(discovery_payload: MQTTDiscoveryPayload) -> bool:
"""Parse and validate origin info from a single component discovery payload."""
if CONF_ORIGIN not in discovery_payload:
return True
try:
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
except Exception as exc: # noqa:BLE001
_LOGGER.warning(
"Unable to parse origin information from discovery message: %s, got %s",
exc,
discovery_payload[CONF_ORIGIN],
)
return False
return True


async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic: str, config_entry: ConfigEntry
) -> None:
Expand Down Expand Up @@ -168,67 +248,14 @@ def async_discovery_message_received(msg: ReceiveMessage) -> None: # noqa: C901
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return
_replace_all_abbreviations(discovery_payload)
if not _valid_origin_info(discovery_payload):
return
if TOPIC_BASE in discovery_payload:
_replace_topic_base(discovery_payload)
else:
discovery_payload = MQTTDiscoveryPayload({})

for key in list(discovery_payload):
abbreviated_key = key
key = ABBREVIATIONS.get(key, key)
discovery_payload[key] = discovery_payload.pop(abbreviated_key)

if CONF_DEVICE in discovery_payload:
device = discovery_payload[CONF_DEVICE]
for key in list(device):
abbreviated_key = key
key = DEVICE_ABBREVIATIONS.get(key, key)
device[key] = device.pop(abbreviated_key)

if CONF_ORIGIN in discovery_payload:
origin_info: dict[str, Any] = discovery_payload[CONF_ORIGIN]
try:
for key in list(origin_info):
abbreviated_key = key
key = ORIGIN_ABBREVIATIONS.get(key, key)
origin_info[key] = origin_info.pop(abbreviated_key)
MQTT_ORIGIN_INFO_SCHEMA(discovery_payload[CONF_ORIGIN])
except Exception: # noqa: BLE001
_LOGGER.warning(
"Unable to parse origin information "
"from discovery message, got %s",
discovery_payload[CONF_ORIGIN],
)
return

if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(
discovery_payload[CONF_AVAILABILITY]
):
if isinstance(availability_conf, dict):
for key in list(availability_conf):
abbreviated_key = key
key = ABBREVIATIONS.get(key, key)
availability_conf[key] = availability_conf.pop(abbreviated_key)

if TOPIC_BASE in discovery_payload:
base = discovery_payload.pop(TOPIC_BASE)
for key, value in discovery_payload.items():
if isinstance(value, str) and value:
if value[0] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{base}{value[1:]}"
if value[-1] == TOPIC_BASE and key.endswith("topic"):
discovery_payload[key] = f"{value[:-1]}{base}"
if discovery_payload.get(CONF_AVAILABILITY):
for availability_conf in cv.ensure_list(
discovery_payload[CONF_AVAILABILITY]
):
if not isinstance(availability_conf, dict):
continue
if topic := str(availability_conf.get(CONF_TOPIC)):
if topic[0] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
if topic[-1] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{topic[:-1]}{base}"

# If present, the node_id will be included in the discovered object id
discovery_id = f"{node_id} {object_id}" if node_id else object_id
discovery_hash = (component, discovery_id)
Expand Down
4 changes: 1 addition & 3 deletions tests/components/mqtt/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,7 @@ async def test_discovery_with_invalid_integration_info(
state = hass.states.get("binary_sensor.beer")

assert state is None
assert (
"Unable to parse origin information from discovery message, got" in caplog.text
)
assert "Unable to parse origin information from discovery message" in caplog.text


async def test_discover_fan(
Expand Down

0 comments on commit 245e788

Please sign in to comment.