Skip to content

Commit

Permalink
Merge pull request #702 from dimagi/cs/show_routing_tags_in_convo
Browse files Browse the repository at this point in the history
Show routing tag in web UI when "debug mode" is enabled
  • Loading branch information
SmittieC authored Oct 3, 2024
2 parents 733b7b8 + 5e9e398 commit 6a90f23
Show file tree
Hide file tree
Showing 19 changed files with 205 additions and 35 deletions.
18 changes: 18 additions & 0 deletions apps/annotations/migrations/0006_alter_tag_category.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 5.1 on 2024-10-02 13:18

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('annotations', '0005_alter_tag_category'),
]

operations = [
migrations.AlterField(
model_name='tag',
name='category',
field=models.CharField(blank=True, choices=[('bot_response', 'Bot Response'), ('safety_layer_response', 'Safety Layer Response'), ('experiment_version', 'Experiment Version')], default=''),
),
]
1 change: 1 addition & 0 deletions apps/annotations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

class TagCategories(models.TextChoices):
BOT_RESPONSE = "bot_response", _("Bot Response")
SAFETY_LAYER_RESPONSE = "safety_layer_response", _("Safety Layer Response")
EXPERIMENT_VERSION = "experiment_version", _("Experiment Version")


Expand Down
5 changes: 4 additions & 1 deletion apps/channels/tests/test_web_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
@override_settings(CELERY_TASK_ALWAYS_EAGER=True)
@pytest.mark.parametrize("with_seed_message", [True, False])
@patch("apps.events.tasks.enqueue_static_triggers", Mock())
@patch("apps.chat.bots.TopicBot.get_ai_message_id")
@patch("apps.chat.channels.WebChannel.new_user_message")
def test_start_new_session(new_user_message, with_seed_message, experiment):
def test_start_new_session(new_user_message, get_ai_message_id, with_seed_message, experiment):
"""A simple test to make sure we create a session and send a session message"""
get_ai_message_id.return_value = 1

if with_seed_message:
experiment.seed_message = "Tell a joke"
experiment.save()
Expand Down
29 changes: 25 additions & 4 deletions apps/chat/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from langchain_core.runnables import chain
from pydantic import ValidationError

from apps.annotations.models import TagCategories
from apps.chat.conversation import BasicConversation, Conversation
from apps.chat.exceptions import ChatException
from apps.chat.models import ChatMessageType
from apps.events.models import StaticTriggerType
from apps.events.tasks import enqueue_static_triggers
from apps.experiments.models import Experiment, ExperimentRoute, ExperimentSession, SafetyLayer
Expand Down Expand Up @@ -85,6 +87,9 @@ def __init__(self, session: ExperimentSession, experiment: Experiment | None = N
self.trace_service = None
if self.experiment.trace_provider:
self.trace_service = self.experiment.trace_provider.get_service()

# The chain that generated the AI message
self.generator_chain = None
self._initialize()

def _initialize(self):
Expand Down Expand Up @@ -132,7 +137,8 @@ def _call_predict(self, input_str, save_input_to_history=True, attachments: list
)

if self.terminal_chain:
result = self.terminal_chain.invoke(
chain = self.terminal_chain
result = chain.invoke(
result.output,
config={
"run_name": "terminal_chain",
Expand All @@ -144,6 +150,8 @@ def _call_predict(self, input_str, save_input_to_history=True, attachments: list
},
)

self.generator_chain = chain

enqueue_static_triggers.delay(self.session.id, StaticTriggerType.NEW_BOT_MESSAGE)
self.input_tokens = self.input_tokens + result.prompt_tokens
self.output_tokens = self.output_tokens + result.completion_tokens
Expand Down Expand Up @@ -208,13 +216,26 @@ def main_bot_chain(user_input):
}
return main_bot_chain.invoke(user_input, config=config)

def get_ai_message_id(self) -> int | None:
"""Returns the generated AI message's ID. The caller can use this to fetch more information on this message"""
if self.generator_chain and self.generator_chain.state.ai_message:
return self.generator_chain.state.ai_message.id

def _get_safe_response(self, safety_layer: SafetyLayer):
if safety_layer.prompt_to_bot:
safety_response = self._call_predict(safety_layer.prompt_to_bot, save_input_to_history=False)
return safety_response
bot_response = self._call_predict(safety_layer.prompt_to_bot, save_input_to_history=False)
else:
no_answer = "Sorry, I can't answer that. Please try something else."
return safety_layer.default_response_to_user or no_answer
bot_response = safety_layer.default_response_to_user or no_answer
# This is a bit of a hack to store the bot's response, since it didn't really generate it, but we still
# need to save it
self.chain.state.save_message_to_history(bot_response, type_=ChatMessageType.AI)
self.generator_chain = self.chain

self.generator_chain.state.ai_message.add_system_tag(
safety_layer.name, tag_category=TagCategories.SAFETY_LAYER_RESPONSE
)
return bot_response


class SafetyBot:
Expand Down
3 changes: 1 addition & 2 deletions apps/chat/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,7 @@ def _transcribe_audio(self, audio: BytesIO) -> str:

def _get_bot_response(self, message: str) -> str:
self.bot = self.bot or get_bot(self.experiment_session, experiment=self.experiment)
answer = self.bot.process_input(message, attachments=self.message.attachments)
return answer
return self.bot.process_input(message, attachments=self.message.attachments)

def _add_message_to_history(self, message: str, message_type: ChatMessageType):
"""Use this to update the chat history when not using the normal bot flow"""
Expand Down
12 changes: 12 additions & 0 deletions apps/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,18 @@ def add_system_tag(self, tag: str, tag_category: TagCategories):
)
self.add_tag(tag, team=self.chat.team, added_by=None)

def get_processor_bot_tag_name(self) -> str | None:
"""Returns the tag of the bot that generated this message"""
if self.message_type != ChatMessageType.AI:
return
if tag := self.tags.filter(category=TagCategories.BOT_RESPONSE).first():
return tag.name

def get_safety_layer_tag_name(self) -> str | None:
"""Returns the name of the safety layer tag, if there is one"""
if tag := self.tags.filter(category=TagCategories.SAFETY_LAYER_RESPONSE).first():
return tag.name


class ChatAttachment(BaseModel):
chat = models.ForeignKey(Chat, on_delete=models.CASCADE, related_name="attachments")
Expand Down
19 changes: 18 additions & 1 deletion apps/chat/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from apps.chat.models import ChatMessage
from apps.annotations.models import TagCategories
from apps.chat.models import ChatMessage, ChatMessageType
from apps.utils.factories.assistants import OpenAiAssistantFactory
from apps.utils.factories.experiment import ExperimentSessionFactory
from apps.utils.factories.files import FileFactory
Expand Down Expand Up @@ -32,3 +33,19 @@ def test_get_attached_files():
assert chat_file1 in files
assert assistant_file1 not in files
assert assistant_file2 not in files


@pytest.mark.django_db()
class TestChatMessage:
def test_get_processor_bot_tag_name(self):
session = ExperimentSessionFactory()
human_message = ChatMessage.objects.create(chat=session.chat, message_type=ChatMessageType.HUMAN, content="Hi")
ai_message_wo_tag = ChatMessage.objects.create(chat=session.chat, message_type=ChatMessageType.AI, content="Hi")
ai_message_with_tag = ChatMessage.objects.create(
chat=session.chat, message_type=ChatMessageType.AI, content="Hi"
)
ai_message_with_tag.add_system_tag(tag="some-bot", tag_category=TagCategories.BOT_RESPONSE)

assert human_message.get_processor_bot_tag_name() is None
assert ai_message_wo_tag.get_processor_bot_tag_name() is None
assert ai_message_with_tag.get_processor_bot_tag_name() == "some-bot"
15 changes: 15 additions & 0 deletions apps/chat/tests/test_topic_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import pytest

from apps.annotations.models import TagCategories
from apps.chat.bots import TopicBot
from apps.chat.models import ChatMessage, ChatMessageType
from apps.experiments.models import ExperimentRoute, ExperimentRouteType, SafetyLayer
from apps.utils.factories.experiment import ExperimentFactory, ExperimentSessionFactory
from apps.utils.langchain import mock_experiment_llm
Expand Down Expand Up @@ -51,3 +53,16 @@ def test_bot_with_terminal_bot(get_output_check_cancellation):
assert session.chat.messages.count() == 2
assert session.chat.messages.get(message_type="human").content == "What are we going to do?"
assert session.chat.messages.get(message_type="ai").content == "kom ons braai!"


@pytest.mark.django_db()
def test_get_safe_response_creates_ai_message_for_default_messages():
session = ExperimentSessionFactory()
layer = SafetyLayer.objects.create(prompt_text="Is this message safe?", team=session.experiment.team)
session.experiment.safety_layers.add(layer)

bot = TopicBot(session)
bot._get_safe_response(layer)
message = ChatMessage.objects.get(message_type=ChatMessageType.AI)
assert message.content == "Sorry, I can't answer that. Please try something else."
assert message.tags.get(category=TagCategories.SAFETY_LAYER_RESPONSE) is not None
18 changes: 18 additions & 0 deletions apps/experiments/migrations/0095_experiment_debug_mode_enabled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 5.1 on 2024-10-02 07:51

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('experiments', '0094_consentform_working_version_and_more'),
]

operations = [
migrations.AddField(
model_name='experiment',
name='debug_mode_enabled',
field=models.BooleanField(default=False),
),
]
1 change: 1 addition & 0 deletions apps/experiments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ class Experiment(BaseTeamModel, VersionsMixin):
blank=True,
default="",
)
debug_mode_enabled = models.BooleanField(default=False)
objects = ExperimentObjectManager()

class Meta:
Expand Down
7 changes: 5 additions & 2 deletions apps/experiments/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@shared_task(bind=True, base=TaskbadgerTask)
def get_response_for_webchat_task(
self, experiment_session_id: int, experiment_id: int, message_text: str, attachments: list | None = None
) -> str:
) -> dict:
experiment_session = ExperimentSession.objects.select_related("experiment", "experiment__team").get(
id=experiment_session_id
)
Expand All @@ -40,7 +40,10 @@ def get_response_for_webchat_task(
)
update_taskbadger_data(self, web_channel, message)
with current_team(experiment_session.team):
return web_channel.new_user_message(message)
return {
"response": web_channel.new_user_message(message),
"message_id": web_channel.bot.get_ai_message_id(),
}


@shared_task
Expand Down
14 changes: 13 additions & 1 deletion apps/experiments/views/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class Meta:
"use_processor_bot_voice",
"trace_provider",
"participant_allowlist",
"debug_mode_enabled",
]
labels = {"source_material": "Inline Source Material", "participant_allowlist": "Participant allowlist"}
help_texts = {
Expand All @@ -224,6 +225,10 @@ class Meta:
"participant_allowlist": (
"Separate identifiers with a comma. Phone numbers should be in E164 format e.g. +27123456789"
),
"debug_mode_enabled": (
"Enabling this tags each AI message in the web UI with the bot responsible for generating it. "
"This is applicable only for router bots."
),
}

def __init__(self, request, *args, **kwargs):
Expand Down Expand Up @@ -788,14 +793,21 @@ def get_message_response(request, team_slug: str, experiment_id: int, session_id
# don't render empty messages
skip_render = progress["complete"] and progress["success"] and not progress["result"]

message_details = {"message": None, "error": False, "complete": progress["complete"]}
if progress["complete"] and progress["success"]:
result = progress["result"]
message_details["message"] = ChatMessage.objects.get(id=result["message_id"])
elif progress["complete"]:
message_details["error"] = True

return TemplateResponse(
request,
"experiments/chat/chat_message_response.html",
{
"experiment": experiment,
"session": session,
"task_id": task_id,
"progress": progress,
"message_details": message_details,
"skip_render": skip_render,
"last_message_datetime": last_message and quote(last_message.created_at.isoformat()),
},
Expand Down
26 changes: 17 additions & 9 deletions apps/service_providers/llm_service/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@


class RunnableState(metaclass=ABCMeta):
ai_message: ChatMessage | None = None

@abstractmethod
def get_llm_service(self):
pass
Expand Down Expand Up @@ -147,10 +149,12 @@ def save_message_to_history(self, message: str, type_: ChatMessageType, experime
)
chat_message.add_tag(tag, team=self.session.team, added_by=None)

if type_ == ChatMessageType.AI and not self.experiment.is_working_version:
chat_message.add_system_tag(
tag=self.experiment.version_display, tag_category=TagCategories.EXPERIMENT_VERSION
)
if type_ == ChatMessageType.AI:
self.ai_message = chat_message
if not self.experiment.is_working_version:
chat_message.add_system_tag(
tag=self.experiment.version_display, tag_category=TagCategories.EXPERIMENT_VERSION
)

def check_cancellation(self):
self.session.chat.refresh_from_db(fields=["metadata"])
Expand All @@ -177,7 +181,9 @@ def set_metadata(self, key: Chat.MetadataKeys, value):
pass

@abstractmethod
def save_message_to_history(self, message: str, type_: ChatMessageType, resource_file_ids: dict | None = None):
def save_message_to_history(
self, message: str, type_: ChatMessageType, resource_file_ids: dict | None = None
) -> ChatMessage:
pass

@abstractmethod
Expand Down Expand Up @@ -268,10 +274,12 @@ def save_message_to_history(
)
chat_message.add_tag(tag, team=self.session.team, added_by=None)

if type_ == ChatMessageType.AI and not self.experiment.is_working_version:
chat_message.add_system_tag(
tag=self.experiment.version_display, tag_category=TagCategories.EXPERIMENT_VERSION
)
if type_ == ChatMessageType.AI:
self.ai_message = chat_message
if not self.experiment.is_working_version:
chat_message.add_system_tag(
tag=self.experiment.version_display, tag_category=TagCategories.EXPERIMENT_VERSION
)

return chat_message

Expand Down
Empty file.
31 changes: 31 additions & 0 deletions apps/service_providers/tests/llm_service/test_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from apps.chat.models import ChatMessage, ChatMessageType
from apps.service_providers.llm_service.state import AssistantExperimentState, ChatExperimentState
from apps.utils.factories.experiment import ExperimentSessionFactory


@pytest.fixture()
def session():
return ExperimentSessionFactory()


# TODO: I need more tests please!
@pytest.mark.django_db()
class TestChatExperimentState:
def test_save_message_to_history_stores_ai_message_on_state(self, session):
state = ChatExperimentState(session=session, experiment=session.experiment)
state.save_message_to_history(message="hi", type_=ChatMessageType.HUMAN)
assert state.ai_message is None
state.save_message_to_history(message="hi human", type_=ChatMessageType.AI)
assert state.ai_message == ChatMessage.objects.get(message_type=ChatMessageType.AI)


@pytest.mark.django_db()
class TestAssistantExperimentState:
def test_save_message_to_history_stores_ai_message_on_state(self, session):
state = AssistantExperimentState(session=session, experiment=session.experiment)
state.save_message_to_history(message="hi", type_=ChatMessageType.HUMAN)
assert state.ai_message is None
state.save_message_to_history(message="hi human", type_=ChatMessageType.AI)
assert state.ai_message == ChatMessage.objects.get(message_type=ChatMessageType.AI)
22 changes: 17 additions & 5 deletions templates/experiments/chat/ai_message.html
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
{% load chat_tags %}
<div class="chat-message-system flex flex-col" data-last-message-datetime="{{ created_at_datetime|safe }}">
<div class="flex flex-row">
{% include "experiments/chat/components/system_icon.html" %}
<div class="message-contents">
<p>{{ message_text|render_markdown }}</p>
<div class="flex flex-col gap-1">
{% if experiment.debug_mode_enabled %}
<div>
{% if message.get_safety_layer_tag_name %}
<div class="badge badge-sm badge-error">Safety Layer '{{ message.get_safety_layer_tag_name }}' Triggered</div>
{% endif %}
{% if experiment.debug_mode_enabled and message.get_processor_bot_tag_name %}
<div class="badge badge-sm badge-warning">Routed to {{ message.get_processor_bot_tag_name }}</div>
{% endif %}
</div>
{% endif %}
<div class="chat-message-system flex flex-col" data-last-message-datetime="{{ created_at_datetime|safe }}">
<div class="flex flex-row">
{% include "experiments/chat/components/system_icon.html" %}
<div class="message-contents">
<p>{{ message.content|render_markdown }}</p>
</div>
</div>
</div>
</div>
Loading

0 comments on commit 6a90f23

Please sign in to comment.