diff --git a/neon_minerva/intent_services/common_query.py b/neon_minerva/intent_services/common_query.py new file mode 100644 index 0000000..290550d --- /dev/null +++ b/neon_minerva/intent_services/common_query.py @@ -0,0 +1,179 @@ + +import time +from dataclasses import dataclass +from threading import Event +from typing import Dict + +from ovos_utils import flatten_list +from ovos_utils.log import LOG +from neon_minerva.intent_services import IntentMatch + + +EXTENSION_TIME = 15 +MIN_RESPONSE_WAIT = 3 + + +@dataclass +class Query: + session_id: str + query: str + replies: list = None + extensions: list = None + query_time: float = time.time() + timeout_time: float = time.time() + 1 + responses_gathered: Event = Event() + completed: Event = Event() + answered: bool = False + + +class CommonQuery: + def __init__(self, bus): + self.bus = bus + self.skill_id = "common_query.test" # fake skill + self.active_queries: Dict[str, Query] = dict() + self._vocabs = {} + self.bus.on('question:query.response', self.handle_query_response) + self.bus.on('common_query.question', self.handle_question) + # TODO: Register available CommonQuery skills + + def is_question_like(self, utterance, lang): + # skip utterances with less than 3 words + if len(utterance.split(" ")) < 3: + return False + return True + + def match(self, utterances, lang, message): + """Send common query request and select best response + + Args: + utterances (list): List of tuples, + utterances and normalized version + lang (str): Language code + message: Message for session context + Returns: + IntentMatch or None + """ + # we call flatten in case someone is sending the old style list of tuples + utterances = flatten_list(utterances) + match = None + for utterance in utterances: + if self.is_question_like(utterance, lang): + message.data["lang"] = lang # only used for speak + message.data["utterance"] = utterance + answered = self.handle_question(message) + if answered: + match = IntentMatch('CommonQuery', None, {}, None, + utterance) + break + return match + + def handle_question(self, message): + """ + Send the phrase to the CommonQuerySkills and prepare for handling + the replies. + """ + utt = message.data.get('utterance') + sid = "test_session" + # TODO: Why are defaults not creating new objects on init? + query = Query(session_id=sid, query=utt, replies=[], extensions=[], + query_time=time.time(), timeout_time=time.time() + 1, + responses_gathered=Event(), completed=Event(), + answered=False) + assert query.responses_gathered.is_set() is False + assert query.completed.is_set() is False + self.active_queries[sid] = query + + LOG.info(f'Searching for {utt}') + # Send the query to anyone listening for them + msg = message.reply('question:query', data={'phrase': utt}) + if "skill_id" not in msg.context: + msg.context["skill_id"] = self.skill_id + self.bus.emit(msg) + + query.timeout_time = time.time() + 1 + timeout = False + while not query.responses_gathered.wait(EXTENSION_TIME): + if time.time() > query.timeout_time + 1: + LOG.debug(f"Timeout gathering responses ({query.session_id})") + timeout = True + break + + # forcefully timeout if search is still going + if timeout: + LOG.warning(f"Timed out getting responses for: {query.query}") + self._query_timeout(message) + if not query.completed.wait(10): + raise TimeoutError("Timed out processing responses") + answered = bool(query.answered) + self.active_queries.pop(sid) + LOG.debug(f"answered={answered}|" + f"remaining active_queries={len(self.active_queries)}") + return answered + + def handle_query_response(self, message): + search_phrase = message.data['phrase'] + skill_id = message.data['skill_id'] + searching = message.data.get('searching') + answer = message.data.get('answer') + + query = self.active_queries.get("test_session") + if not query: + LOG.warning(f"No active query for: {search_phrase}") + # Manage requests for time to complete searches + if searching: + LOG.debug(f"{skill_id} is searching") + # request extending the timeout by EXTENSION_TIME + query.timeout_time = time.time() + EXTENSION_TIME + # TODO: Perhaps block multiple extensions? + if skill_id not in query.extensions: + query.extensions.append(skill_id) + else: + # Search complete, don't wait on this skill any longer + if answer: + LOG.info(f'Answer from {skill_id}') + query.replies.append(message.data) + + # Remove the skill from list of timeout extensions + if skill_id in query.extensions: + LOG.debug(f"Done waiting for {skill_id}") + query.extensions.remove(skill_id) + + time_to_wait = query.query_time + MIN_RESPONSE_WAIT - time.time() + if time_to_wait > 0: + LOG.debug(f"Waiting {time_to_wait}s before checking extensions") + query.responses_gathered.wait(time_to_wait) + # not waiting for any more skills + if not query.extensions: + LOG.debug(f"No more skills to wait for ({query.session_id})") + query.responses_gathered.set() + + def _query_timeout(self, message): + query = self.active_queries.get("test_session") + LOG.info(f'Check responses with {len(query.replies)} replies') + search_phrase = message.data.get('phrase', "") + if query.extensions: + query.extensions = [] + + # Look at any replies that arrived before the timeout + # Find response(s) with the highest confidence + best = None + ties = [] + for response in query.replies: + if not best or response['conf'] > best['conf']: + best = response + ties = [] + elif response['conf'] == best['conf']: + ties.append(response) + + if best: + # invoke best match + LOG.info('Handling with: ' + str(best['skill_id'])) + cb = best.get('callback_data') or {} + self.bus.emit(message.forward('question:action', + data={'skill_id': best['skill_id'], + 'phrase': search_phrase, + 'callback_data': cb})) + query.answered = True + else: + query.answered = False + query.completed.set() diff --git a/neon_minerva/tests/test_skill_intents.py b/neon_minerva/tests/test_skill_intents.py index 0a36513..60977dc 100644 --- a/neon_minerva/tests/test_skill_intents.py +++ b/neon_minerva/tests/test_skill_intents.py @@ -30,6 +30,9 @@ from os import getenv from os.path import join, exists +from unittest.mock import Mock + +from ovos_bus_client import Message from ovos_utils.messagebus import FakeBus from ovos_utils.log import LOG @@ -38,6 +41,7 @@ from neon_minerva.intent_services.padatious import PadatiousContainer, TestPadatiousMatcher from neon_minerva.intent_services.adapt import AdaptContainer from neon_minerva.intent_services.padacioso import PadaciosoContainer +from neon_minerva.intent_services.common_query import CommonQuery from neon_minerva.intent_services import IntentMatch @@ -81,6 +85,9 @@ class TestSkillIntentMatching(unittest.TestCase): bus) adapt_services[lang] = AdaptContainer(lang, bus) + if common_query: + common_query_service = CommonQuery(bus) + skill = get_skill_object(skill_entrypoint=skill_entrypoint, skill_id=test_skill_id, bus=bus, config_patch=core_config_patch) @@ -148,8 +155,50 @@ def test_negative_intents(self): padatious.test_intent(utt) def test_common_query(self): - # TODO - pass + if not self.common_query: + return + + qa_callback = Mock() + qa_response = Mock() + self.skill.events.add('question:action', qa_callback) + self.skill.events.add('question:query.response', qa_response) + for lang in self.common_query.keys(): + for utt in self.common_query[lang]: + if isinstance(utt, dict): + data = list(utt.values())[0] + utt = list(utt.keys())[0] + else: + data = dict() + message = Message('test_utterance', + {"utterances": [utt], "lang": lang}) + self.common_query_service.handle_question(message) + response = qa_response.call_args[0][0] + callback = qa_response.call_args[0][0] + self.assertIsInstance(response, Message) + self.assertTrue(response.data["phrase"] in utt) + self.assertEqual(response.data["skill_id"], self.skill.skill_id) + self.assertIn("callback_data", response.data.keys()) + self.assertIsInstance(response.data["conf"], float) + self.assertIsInstance(response.data["answer"], str) + + self.assertIsInstance(callback, Message) + self.assertEqual(callback.data['skill_id'], self.skill.skill_id) + self.assertEqual(callback.data['phrase'], + response.data['phrase']) + if not data: + continue + if isinstance(data.get('callback'), dict): + self.assertEqual(callback.data['callback_data'], + data['callback']) + elif isinstance(data.get('callback'), list): + self.assertEqual(set(callback.data['callback_data'].keys()), + set(data.get('callback'))) + if data.get('min_confidence'): + self.assertGreaterEqual(response.data['conf'], + data['min_confidence']) + if data.get('max_confidence'): + self.assertLessEqual(response.data['conf'], + data['max_confidence']) def test_common_play(self): # TODO