Skip to content

Commit

Permalink
Implement CommonQuery tests from shared GH action with test class (#16)
Browse files Browse the repository at this point in the history
adapted from NeonCore

Co-authored-by: Daniel McKnight <daniel@neon.ai>
  • Loading branch information
NeonDaniel and NeonDaniel authored Jan 15, 2024
1 parent 573d7c2 commit 074ad7e
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 2 deletions.
179 changes: 179 additions & 0 deletions neon_minerva/intent_services/common_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@

import time
from dataclasses import dataclass
from threading import Event
from typing import Dict

from ovos_utils import flatten_list
from ovos_utils.log import LOG
from neon_minerva.intent_services import IntentMatch


EXTENSION_TIME = 15
MIN_RESPONSE_WAIT = 3


@dataclass
class Query:
session_id: str
query: str
replies: list = None
extensions: list = None
query_time: float = time.time()
timeout_time: float = time.time() + 1
responses_gathered: Event = Event()
completed: Event = Event()
answered: bool = False


class CommonQuery:
def __init__(self, bus):
self.bus = bus
self.skill_id = "common_query.test" # fake skill
self.active_queries: Dict[str, Query] = dict()
self._vocabs = {}
self.bus.on('question:query.response', self.handle_query_response)
self.bus.on('common_query.question', self.handle_question)
# TODO: Register available CommonQuery skills

def is_question_like(self, utterance, lang):
# skip utterances with less than 3 words
if len(utterance.split(" ")) < 3:
return False
return True

def match(self, utterances, lang, message):
"""Send common query request and select best response
Args:
utterances (list): List of tuples,
utterances and normalized version
lang (str): Language code
message: Message for session context
Returns:
IntentMatch or None
"""
# we call flatten in case someone is sending the old style list of tuples
utterances = flatten_list(utterances)
match = None
for utterance in utterances:
if self.is_question_like(utterance, lang):
message.data["lang"] = lang # only used for speak
message.data["utterance"] = utterance
answered = self.handle_question(message)
if answered:
match = IntentMatch('CommonQuery', None, {}, None,
utterance)
break
return match

def handle_question(self, message):
"""
Send the phrase to the CommonQuerySkills and prepare for handling
the replies.
"""
utt = message.data.get('utterance')
sid = "test_session"
# TODO: Why are defaults not creating new objects on init?
query = Query(session_id=sid, query=utt, replies=[], extensions=[],
query_time=time.time(), timeout_time=time.time() + 1,
responses_gathered=Event(), completed=Event(),
answered=False)
assert query.responses_gathered.is_set() is False
assert query.completed.is_set() is False
self.active_queries[sid] = query

LOG.info(f'Searching for {utt}')
# Send the query to anyone listening for them
msg = message.reply('question:query', data={'phrase': utt})
if "skill_id" not in msg.context:
msg.context["skill_id"] = self.skill_id
self.bus.emit(msg)

query.timeout_time = time.time() + 1
timeout = False
while not query.responses_gathered.wait(EXTENSION_TIME):
if time.time() > query.timeout_time + 1:
LOG.debug(f"Timeout gathering responses ({query.session_id})")
timeout = True
break

# forcefully timeout if search is still going
if timeout:
LOG.warning(f"Timed out getting responses for: {query.query}")
self._query_timeout(message)
if not query.completed.wait(10):
raise TimeoutError("Timed out processing responses")
answered = bool(query.answered)
self.active_queries.pop(sid)
LOG.debug(f"answered={answered}|"
f"remaining active_queries={len(self.active_queries)}")
return answered

def handle_query_response(self, message):
search_phrase = message.data['phrase']
skill_id = message.data['skill_id']
searching = message.data.get('searching')
answer = message.data.get('answer')

query = self.active_queries.get("test_session")
if not query:
LOG.warning(f"No active query for: {search_phrase}")
# Manage requests for time to complete searches
if searching:
LOG.debug(f"{skill_id} is searching")
# request extending the timeout by EXTENSION_TIME
query.timeout_time = time.time() + EXTENSION_TIME
# TODO: Perhaps block multiple extensions?
if skill_id not in query.extensions:
query.extensions.append(skill_id)
else:
# Search complete, don't wait on this skill any longer
if answer:
LOG.info(f'Answer from {skill_id}')
query.replies.append(message.data)

# Remove the skill from list of timeout extensions
if skill_id in query.extensions:
LOG.debug(f"Done waiting for {skill_id}")
query.extensions.remove(skill_id)

time_to_wait = query.query_time + MIN_RESPONSE_WAIT - time.time()
if time_to_wait > 0:
LOG.debug(f"Waiting {time_to_wait}s before checking extensions")
query.responses_gathered.wait(time_to_wait)
# not waiting for any more skills
if not query.extensions:
LOG.debug(f"No more skills to wait for ({query.session_id})")
query.responses_gathered.set()

def _query_timeout(self, message):
query = self.active_queries.get("test_session")
LOG.info(f'Check responses with {len(query.replies)} replies')
search_phrase = message.data.get('phrase', "")
if query.extensions:
query.extensions = []

# Look at any replies that arrived before the timeout
# Find response(s) with the highest confidence
best = None
ties = []
for response in query.replies:
if not best or response['conf'] > best['conf']:
best = response
ties = []
elif response['conf'] == best['conf']:
ties.append(response)

if best:
# invoke best match
LOG.info('Handling with: ' + str(best['skill_id']))
cb = best.get('callback_data') or {}
self.bus.emit(message.forward('question:action',
data={'skill_id': best['skill_id'],
'phrase': search_phrase,
'callback_data': cb}))
query.answered = True
else:
query.answered = False
query.completed.set()
53 changes: 51 additions & 2 deletions neon_minerva/tests/test_skill_intents.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

from os import getenv
from os.path import join, exists
from unittest.mock import Mock

from ovos_bus_client import Message
from ovos_utils.messagebus import FakeBus
from ovos_utils.log import LOG

Expand All @@ -38,6 +41,7 @@
from neon_minerva.intent_services.padatious import PadatiousContainer, TestPadatiousMatcher
from neon_minerva.intent_services.adapt import AdaptContainer
from neon_minerva.intent_services.padacioso import PadaciosoContainer
from neon_minerva.intent_services.common_query import CommonQuery
from neon_minerva.intent_services import IntentMatch


Expand Down Expand Up @@ -81,6 +85,9 @@ class TestSkillIntentMatching(unittest.TestCase):
bus)
adapt_services[lang] = AdaptContainer(lang, bus)

if common_query:
common_query_service = CommonQuery(bus)

skill = get_skill_object(skill_entrypoint=skill_entrypoint,
skill_id=test_skill_id, bus=bus,
config_patch=core_config_patch)
Expand Down Expand Up @@ -148,8 +155,50 @@ def test_negative_intents(self):
padatious.test_intent(utt)

def test_common_query(self):
# TODO
pass
if not self.common_query:
return

qa_callback = Mock()
qa_response = Mock()
self.skill.events.add('question:action', qa_callback)
self.skill.events.add('question:query.response', qa_response)
for lang in self.common_query.keys():
for utt in self.common_query[lang]:
if isinstance(utt, dict):
data = list(utt.values())[0]
utt = list(utt.keys())[0]
else:
data = dict()
message = Message('test_utterance',
{"utterances": [utt], "lang": lang})
self.common_query_service.handle_question(message)
response = qa_response.call_args[0][0]
callback = qa_response.call_args[0][0]
self.assertIsInstance(response, Message)
self.assertTrue(response.data["phrase"] in utt)
self.assertEqual(response.data["skill_id"], self.skill.skill_id)
self.assertIn("callback_data", response.data.keys())
self.assertIsInstance(response.data["conf"], float)
self.assertIsInstance(response.data["answer"], str)

self.assertIsInstance(callback, Message)
self.assertEqual(callback.data['skill_id'], self.skill.skill_id)
self.assertEqual(callback.data['phrase'],
response.data['phrase'])
if not data:
continue
if isinstance(data.get('callback'), dict):
self.assertEqual(callback.data['callback_data'],
data['callback'])
elif isinstance(data.get('callback'), list):
self.assertEqual(set(callback.data['callback_data'].keys()),
set(data.get('callback')))
if data.get('min_confidence'):
self.assertGreaterEqual(response.data['conf'],
data['min_confidence'])
if data.get('max_confidence'):
self.assertLessEqual(response.data['conf'],
data['max_confidence'])

def test_common_play(self):
# TODO
Expand Down

0 comments on commit 074ad7e

Please sign in to comment.