Skip to content

Commit

Permalink
Outline test_llm TestCase
Browse files Browse the repository at this point in the history
Add tests of MQ request/response handling in `chatbot` module
Update `chatbot` module to use Pydantic models in place of `dict` objects for MQ message validation
  • Loading branch information
NeonDaniel committed Nov 22, 2024
1 parent 7762472 commit 3abd732
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 52 deletions.
94 changes: 52 additions & 42 deletions neon_llm_core/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import List
from typing import List, Optional
from chatbot_core.v2 import ChatBot
from neon_data_models.models.api.mq import (LLMProposeRequest,
LLMDiscussRequest, LLMVoteRequest)
LLMDiscussRequest, LLMVoteRequest, LLMProposeResponse, LLMDiscussResponse,
LLMVoteResponse)
from neon_mq_connector.utils.client_utils import send_mq_request
from neon_utils.logger import LOG
from neon_data_models.models.api.llm import LLMPersona
Expand Down Expand Up @@ -66,9 +67,8 @@ def ask_chatbot(self, user: str, shout: str, timestamp: str,
if prompt_id:
self.prompt_id_to_shout[prompt_id] = shout
LOG.debug(f"Getting response to {shout}")
response = self._get_llm_api_response(
shout=shout).get("response", "I have nothing to say here...")
return response
response = self._get_llm_api_response(shout=shout)
return response.response if response else "I have nothing to say here..."

def ask_discusser(self, options: dict, context: dict = None) -> str:
"""
Expand All @@ -81,8 +81,8 @@ def ask_discusser(self, options: dict, context: dict = None) -> str:
prompt_sentence = self.prompt_id_to_shout.get(context['prompt_id'], '')
LOG.info(f'prompt_sentence={prompt_sentence}, options={options}')
opinion = self._get_llm_api_opinion(prompt=prompt_sentence,
options=options).get('opinion', '')
return opinion
options=options)
return opinion.opinion if opinion else "I have nothing to say here..."

def ask_appraiser(self, options: dict, context: dict = None) -> str:
"""
Expand All @@ -101,16 +101,15 @@ def ask_appraiser(self, options: dict, context: dict = None) -> str:
answer_data = self._get_llm_api_choice(prompt=prompt,
responses=bot_responses)
LOG.info(f'Received answer_data={answer_data}')
sorted_answer_indexes = answer_data.get('sorted_answer_indexes')
if sorted_answer_indexes:
return bots[sorted_answer_indexes[0]]
if answer_data and answer_data.sorted_answer_indexes:
return bots[answer_data.sorted_answer_indexes[0]]
return "abstain"

def _get_llm_api_response(self, shout: str) -> dict:
def _get_llm_api_response(self, shout: str) -> Optional[LLMProposeResponse]:
"""
Requests LLM API for response on provided shout
:param shout: provided should string
:returns response string from LLM API
:returns response from LLM API
"""
queue = self.mq_queue_config.ask_response_queue
LOG.info(f"Sending to {self.mq_queue_config.vhost}/{queue}")
Expand All @@ -120,54 +119,65 @@ def _get_llm_api_response(self, shout: str) -> dict:
query=shout,
history=[],
message_id="")
return send_mq_request(vhost=self.mq_queue_config.vhost,
request_data=request_data.model_dump(),
target_queue=queue,
response_queue=f"{queue}.response")
resp_data = send_mq_request(vhost=self.mq_queue_config.vhost,
request_data=request_data.model_dump(),
target_queue=queue,
response_queue=f"{queue}.response")
return LLMProposeResponse(**resp_data)
except Exception as e:
LOG.exception(f"Failed to get response on "
f"{self.mq_queue_config.vhost}/"
f"{self.mq_queue_config.ask_response_queue}: "
f"{e}")
return dict()
f"{self.mq_queue_config.vhost}/{queue}: {e}")

def _get_llm_api_opinion(self, prompt: str, options: dict) -> dict:
def _get_llm_api_opinion(self, prompt: str,
options: dict) -> Optional[LLMDiscussResponse]:
"""
Requests LLM API for discussion of provided submind responses
:param prompt: incoming prompt text
:param options: proposed responses (botname: response)
:returns response data from LLM API
"""
queue = self.mq_queue_config.ask_discusser_queue
request_data = LLMDiscussRequest(model=self.base_llm,
persona=self.persona,
query=prompt,
options=options,
history=[],
message_id="")
return send_mq_request(vhost=self.mq_queue_config.vhost,
request_data=request_data.model_dump(),
target_queue=queue,
response_queue=f"{queue}.response")
try:
request_data = LLMDiscussRequest(model=self.base_llm,
persona=self.persona,
query=prompt,
options=options,
history=[],
message_id="")
resp_data = send_mq_request(vhost=self.mq_queue_config.vhost,
request_data=request_data.model_dump(),
target_queue=queue,
response_queue=f"{queue}.response")
return LLMDiscussResponse(**resp_data)
except Exception as e:
LOG.exception(f"Failed to get response on "
f"{self.mq_queue_config.vhost}/{queue}: {e}")

def _get_llm_api_choice(self, prompt: str, responses: List[str]) -> dict:
def _get_llm_api_choice(self, prompt: str,
responses: List[str]) -> Optional[LLMVoteResponse]:
"""
Requests LLM API for choice among provided message list
:param prompt: incoming prompt text
:param responses: list of answers to select from
:returns response data from LLM API
"""
request_data = LLMVoteRequest(model=self.base_llm,
persona=self.persona,
query=prompt,
responses=responses,
history=[],
message_id="")
queue = self.mq_queue_config.ask_appraiser_queue
return send_mq_request(vhost=self.mq_queue_config.vhost,
request_data=request_data.model_dump(),
target_queue=queue,
response_queue=f"{queue}.response")

try:
request_data = LLMVoteRequest(model=self.base_llm,
persona=self.persona,
query=prompt,
responses=responses,
history=[],
message_id="")
resp_data = send_mq_request(vhost=self.mq_queue_config.vhost,
request_data=request_data.model_dump(),
target_queue=queue,
response_queue=f"{queue}.response")
return LLMVoteResponse(**resp_data)
except Exception as e:
LOG.exception(f"Failed to get response on "
f"{self.mq_queue_config.vhost}/{queue}: {e}")

@staticmethod
def get_llm_mq_config(llm_name: str) -> LLMMQConfig:
Expand Down
93 changes: 83 additions & 10 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from unittest import TestCase
from unittest.mock import patch

from neon_data_models.models.api import LLMPersona
from neon_data_models.models.api import LLMPersona, LLMProposeRequest, LLMProposeResponse, LLMDiscussRequest, \
LLMDiscussResponse, LLMVoteRequest, LLMVoteResponse
from pydantic import ValidationError

from neon_llm_core.chatbot import LLMBot
from neon_llm_core.utils.config import LLMMQConfig
Expand Down Expand Up @@ -64,14 +67,84 @@ def test_ask_appraiser(self):
# TODO
pass

def test_get_llm_api_response(self):
# TODO
pass
@patch('neon_llm_core.chatbot.send_mq_request')
def test_get_llm_api_response(self, mq_request):
mq_request.return_value = {"response": "test",
"message_id": ""}

def test_get_llm_api_opinion(self):
# TODO
pass
# Valid Request
resp = self.mock_chatbot._get_llm_api_response("input")
request_data = mq_request.call_args.kwargs['request_data']
req = LLMProposeRequest(**request_data)
self.assertIsInstance(req, LLMProposeRequest)
self.assertEqual(req.query, "input")
self.assertEqual(req.model, self.mock_chatbot.base_llm)
self.assertEqual(req.persona, self.mock_chatbot.persona)
self.assertIsInstance(resp, LLMProposeResponse)
self.assertEqual(resp.response, mq_request.return_value['response'])

def test_get_llm_api_choice(self):
# TODO
pass
# Invalid request
self.assertIsNone(self.mock_chatbot._get_llm_api_response(None))

# Invalid response
mq_request.return_value = {}
self.assertIsNone(self.mock_chatbot._get_llm_api_response("input"))

@patch('neon_llm_core.chatbot.send_mq_request')
def test_get_llm_api_opinion(self, mq_request):
mq_request.return_value = {"opinion": "test",
"message_id": ""}
prompt = "test prompt"
options = {"bot 1": "resp 1", "bot 2": "resp 2"}

# Valid Request
resp = self.mock_chatbot._get_llm_api_opinion(prompt, options)
request_data = mq_request.call_args.kwargs['request_data']
req = LLMDiscussRequest(**request_data)
self.assertIsInstance(req, LLMDiscussRequest)
self.assertEqual(req.query, prompt)
self.assertEqual(req.options, options)
self.assertEqual(req.model, self.mock_chatbot.base_llm)
self.assertEqual(req.persona, self.mock_chatbot.persona)
self.assertIsInstance(resp, LLMDiscussResponse)
self.assertEqual(resp.opinion, mq_request.return_value['opinion'])

# Invalid request
self.assertIsNone(self.mock_chatbot._get_llm_api_opinion(prompt,
prompt))

# Invalid response
mq_request.return_value = {}
self.assertIsNone(self.mock_chatbot._get_llm_api_opinion(prompt,
options))

@patch('neon_llm_core.chatbot.send_mq_request')
def test_get_llm_api_choice(self, mq_request):
mq_request.return_value = {"sorted_answer_indexes": [2, 0, 1],
"message_id": ""}
prompt = "test prompt"
responses = ["one", "two", "three"]

# Valid Request
resp = self.mock_chatbot._get_llm_api_choice(prompt, responses)
request_data = mq_request.call_args.kwargs['request_data']

req = LLMVoteRequest(**request_data)
self.assertIsInstance(req, LLMVoteRequest)
self.assertEqual(req.query, prompt)
self.assertEqual(req.responses, responses)
self.assertEqual(req.model, self.mock_chatbot.base_llm)
self.assertEqual(req.persona, self.mock_chatbot.persona)
self.assertIsInstance(resp, LLMVoteResponse)
self.assertEqual(resp.sorted_answer_indexes,
mq_request.return_value['sorted_answer_indexes'])

# Invalid request
self.assertIsNone(self.mock_chatbot._get_llm_api_choice(prompt,
[1, 2, 3]))

# Invalid response
mq_request.return_value["sorted_answer_indexes"] = ["one", "two",
"three"]
self.assertIsNone(self.mock_chatbot._get_llm_api_choice(prompt,
responses))
32 changes: 32 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System
# All trademark and other rights reserved by their respective owners
# Copyright 2008-2024 NeonGecko.com Inc.
# BSD-3
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from unittest import TestCase


class TestNeonLLM(TestCase):
# TODO
pass

0 comments on commit 3abd732

Please sign in to comment.