Skip to content

Commit

Permalink
Merge pull request #671 from dimagi/cs/versioning_changes_in_code
Browse files Browse the repository at this point in the history
Experiment Versioning | Code updates
  • Loading branch information
SmittieC authored Sep 19, 2024
2 parents c8e22bd + 0a38ec6 commit a8485ff
Show file tree
Hide file tree
Showing 43 changed files with 670 additions and 187 deletions.
4 changes: 4 additions & 0 deletions api-schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@ components:
format: uri
readOnly: true
title: API URL
version_number:
type: integer
maximum: 2147483647
minimum: 0
required:
- id
- name
Expand Down
18 changes: 18 additions & 0 deletions apps/annotations/migrations/0005_alter_tag_category.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 5.1 on 2024-09-06 10:51

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('annotations', '0004_default_tag_category'),
]

operations = [
migrations.AlterField(
model_name='tag',
name='category',
field=models.CharField(blank=True, choices=[('bot_response', 'Bot Response'), ('experiment_version', 'Experiment Version')], default=''),
),
]
1 change: 1 addition & 0 deletions apps/annotations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

class TagCategories(models.TextChoices):
BOT_RESPONSE = "bot_response", _("Bot Response")
EXPERIMENT_VERSION = "experiment_version", _("Experiment Version")


@audit_fields(
Expand Down
2 changes: 1 addition & 1 deletion apps/api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def chat_completions(request, experiment_id: str):
session = serializer.save()
response_message = handle_api_message(
request.user,
session.experiment,
session.experiment_version,
session.experiment_channel,
last_message.get("content"),
session.participant.identifier,
Expand Down
2 changes: 1 addition & 1 deletion apps/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ExperimentSerializer(serializers.ModelSerializer):

class Meta:
model = Experiment
fields = ["id", "name", "url"]
fields = ["id", "name", "url", "version_number"]


class ParticipantSerializer(serializers.ModelSerializer):
Expand Down
2 changes: 2 additions & 0 deletions apps/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_list_experiments(experiment):
"name": experiment.name,
"id": experiment.public_id,
"url": f"http://testserver/api/experiments/{experiment.public_id}/",
"version_number": 1,
}
],
"next": None,
Expand All @@ -48,6 +49,7 @@ def test_retrieve_experiments(experiment):
"id": experiment.public_id,
"name": experiment.name,
"url": f"http://testserver/api/experiments/{experiment.public_id}/",
"version_number": 1,
}


Expand Down
1 change: 1 addition & 0 deletions apps/api/tests/test_session_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def get_session_json(session, expected_messages=None):
"id": str(experiment.public_id),
"name": experiment.name,
"url": f"http://testserver/api/experiments/{experiment.public_id}/",
"version_number": 1,
},
"participant": {"identifier": session.participant.identifier},
"id": str(session.external_id),
Expand Down
14 changes: 8 additions & 6 deletions apps/channels/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def handle_telegram_message(self, message_data: str, channel_external_id: uuid):
return

message = TelegramMessage.parse(update)
message_handler = TelegramChannel(experiment_channel.experiment, experiment_channel)
message_handler = TelegramChannel(experiment_channel.experiment.default_version, experiment_channel)
update_taskbadger_data(self, message_handler, message)

with current_team(experiment_channel.team):
Expand Down Expand Up @@ -68,7 +68,7 @@ def handle_twilio_message(self, message_data: str, request_uri: str, signature:

validate_twillio_request(experiment_channel, raw_data, request_uri, signature)

message_handler = ChannelClass(experiment_channel.experiment, experiment_channel=experiment_channel)
message_handler = ChannelClass(experiment_channel.experiment.default_version, experiment_channel=experiment_channel)
update_taskbadger_data(self, message_handler, message)

with current_team(experiment_channel.team):
Expand Down Expand Up @@ -106,7 +106,7 @@ def handle_sureadhere_message(self, sureadhere_tenant_id: str, message_data: dic
if not experiment_channel:
log.info(f"No experiment channel found for SureAdhere tenant ID: {sureadhere_tenant_id}")
return
channel = SureAdhereChannel(experiment_channel.experiment, experiment_channel)
channel = SureAdhereChannel(experiment_channel.experiment.default_version, experiment_channel)
update_taskbadger_data(self, channel, message)
with current_team(experiment_channel.team):
channel.new_user_message(message)
Expand All @@ -127,18 +127,20 @@ def handle_turn_message(self, experiment_id: uuid, message_data: dict):
if not experiment_channel:
log.info(f"No experiment channel found for experiment_id: {experiment_id}")
return
channel = WhatsappChannel(experiment_channel.experiment, experiment_channel)
channel = WhatsappChannel(experiment_channel.experiment.default_version, experiment_channel)
update_taskbadger_data(self, channel, message)

with current_team(experiment_channel.team):
channel.new_user_message(message)


def handle_api_message(user, experiment, experiment_channel, message_text: str, participant_id: str, session=None):
def handle_api_message(
user, experiment_version, experiment_channel, message_text: str, participant_id: str, session=None
):
"""Synchronously handles the message coming from the API"""
message = BaseMessage(participant_id=participant_id, message_text=message_text)
channel = ApiChannel(
experiment,
experiment_version,
experiment_channel,
experiment_session=session,
user=user,
Expand Down
76 changes: 63 additions & 13 deletions apps/channels/tests/test_base_channel_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from apps.channels.models import ChannelPlatform, ExperimentChannel
from apps.chat.channels import URL_REGEX, ChannelBase, TelegramChannel, strip_urls_and_emojis
from apps.chat.exceptions import VersionedExperimentSessionsNotAllowedException
from apps.chat.models import ChatMessageType
from apps.experiments.models import (
ExperimentRoute,
Expand Down Expand Up @@ -45,7 +46,7 @@ def test_incoming_message_adds_channel_info(telegram_channel):

chat_id = 123123
message = telegram_messages.text_message(chat_id=chat_id)
_simulate_user_message(telegram_channel, message)
_send_user_message_on_channel(telegram_channel, message)

experiment_session = ExperimentSession.objects.filter(
experiment=telegram_channel.experiment, participant__identifier=chat_id
Expand All @@ -60,7 +61,7 @@ def test_incoming_message_adds_channel_info(telegram_channel):
def test_channel_added_for_experiment_session(telegram_channel):
chat_id = 123123
message = telegram_messages.text_message(chat_id=chat_id)
_simulate_user_message(telegram_channel, message)
_send_user_message_on_channel(telegram_channel, message)
participant = Participant.objects.get(identifier=chat_id)
experiment_session = participant.experimentsession_set.first()
assert experiment_session.experiment_channel is not None
Expand All @@ -76,7 +77,7 @@ def test_incoming_message_uses_existing_experiment_session(telegram_channel):

# First message
message = telegram_messages.text_message(chat_id=chat_id)
_simulate_user_message(telegram_channel, message)
_send_user_message_on_channel(telegram_channel, message)

# Let's find the session it created
experiment_sessions_count = ExperimentSession.objects.filter(
Expand All @@ -88,7 +89,7 @@ def test_incoming_message_uses_existing_experiment_session(telegram_channel):
telegram_channel._create_new_experiment_session = Mock()

# Second message
_simulate_user_message(telegram_channel, message)
_send_user_message_on_channel(telegram_channel, message)

# Assertions
experiment_sessions_count = ExperimentSession.objects.filter(
Expand All @@ -107,14 +108,14 @@ def test_different_sessions_created_for_different_users(telegram_channel):

# First user's message
user_1_message = telegram_messages.text_message(chat_id=user_1_chat_id)
_simulate_user_message(telegram_channel, user_1_message)
_send_user_message_on_channel(telegram_channel, user_1_message)

# Calling new_user_message added an experiment_session, so we should remove it before reusing the instance
telegram_channel.experiment_session = None

# Second user's message
user_2_message = telegram_messages.text_message(chat_id=user_2_chat_id)
_simulate_user_message(telegram_channel, user_2_message)
_send_user_message_on_channel(telegram_channel, user_2_message)

# Assertions
experiment_sessions_count = ExperimentSession.objects.count()
Expand All @@ -141,8 +142,8 @@ def test_different_participants_created_for_same_user_in_different_teams():

assert experiment1.team != experiment2.team

_simulate_user_message(channel1, user_message)
_simulate_user_message(channel2, user_message)
_send_user_message_on_channel(channel1, user_message)
_send_user_message_on_channel(channel2, user_message)

experiment_sessions_count = ExperimentSession.objects.count()
assert experiment_sessions_count == 2
Expand All @@ -159,7 +160,7 @@ def test_reset_command_creates_new_experiment_session(_send_text_to_user_mock, t
telegram_chat_id = 00000
normal_message = telegram_messages.text_message(chat_id=telegram_chat_id)

_simulate_user_message(telegram_channel, normal_message)
_send_user_message_on_channel(telegram_channel, normal_message)

reset_message = telegram_messages.text_message(
chat_id=telegram_chat_id, message_text=ExperimentChannel.RESET_COMMAND
Expand All @@ -184,18 +185,18 @@ def test_reset_conversation_does_not_create_new_session(
telegram_chat_id = 00000

message1 = telegram_messages.text_message(chat_id=telegram_chat_id, message_text=ExperimentChannel.RESET_COMMAND)
_simulate_user_message(telegram_channel, message1)
_send_user_message_on_channel(telegram_channel, message1)

message2 = telegram_messages.text_message(chat_id=telegram_chat_id, message_text=ExperimentChannel.RESET_COMMAND)
_simulate_user_message(telegram_channel, message2)
_send_user_message_on_channel(telegram_channel, message2)

sessions = ExperimentSession.objects.for_chat_id(telegram_chat_id).all()
assert len(sessions) == 1
# The reset command should not be saved in the history
assert sessions[0].chat.get_langchain_messages() == []


def _simulate_user_message(channel_instance, user_message: str):
def _send_user_message_on_channel(channel_instance, user_message: str):
with mock_experiment_llm(channel_instance.experiment, responses=["OK"]):
channel_instance.new_user_message(user_message)

Expand Down Expand Up @@ -229,7 +230,7 @@ def _user_message(message: str):

_user_message("Hi")
chat = channel.experiment_session.chat
pre_survey_link = channel.experiment_session.get_pre_survey_link()
pre_survey_link = channel.experiment_session.get_pre_survey_link(experiment)
confirmation_text = pre_survey.confirmation_text
expected_survey_text = confirmation_text.format(survey_link=pre_survey_link)
# Let's see if the bot asked consent
Expand Down Expand Up @@ -775,3 +776,52 @@ def test_participant_authorization(
telegram_channel.new_user_message(message)
send_text_to_user.assert_called()
assert send_text_to_user.call_args[0][0] == "Sorry, you are not allowed to chat to this bot"


class TestChannel(ChannelBase):
def send_text_to_user(self):
pass


@pytest.mark.django_db()
class TestBaseChannelMethods:
"""Unit tests for the methods of the ChannelBase class"""

def test_participant_identifier(self):
"""Fetching the participant data"""
session = ExperimentSessionFactory(participant__identifier="Alpha")
exp_channel = ExperimentChannelFactory(experiment=session.experiment)
channel_base = TestChannel(experiment=session.experiment, experiment_channel=exp_channel)
channel_base.message = telegram_messages.text_message(chat_id="Beta")

assert channel_base.participant_identifier == "Beta"
channel_base.experiment_session = session
assert channel_base.participant_identifier == "Alpha"

@patch("apps.chat.channels.TelegramChannel.send_text_to_user", Mock())
@patch("apps.chat.channels.TelegramChannel._get_bot_response", Mock())
def test_new_sessions_are_linked_to_the_working_experiment(self, experiment):
working_version = experiment
channel = ExperimentChannelFactory(experiment=working_version)
new_version = working_version.create_new_version()

telegram = TelegramChannel(experiment=new_version, experiment_channel=channel)
telegram.telegram_bot = Mock()
telegram.new_user_message(telegram_messages.text_message())

# Check that the working experiment is linked to the session
assert ExperimentSession.objects.filter(experiment=working_version).exists()
assert not ExperimentSession.objects.filter(experiment=new_version).exists()

def test_can_start_a_session_with_working_experiment(self, experiment):
assert experiment.is_versioned is False
channel = ExperimentChannelFactory(experiment=experiment)
session = ChannelBase.start_new_session(experiment, channel, participant_identifier="testy-pie")
assert session.experiment == experiment

def test_cannot_start_a_session_with_an_experiment_version(self, experiment):
channel = ExperimentChannelFactory(experiment=experiment)
new_version = experiment.create_new_version()
assert new_version.is_versioned is True
with pytest.raises(VersionedExperimentSessionsNotAllowedException):
ChannelBase.start_new_session(new_version, channel, participant_identifier="testy-pie")
17 changes: 15 additions & 2 deletions apps/channels/tests/test_web_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@patch("apps.events.tasks.enqueue_static_triggers", Mock())
@patch("apps.chat.channels.WebChannel.new_user_message")
def test_start_new_session(new_user_message, with_seed_message, experiment):
"""A simple test to make sure we create"""
"""A simple test to make sure we create a session and send a session message"""
if with_seed_message:
experiment.seed_message = "Tell a joke"
experiment.save()
Expand All @@ -37,4 +37,17 @@ def test_start_new_session(new_user_message, with_seed_message, experiment):
assert message.attachments == []


# TODO: Add more tests
@pytest.mark.django_db()
class TestVersioning:
@patch("apps.events.tasks.enqueue_static_triggers", Mock())
@patch("apps.chat.channels.WebChannel.check_and_process_seed_message")
def test_start_new_session_uses_default_version(self, check_and_process_seed_message, experiment):
new_version = experiment.create_new_version()
session = WebChannel.start_new_session(
experiment,
"jack@titanic.com",
)

_session_used, experiment_used = check_and_process_seed_message.call_args[0]
assert experiment_used == new_version
assert session.experiment == experiment
6 changes: 4 additions & 2 deletions apps/channels/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ def new_api_message(request, experiment_id: uuid):
session = None
if session_id := message_data.get("session"):
try:
# TODO: Support ability to select a specific version
experiment = Experiment.objects.get(public_id=experiment_id)
session = ExperimentSession.objects.select_related("experiment", "experiment_channel").get(
external_id=session_id,
experiment__public_id=experiment_id,
experiment=experiment,
team=request.team,
participant__user=request.user,
experiment_channel__platform=ChannelPlatform.API,
Expand All @@ -115,7 +117,7 @@ def new_api_message(request, experiment_id: uuid):

response = tasks.handle_api_message(
request.user,
experiment,
experiment.default_version,
experiment_channel,
message_data["message"],
participant_id,
Expand Down
5 changes: 2 additions & 3 deletions apps/chat/bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TopicBot:
"""

def __init__(self, session: ExperimentSession, experiment: Experiment | None = None, disable_tools: bool = False):
self.experiment = experiment or session.experiment
self.experiment = experiment or session.experiment_version
self.disable_tools = disable_tools
self.prompt = self.experiment.prompt_text
self.input_formatter = self.experiment.input_formatter
Expand All @@ -85,7 +85,6 @@ def __init__(self, session: ExperimentSession, experiment: Experiment | None = N
self.trace_service = None
if self.experiment.trace_provider:
self.trace_service = self.experiment.trace_provider.get_service()

self._initialize()

def _initialize(self):
Expand Down Expand Up @@ -255,7 +254,7 @@ def filter_ai_messages(self) -> bool:

class PipelineBot:
def __init__(self, session: ExperimentSession):
self.experiment = session.experiment
self.experiment = session.experiment_version
self.session = session

def process_input(self, user_input: str, save_input_to_history=True, attachments: list["Attachment"] | None = None):
Expand Down
Loading

0 comments on commit a8485ff

Please sign in to comment.