Skip to content

Commit

Permalink
Add test coverage for ask_chatbot, ask_discusser, and `ask_apprai…
Browse files Browse the repository at this point in the history
…ser`

Refactor `chatbot` methods to safely handle missing context
  • Loading branch information
NeonDaniel committed Nov 22, 2024
1 parent 3abd732 commit aff20cb
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 21 deletions.
25 changes: 17 additions & 8 deletions neon_llm_core/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def ask_chatbot(self, user: str, shout: str, timestamp: str,
:param timestamp: formatted timestamp of shout
:param context: message context
"""
prompt_id = context.get('prompt_id')
prompt_id = context.get('prompt_id') if context else None
if prompt_id:
self.prompt_id_to_shout[prompt_id] = shout
LOG.debug(f"Getting response to {shout}")
Expand All @@ -78,7 +78,10 @@ def ask_discusser(self, options: dict, context: dict = None) -> str:
:param context: message context
"""
options = {k: v for k, v in options.items() if k != self.service_name}
prompt_sentence = self.prompt_id_to_shout.get(context['prompt_id'], '')
prompt_id = context.get('prompt_id') if context else None
prompt_sentence = None
if prompt_id:
prompt_sentence = self.prompt_id_to_shout.get(prompt_id)
LOG.info(f'prompt_sentence={prompt_sentence}, options={options}')
opinion = self._get_llm_api_opinion(prompt=prompt_sentence,
options=options)
Expand All @@ -90,15 +93,21 @@ def ask_appraiser(self, options: dict, context: dict = None) -> str:
:param options: proposed responses (botname: response)
:param context: message context
"""
# Determine the relevant prompt
prompt_id = context.get('prompt_id') if context else None
prompt_sentence = None
if prompt_id:
prompt_sentence = self.prompt_id_to_shout.get(prompt_id)

# Remove self answer from available options
options = {k: v for k, v in options.items()
if k != self.service_name}

if options:
# Remove self answer from available options
options = {k: v for k, v in options.items()
if k != self.service_name}
bots = list(options)
bot_responses = list(options.values())
LOG.info(f'bots={bots}, answers={bot_responses}')
prompt = self.prompt_id_to_shout.pop(context['prompt_id'], '')
answer_data = self._get_llm_api_choice(prompt=prompt,
answer_data = self._get_llm_api_choice(prompt=prompt_sentence,
responses=bot_responses)
LOG.info(f'Received answer_data={answer_data}')
if answer_data and answer_data.sorted_answer_indexes:
Expand All @@ -108,7 +117,7 @@ def ask_appraiser(self, options: dict, context: dict = None) -> str:
def _get_llm_api_response(self, shout: str) -> Optional[LLMProposeResponse]:
"""
Requests LLM API for response on provided shout
:param shout: provided should string
:param shout: Input prompt to respond to
:returns response from LLM API
"""
queue = self.mq_queue_config.ask_response_queue
Expand Down
115 changes: 102 additions & 13 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from datetime import datetime
from unittest import TestCase
from unittest.mock import patch

from neon_data_models.models.api import LLMPersona, LLMProposeRequest, LLMProposeResponse, LLMDiscussRequest, \
LLMDiscussResponse, LLMVoteRequest, LLMVoteResponse
from pydantic import ValidationError
from neon_data_models.models.api import LLMPersona, LLMProposeRequest, \
LLMProposeResponse, LLMDiscussRequest, LLMDiscussResponse, LLMVoteRequest, \
LLMVoteResponse

from neon_llm_core.chatbot import LLMBot
from neon_llm_core.utils.config import LLMMQConfig
Expand All @@ -55,17 +55,106 @@ def test_00_init(self):
self.assertIsInstance(self.mock_chatbot.persona, LLMPersona)
self.assertIsInstance(self.mock_chatbot.mq_queue_config, LLMMQConfig)

def test_ask_chatbot(self):
# TODO
pass
@patch.object(mock_chatbot, '_get_llm_api_response')
def test_ask_chatbot(self, get_api_response):
get_api_response.return_value = LLMProposeResponse(message_id="",
response="test_resp")
valid_prompt_id = "test_prompt_id_ask"
valid_user = "test_user"
valid_shout = "test_shout"
valid_timestamp = datetime.now().isoformat()
valid_context = {"prompt_id": valid_prompt_id}

# Valid Request
resp = self.mock_chatbot.ask_chatbot(valid_user, valid_shout,
valid_timestamp, valid_context)
get_api_response.assert_called_with(shout=valid_shout)
self.assertEqual(resp, "test_resp")
self.assertEqual(self.mock_chatbot.prompt_id_to_shout,
{valid_prompt_id: valid_shout})

# Valid without context
resp = self.mock_chatbot.ask_chatbot(valid_user, valid_shout,
valid_timestamp)
get_api_response.assert_called_with(shout=valid_shout)
self.assertEqual(resp, "test_resp")
self.assertEqual(self.mock_chatbot.prompt_id_to_shout,
{valid_prompt_id: valid_shout})

def test_ask_discusser(self):
# TODO
pass
# Invalid request
self.assertIsInstance(self.mock_chatbot.ask_chatbot(valid_user,
None,
valid_timestamp),
str)
get_api_response.assert_called_with(shout=None)

def test_ask_appraiser(self):
# TODO
pass
# Invalid response
get_api_response.return_value = None
self.assertIsInstance(self.mock_chatbot.ask_chatbot(valid_user,
valid_shout,
valid_timestamp,
valid_context), str)
get_api_response.assert_called_with(shout=valid_shout)

@patch.object(mock_chatbot, '_get_llm_api_opinion')
def test_ask_discusser(self, get_api_opinion):
get_api_opinion.return_value = LLMDiscussResponse(message_id="",
opinion="test_resp")
valid_prompt_id = "test_prompt_id_disc"
valid_prompt = "test prompt"
valid_options = {"bot 1": "response 1", "bot 2": "response 2"}
valid_context = {"prompt_id": valid_prompt_id}

self.mock_chatbot.prompt_id_to_shout[valid_prompt_id] = valid_prompt

# Valid request
resp = self.mock_chatbot.ask_discusser(valid_options, valid_context)
get_api_opinion.assert_called_with(prompt=valid_prompt,
options=valid_options)
self.assertEqual(resp, "test_resp")

# Invalid response
get_api_opinion.return_value = None
self.assertIsInstance(self.mock_chatbot.ask_discusser(valid_options,
valid_context),
str)
get_api_opinion.assert_called_with(prompt=valid_prompt,
options=valid_options)

@patch.object(mock_chatbot, '_get_llm_api_choice')
def test_ask_appraiser(self, get_api_choice):
get_api_choice.return_value = LLMVoteResponse(
message_id="", sorted_answer_indexes=[2, 0, 1])
valid_prompt_id = "test_prompt_id_vote"
valid_prompt = "test prompt"
options = {"bot 0": "response 0",
"bot 1": "response 1",
"bot 2": "response 2",
self.mock_chatbot.service_name: "Self response"}
valid_options = ["response 0", "response 1", "response 2"]
valid_context = {"prompt_id": valid_prompt_id}

self.mock_chatbot.prompt_id_to_shout[valid_prompt_id] = valid_prompt

# Valid request
resp = self.mock_chatbot.ask_appraiser(options, valid_context)
get_api_choice.assert_called_with(prompt=valid_prompt,
responses=valid_options)
self.assertEqual(resp, "bot 2")

# Invalid no valid options
resp = self.mock_chatbot.ask_appraiser(
{self.mock_chatbot.service_name: "Self response"},
valid_context)
self.assertIn("abstain", resp.lower())

# Invalid API response
get_api_choice.reset_mock()
get_api_choice.return_value = None
resp = self.mock_chatbot.ask_appraiser(options, valid_context)
get_api_choice.assert_called_with(prompt=valid_prompt,
responses=valid_options)
self.assertIn("abstain", resp.lower())

@patch('neon_llm_core.chatbot.send_mq_request')
def test_get_llm_api_response(self, mq_request):
Expand Down

0 comments on commit aff20cb

Please sign in to comment.