diff --git a/neon_llm_core/chatbot.py b/neon_llm_core/chatbot.py index 2208b8a..0dca2fd 100644 --- a/neon_llm_core/chatbot.py +++ b/neon_llm_core/chatbot.py @@ -24,10 +24,11 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import List +from typing import List, Optional from chatbot_core.v2 import ChatBot from neon_data_models.models.api.mq import (LLMProposeRequest, - LLMDiscussRequest, LLMVoteRequest) + LLMDiscussRequest, LLMVoteRequest, LLMProposeResponse, LLMDiscussResponse, + LLMVoteResponse) from neon_mq_connector.utils.client_utils import send_mq_request from neon_utils.logger import LOG from neon_data_models.models.api.llm import LLMPersona @@ -66,9 +67,8 @@ def ask_chatbot(self, user: str, shout: str, timestamp: str, if prompt_id: self.prompt_id_to_shout[prompt_id] = shout LOG.debug(f"Getting response to {shout}") - response = self._get_llm_api_response( - shout=shout).get("response", "I have nothing to say here...") - return response + response = self._get_llm_api_response(shout=shout) + return response.response if response else "I have nothing to say here..." def ask_discusser(self, options: dict, context: dict = None) -> str: """ @@ -81,8 +81,8 @@ def ask_discusser(self, options: dict, context: dict = None) -> str: prompt_sentence = self.prompt_id_to_shout.get(context['prompt_id'], '') LOG.info(f'prompt_sentence={prompt_sentence}, options={options}') opinion = self._get_llm_api_opinion(prompt=prompt_sentence, - options=options).get('opinion', '') - return opinion + options=options) + return opinion.opinion if opinion else "I have nothing to say here..." def ask_appraiser(self, options: dict, context: dict = None) -> str: """ @@ -101,16 +101,15 @@ def ask_appraiser(self, options: dict, context: dict = None) -> str: answer_data = self._get_llm_api_choice(prompt=prompt, responses=bot_responses) LOG.info(f'Received answer_data={answer_data}') - sorted_answer_indexes = answer_data.get('sorted_answer_indexes') - if sorted_answer_indexes: - return bots[sorted_answer_indexes[0]] + if answer_data and answer_data.sorted_answer_indexes: + return bots[answer_data.sorted_answer_indexes[0]] return "abstain" - def _get_llm_api_response(self, shout: str) -> dict: + def _get_llm_api_response(self, shout: str) -> Optional[LLMProposeResponse]: """ Requests LLM API for response on provided shout :param shout: provided should string - :returns response string from LLM API + :returns response from LLM API """ queue = self.mq_queue_config.ask_response_queue LOG.info(f"Sending to {self.mq_queue_config.vhost}/{queue}") @@ -120,18 +119,17 @@ def _get_llm_api_response(self, shout: str) -> dict: query=shout, history=[], message_id="") - return send_mq_request(vhost=self.mq_queue_config.vhost, - request_data=request_data.model_dump(), - target_queue=queue, - response_queue=f"{queue}.response") + resp_data = send_mq_request(vhost=self.mq_queue_config.vhost, + request_data=request_data.model_dump(), + target_queue=queue, + response_queue=f"{queue}.response") + return LLMProposeResponse(**resp_data) except Exception as e: LOG.exception(f"Failed to get response on " - f"{self.mq_queue_config.vhost}/" - f"{self.mq_queue_config.ask_response_queue}: " - f"{e}") - return dict() + f"{self.mq_queue_config.vhost}/{queue}: {e}") - def _get_llm_api_opinion(self, prompt: str, options: dict) -> dict: + def _get_llm_api_opinion(self, prompt: str, + options: dict) -> Optional[LLMDiscussResponse]: """ Requests LLM API for discussion of provided submind responses :param prompt: incoming prompt text @@ -139,35 +137,47 @@ def _get_llm_api_opinion(self, prompt: str, options: dict) -> dict: :returns response data from LLM API """ queue = self.mq_queue_config.ask_discusser_queue - request_data = LLMDiscussRequest(model=self.base_llm, - persona=self.persona, - query=prompt, - options=options, - history=[], - message_id="") - return send_mq_request(vhost=self.mq_queue_config.vhost, - request_data=request_data.model_dump(), - target_queue=queue, - response_queue=f"{queue}.response") + try: + request_data = LLMDiscussRequest(model=self.base_llm, + persona=self.persona, + query=prompt, + options=options, + history=[], + message_id="") + resp_data = send_mq_request(vhost=self.mq_queue_config.vhost, + request_data=request_data.model_dump(), + target_queue=queue, + response_queue=f"{queue}.response") + return LLMDiscussResponse(**resp_data) + except Exception as e: + LOG.exception(f"Failed to get response on " + f"{self.mq_queue_config.vhost}/{queue}: {e}") - def _get_llm_api_choice(self, prompt: str, responses: List[str]) -> dict: + def _get_llm_api_choice(self, prompt: str, + responses: List[str]) -> Optional[LLMVoteResponse]: """ Requests LLM API for choice among provided message list :param prompt: incoming prompt text :param responses: list of answers to select from :returns response data from LLM API """ - request_data = LLMVoteRequest(model=self.base_llm, - persona=self.persona, - query=prompt, - responses=responses, - history=[], - message_id="") queue = self.mq_queue_config.ask_appraiser_queue - return send_mq_request(vhost=self.mq_queue_config.vhost, - request_data=request_data.model_dump(), - target_queue=queue, - response_queue=f"{queue}.response") + + try: + request_data = LLMVoteRequest(model=self.base_llm, + persona=self.persona, + query=prompt, + responses=responses, + history=[], + message_id="") + resp_data = send_mq_request(vhost=self.mq_queue_config.vhost, + request_data=request_data.model_dump(), + target_queue=queue, + response_queue=f"{queue}.response") + return LLMVoteResponse(**resp_data) + except Exception as e: + LOG.exception(f"Failed to get response on " + f"{self.mq_queue_config.vhost}/{queue}: {e}") @staticmethod def get_llm_mq_config(llm_name: str) -> LLMMQConfig: diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index e53f313..3f6e7c8 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -25,8 +25,11 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from unittest import TestCase +from unittest.mock import patch -from neon_data_models.models.api import LLMPersona +from neon_data_models.models.api import LLMPersona, LLMProposeRequest, LLMProposeResponse, LLMDiscussRequest, \ + LLMDiscussResponse, LLMVoteRequest, LLMVoteResponse +from pydantic import ValidationError from neon_llm_core.chatbot import LLMBot from neon_llm_core.utils.config import LLMMQConfig @@ -64,14 +67,84 @@ def test_ask_appraiser(self): # TODO pass - def test_get_llm_api_response(self): - # TODO - pass + @patch('neon_llm_core.chatbot.send_mq_request') + def test_get_llm_api_response(self, mq_request): + mq_request.return_value = {"response": "test", + "message_id": ""} - def test_get_llm_api_opinion(self): - # TODO - pass + # Valid Request + resp = self.mock_chatbot._get_llm_api_response("input") + request_data = mq_request.call_args.kwargs['request_data'] + req = LLMProposeRequest(**request_data) + self.assertIsInstance(req, LLMProposeRequest) + self.assertEqual(req.query, "input") + self.assertEqual(req.model, self.mock_chatbot.base_llm) + self.assertEqual(req.persona, self.mock_chatbot.persona) + self.assertIsInstance(resp, LLMProposeResponse) + self.assertEqual(resp.response, mq_request.return_value['response']) - def test_get_llm_api_choice(self): - # TODO - pass + # Invalid request + self.assertIsNone(self.mock_chatbot._get_llm_api_response(None)) + + # Invalid response + mq_request.return_value = {} + self.assertIsNone(self.mock_chatbot._get_llm_api_response("input")) + + @patch('neon_llm_core.chatbot.send_mq_request') + def test_get_llm_api_opinion(self, mq_request): + mq_request.return_value = {"opinion": "test", + "message_id": ""} + prompt = "test prompt" + options = {"bot 1": "resp 1", "bot 2": "resp 2"} + + # Valid Request + resp = self.mock_chatbot._get_llm_api_opinion(prompt, options) + request_data = mq_request.call_args.kwargs['request_data'] + req = LLMDiscussRequest(**request_data) + self.assertIsInstance(req, LLMDiscussRequest) + self.assertEqual(req.query, prompt) + self.assertEqual(req.options, options) + self.assertEqual(req.model, self.mock_chatbot.base_llm) + self.assertEqual(req.persona, self.mock_chatbot.persona) + self.assertIsInstance(resp, LLMDiscussResponse) + self.assertEqual(resp.opinion, mq_request.return_value['opinion']) + + # Invalid request + self.assertIsNone(self.mock_chatbot._get_llm_api_opinion(prompt, + prompt)) + + # Invalid response + mq_request.return_value = {} + self.assertIsNone(self.mock_chatbot._get_llm_api_opinion(prompt, + options)) + + @patch('neon_llm_core.chatbot.send_mq_request') + def test_get_llm_api_choice(self, mq_request): + mq_request.return_value = {"sorted_answer_indexes": [2, 0, 1], + "message_id": ""} + prompt = "test prompt" + responses = ["one", "two", "three"] + + # Valid Request + resp = self.mock_chatbot._get_llm_api_choice(prompt, responses) + request_data = mq_request.call_args.kwargs['request_data'] + + req = LLMVoteRequest(**request_data) + self.assertIsInstance(req, LLMVoteRequest) + self.assertEqual(req.query, prompt) + self.assertEqual(req.responses, responses) + self.assertEqual(req.model, self.mock_chatbot.base_llm) + self.assertEqual(req.persona, self.mock_chatbot.persona) + self.assertIsInstance(resp, LLMVoteResponse) + self.assertEqual(resp.sorted_answer_indexes, + mq_request.return_value['sorted_answer_indexes']) + + # Invalid request + self.assertIsNone(self.mock_chatbot._get_llm_api_choice(prompt, + [1, 2, 3])) + + # Invalid response + mq_request.return_value["sorted_answer_indexes"] = ["one", "two", + "three"] + self.assertIsNone(self.mock_chatbot._get_llm_api_choice(prompt, + responses)) diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 0000000..8a6a92f --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,32 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2024 NeonGecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from unittest import TestCase + + +class TestNeonLLM(TestCase): + # TODO + pass