Skip to content

Commit

Permalink
Add RMQ unit test coverage
Browse files Browse the repository at this point in the history
Include `routing_key` in LLM responses to associate inputs/responses
  • Loading branch information
NeonDaniel committed Dec 2, 2024
1 parent 99cb462 commit 415d607
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 15 deletions.
21 changes: 13 additions & 8 deletions neon_llm_core/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,17 @@ def model(self) -> NeonLLM:
pass

@create_mq_callback()
def handle_request(self, body: dict):
def handle_request(self, body: dict) -> Thread:
"""
Handles ask requests (response to prompt) from MQ to LLM
:param body: request body (dict)
"""
# Handle this asynchronously so multiple subminds can be handled
# concurrently
Thread(target=self._handle_request_async, args=(body,),
daemon=True).start()
t = Thread(target=self._handle_request_async, args=(body,),
daemon=True)
t.start()
return t

def _handle_request_async(self, request: dict):
message_id = request["message_id"]
Expand All @@ -133,7 +135,8 @@ def _handle_request_async(self, request: dict):
response = ('Sorry, but I cannot respond to your message at the '
'moment, please try again later')
api_response = LLMProposeResponse(message_id=message_id,
response=response)
response=response,
routing_key=routing_key)
LOG.info(f"Sending response: {response}")
self.send_message(request_data=api_response.model_dump(),
queue=routing_key)
Expand All @@ -154,17 +157,18 @@ def handle_score_request(self, body: dict):
persona = body.get("persona", {})

if not responses:
sorted_answer_indexes = []
sorted_answer_idx = []
else:
try:
sorted_answer_indexes = self.model.get_sorted_answer_indexes(
sorted_answer_idx = self.model.get_sorted_answer_indexes(
question=query, answers=responses, persona=persona)
except ValueError as err:
LOG.error(f'ValueError={err}')
sorted_answer_indexes = []
sorted_answer_idx = []

api_response = LLMVoteResponse(message_id=message_id,
sorted_answer_indexes=sorted_answer_indexes)
routing_key=routing_key,
sorted_answer_indexes=sorted_answer_idx)
self.send_message(request_data=api_response.model_dump(),
queue=routing_key)
LOG.info(f"Handled score request for message_id={message_id}")
Expand Down Expand Up @@ -200,6 +204,7 @@ def handle_opinion_request(self, body: dict):
"an opinion on this topic")

api_response = LLMDiscussResponse(message_id=message_id,
routing_key=routing_key,
opinion=opinion)
self.send_message(request_data=api_response.model_dump(),
queue=routing_key)
Expand Down
124 changes: 117 additions & 7 deletions tests/test_rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from unittest.mock import Mock

from mirakuru import ProcessExitedWithError
from neon_mq_connector.utils.network_utils import dict_to_b64
from port_for import get_port
from pytest_rabbitmq.factories.executor import RabbitMqExecutor
from pytest_rabbitmq.factories.process import get_config
Expand All @@ -47,20 +48,24 @@ def __init__(self, rmq_port: int):
"neon_llm_mock_mq": {"user": "test_llm_user",
"password": "test_llm_password"}}}}
NeonLLMMQConnector.__init__(self, config=config)
self._model = Mock()
self._model.ask.return_value = "Mock response"
self._model.get_sorted_answer_indexes.return_value = [0, 1]
self.send_message = Mock()
self._compose_opinion_prompt = Mock(return_value="Mock opinion prompt")

@property
def name(self):
return "mock_mq"

@property
def model(self) -> NeonLLM:
return Mock()
return self._model

@staticmethod
def compose_opinion_prompt(respondent_nick: str,
def compose_opinion_prompt(self, respondent_nick: str,
question: str,
answer: str) -> str:
return "opinion prompt"
return self._compose_opinion_prompt(respondent_nick, question, answer)


@pytest.fixture(scope="class")
Expand Down Expand Up @@ -112,6 +117,8 @@ def rmq_instance(request, tmp_path_factory):

@pytest.mark.usefixtures("rmq_instance")
class TestNeonLLMMQConnector(TestCase):
mq_llm: NeonMockLlm = None
rmq_instance: RabbitMqExecutor = None

@classmethod
def tearDownClass(cls):
Expand All @@ -120,12 +127,115 @@ def tearDownClass(cls):
except ProcessExitedWithError:
pass

def test_00_init(self):
self.mq_llm = NeonMockLlm(self.rmq_instance.port)
def setUp(self):
if self.mq_llm is None:
self.mq_llm = NeonMockLlm(self.rmq_instance.port)

def test_00_init(self):
self.assertIn(self.mq_llm.name, self.mq_llm.service_name)
self.assertIsInstance(self.mq_llm.ovos_config, dict)
self.assertEqual(self.mq_llm.vhost, "/llm")
self.assertIsNotNone(self.mq_llm.model)
self.assertIsNotNone(self.mq_llm.model, self.mq_llm.model)
self.assertEqual(self.mq_llm._personas_provider.service_name,
self.mq_llm.name)

def test_handle_request(self):
from neon_data_models.models.api.mq import (LLMProposeRequest,
LLMProposeResponse)
# Valid Request
request = LLMProposeRequest(message_id="mock_message_id",
routing_key="mock_routing_key",
query="Mock Query", history=[])
self.mq_llm.handle_request(None, None, None,
dict_to_b64(request.model_dump())).join()
self.mq_llm.model.ask.assert_called_with(message=request.query,
chat_history=request.history,
persona=request.persona)
response = self.mq_llm.send_message.call_args.kwargs
self.assertEqual(response['queue'], request.routing_key)
response = LLMProposeResponse(**response['request_data'])
self.assertIsInstance(response, LLMProposeResponse)
self.assertEqual(request.routing_key, response.routing_key)
self.assertEqual(request.message_id, response.message_id)

self.assertEqual(response.response, self.mq_llm.model.ask())

def test_handle_opinion_request(self):
from neon_data_models.models.api.mq import (LLMDiscussRequest,
LLMDiscussResponse)
# Valid Request
request = LLMDiscussRequest(message_id="mock_message_id",
routing_key="mock_routing_key",
query="Mock Discuss", history=[],
options={"bot 1": "resp 1",
"bot 2": "resp 2"})
self.mq_llm.handle_opinion_request(None, None, None,
dict_to_b64(request.model_dump()))

self.mq_llm._compose_opinion_prompt.assert_called_with(
list(request.options.keys())[0], request.query,
list(request.options.values())[0])

response = self.mq_llm.send_message.call_args.kwargs
self.assertEqual(response['queue'], request.routing_key)
response = LLMDiscussResponse(**response['request_data'])
self.assertIsInstance(response, LLMDiscussResponse)
self.assertEqual(request.routing_key, response.routing_key)
self.assertEqual(request.message_id, response.message_id)

self.assertEqual(response.opinion, self.mq_llm.model.ask())

# No input options
request = LLMDiscussRequest(message_id="mock_message_id1",
routing_key="mock_routing_key1",
query="Mock Discuss 1", history=[],
options={})
self.mq_llm.handle_opinion_request(None, None, None,
dict_to_b64(request.model_dump()))
response = self.mq_llm.send_message.call_args.kwargs
self.assertEqual(response['queue'], request.routing_key)
response = LLMDiscussResponse(**response['request_data'])
self.assertIsInstance(response, LLMDiscussResponse)
self.assertEqual(request.routing_key, response.routing_key)
self.assertEqual(request.message_id, response.message_id)
self.assertNotEqual(response.opinion, self.mq_llm.model.ask())

# TODO: Test with invalid sorted answer indexes

def test_handle_score_request(self):
from neon_data_models.models.api.mq import (LLMVoteRequest,
LLMVoteResponse)

# Valid Request
request = LLMVoteRequest(message_id="mock_message_id",
routing_key="mock_routing_key",
query="Mock Score", history=[],
responses=["one", "two"])
self.mq_llm.handle_score_request(None, None, None,
dict_to_b64(request.model_dump()))

response = self.mq_llm.send_message.call_args.kwargs
self.assertEqual(response['queue'], request.routing_key)
response = LLMVoteResponse(**response['request_data'])
self.assertIsInstance(response, LLMVoteResponse)
self.assertEqual(request.routing_key, response.routing_key)
self.assertEqual(request.message_id, response.message_id)

self.assertEqual(response.sorted_answer_indexes,
self.mq_llm.model.get_sorted_answer_indexes())

# No response options
request = LLMVoteRequest(message_id="mock_message_id",
routing_key="mock_routing_key",
query="Mock Score", history=[], responses=[])
self.mq_llm.handle_score_request(None, None, None,
dict_to_b64(request.model_dump()))

response = self.mq_llm.send_message.call_args.kwargs
self.assertEqual(response['queue'], request.routing_key)
response = LLMVoteResponse(**response['request_data'])
self.assertIsInstance(response, LLMVoteResponse)
self.assertEqual(request.routing_key, response.routing_key)
self.assertEqual(request.message_id, response.message_id)

self.assertEqual(response.sorted_answer_indexes, [])

0 comments on commit 415d607

Please sign in to comment.