Skip to content

Commit

Permalink
Sentence trigger (#94613)
Browse files Browse the repository at this point in the history
* Add async_register_trigger_sentences for default agent

* Add trigger response and trigger handler

* Check callback in test

* Clean up and move response to callback

* Add trigger test

* Drop TriggerAction

* Test we pass sentence to callback

* Match triggers once, allow multiple sentences

* Don't use trigger id

* Use async callback

* No response for now

* Use asyncio.gather for callback responses

* Fix after rebase

* Use a list for trigger sentences

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
  • Loading branch information
synesthesiam and balloob authored Jun 22, 2023
1 parent 29ef925 commit d811fa0
Show file tree
Hide file tree
Showing 4 changed files with 382 additions and 2 deletions.
114 changes: 113 additions & 1 deletion homeassistant/components/conversation/default_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import asyncio
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass
import functools
import logging
from pathlib import Path
import re
Expand Down Expand Up @@ -42,6 +43,9 @@
_ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]

REGEX_TYPE = type(re.compile(""))
TRIGGER_CALLBACK_TYPE = Callable[ # pylint: disable=invalid-name
[str], Awaitable[str | None]
]


def json_load(fp: IO[str]) -> JsonObjectType:
Expand All @@ -60,6 +64,14 @@ class LanguageIntents:
loaded_components: set[str]


@dataclass(slots=True)
class TriggerData:
"""List of sentences and the callback for a trigger."""

sentences: list[str]
callback: TRIGGER_CALLBACK_TYPE


def _get_language_variations(language: str) -> Iterable[str]:
"""Generate language codes with and without region."""
yield language
Expand Down Expand Up @@ -110,6 +122,10 @@ def __init__(self, hass: core.HomeAssistant) -> None:
self._config_intents: dict[str, Any] = {}
self._slot_lists: dict[str, SlotList] | None = None

# Sentences that will trigger a callback (skipping intent recognition)
self._trigger_sentences: list[TriggerData] = []
self._trigger_intents: Intents | None = None

@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
Expand Down Expand Up @@ -174,6 +190,9 @@ async def async_recognize(

async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
if trigger_result := await self._match_triggers(user_input.text):
return trigger_result

language = user_input.language or self.hass.config.language
conversation_id = None # Not supported

Expand Down Expand Up @@ -605,6 +624,99 @@ def _get_error_text(
response_str = lang_intents.error_responses.get(response_key)
return response_str or _DEFAULT_ERROR_TEXT

def register_trigger(
self,
sentences: list[str],
callback: TRIGGER_CALLBACK_TYPE,
) -> core.CALLBACK_TYPE:
"""Register a list of sentences that will trigger a callback when recognized."""
trigger_data = TriggerData(sentences=sentences, callback=callback)
self._trigger_sentences.append(trigger_data)

# Force rebuild on next use
self._trigger_intents = None

unregister = functools.partial(self._unregister_trigger, trigger_data)
return unregister

def _rebuild_trigger_intents(self) -> None:
"""Rebuild the HassIL intents object from the current trigger sentences."""
intents_dict = {
"language": self.hass.config.language,
"intents": {
# Use trigger data index as a virtual intent name for HassIL.
# This works because the intents are rebuilt on every
# register/unregister.
str(trigger_id): {"data": [{"sentences": trigger_data.sentences}]}
for trigger_id, trigger_data in enumerate(self._trigger_sentences)
},
}

self._trigger_intents = Intents.from_dict(intents_dict)
_LOGGER.debug("Rebuilt trigger intents: %s", intents_dict)

def _unregister_trigger(self, trigger_data: TriggerData) -> None:
"""Unregister a set of trigger sentences."""
self._trigger_sentences.remove(trigger_data)

# Force rebuild on next use
self._trigger_intents = None

async def _match_triggers(self, sentence: str) -> ConversationResult | None:
"""Try to match sentence against registered trigger sentences.
Calls the registered callbacks if there's a match and returns a positive
conversation result.
"""
if not self._trigger_sentences:
# No triggers registered
return None

if self._trigger_intents is None:
# Need to rebuild intents before matching
self._rebuild_trigger_intents()

assert self._trigger_intents is not None

matched_triggers: set[int] = set()
for result in recognize_all(sentence, self._trigger_intents):
trigger_id = int(result.intent.name)
if trigger_id in matched_triggers:
# Already matched a sentence from this trigger
break

matched_triggers.add(trigger_id)

if not matched_triggers:
# Sentence did not match any trigger sentences
return None

_LOGGER.debug(
"'%s' matched %s trigger(s): %s",
sentence,
len(matched_triggers),
matched_triggers,
)

# Gather callback responses in parallel
trigger_responses = await asyncio.gather(
*(
self._trigger_sentences[trigger_id].callback(sentence)
for trigger_id in matched_triggers
)
)

# Use last non-empty result as speech response
speech: str | None = None
for trigger_response in trigger_responses:
speech = speech or trigger_response

response = intent.IntentResponse(language=self.hass.config.language)
response.response_type = intent.IntentResponseType.ACTION_DONE
response.async_set_speech(speech or "")

return ConversationResult(response=response)


def _make_error_result(
language: str,
Expand Down
59 changes: 59 additions & 0 deletions homeassistant/components/conversation/trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Offer sentence based automation rules."""
from __future__ import annotations

from typing import Any

import voluptuous as vol

from homeassistant.const import CONF_COMMAND, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType

from . import HOME_ASSISTANT_AGENT, _get_agent_manager
from .const import DOMAIN
from .default_agent import DefaultAgent

TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
{
vol.Required(CONF_PLATFORM): DOMAIN,
vol.Required(CONF_COMMAND): vol.All(cv.ensure_list, [cv.string]),
}
)


async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Listen for events based on configuration."""
trigger_data = trigger_info["trigger_data"]
sentences = config.get(CONF_COMMAND, [])

job = HassJob(action)

@callback
async def call_action(sentence: str) -> str | None:
"""Call action with right context."""
trigger_input: dict[str, Any] = { # Satisfy type checker
**trigger_data,
"platform": DOMAIN,
"sentence": sentence,
}

# Wait for the automation to complete
if future := hass.async_run_hass_job(
job,
{"trigger": trigger_input},
):
await future

return None

default_agent = await _get_agent_manager(hass).async_get_agent(HOME_ASSISTANT_AGENT)
assert isinstance(default_agent, DefaultAgent)

return default_agent.register_trigger(sentences, call_action)
44 changes: 43 additions & 1 deletion tests/components/conversation/test_default_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test for the default agent."""
from unittest.mock import patch
from unittest.mock import AsyncMock, patch

import pytest

Expand Down Expand Up @@ -223,3 +223,45 @@ async def test_unexposed_entities_skipped(
assert result.response.response_type == intent.IntentResponseType.QUERY_ANSWER
assert len(result.response.matched_states) == 1
assert result.response.matched_states[0].entity_id == exposed_light.entity_id


async def test_trigger_sentences(hass: HomeAssistant, init_components) -> None:
"""Test registering/unregistering/matching a few trigger sentences."""
trigger_sentences = ["It's party time", "It is time to party"]
trigger_response = "Cowabunga!"

agent = await conversation._get_agent_manager(hass).async_get_agent(
conversation.HOME_ASSISTANT_AGENT
)
assert isinstance(agent, conversation.DefaultAgent)

callback = AsyncMock(return_value=trigger_response)
unregister = agent.register_trigger(trigger_sentences, callback)

result = await conversation.async_converse(hass, "Not the trigger", None, Context())
assert result.response.response_type == intent.IntentResponseType.ERROR

# Using different case and including punctuation
test_sentences = ["it's party time!", "IT IS TIME TO PARTY."]
for sentence in test_sentences:
callback.reset_mock()
result = await conversation.async_converse(hass, sentence, None, Context())
callback.assert_called_once_with(sentence)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), sentence
assert result.response.speech == {
"plain": {"speech": trigger_response, "extra_data": None}
}

unregister()

# Should produce errors now
callback.reset_mock()
for sentence in test_sentences:
result = await conversation.async_converse(hass, sentence, None, Context())
assert (
result.response.response_type == intent.IntentResponseType.ERROR
), sentence

assert len(callback.mock_calls) == 0
Loading

0 comments on commit d811fa0

Please sign in to comment.