diff --git a/api-schema.yml b/api-schema.yml index 6098864cd..74d12c375 100644 --- a/api-schema.yml +++ b/api-schema.yml @@ -368,6 +368,10 @@ components: format: uri readOnly: true title: API URL + version_number: + type: integer + maximum: 2147483647 + minimum: 0 required: - id - name diff --git a/apps/annotations/migrations/0005_alter_tag_category.py b/apps/annotations/migrations/0005_alter_tag_category.py new file mode 100644 index 000000000..120ecab6e --- /dev/null +++ b/apps/annotations/migrations/0005_alter_tag_category.py @@ -0,0 +1,18 @@ +# Generated by Django 5.1 on 2024-09-06 10:51 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('annotations', '0004_default_tag_category'), + ] + + operations = [ + migrations.AlterField( + model_name='tag', + name='category', + field=models.CharField(blank=True, choices=[('bot_response', 'Bot Response'), ('experiment_version', 'Experiment Version')], default=''), + ), + ] diff --git a/apps/annotations/models.py b/apps/annotations/models.py index 1be129ba5..b9fa15603 100644 --- a/apps/annotations/models.py +++ b/apps/annotations/models.py @@ -17,6 +17,7 @@ class TagCategories(models.TextChoices): BOT_RESPONSE = "bot_response", _("Bot Response") + EXPERIMENT_VERSION = "experiment_version", _("Experiment Version") @audit_fields( diff --git a/apps/api/openai.py b/apps/api/openai.py index 6ba053e5e..4f717cfb2 100644 --- a/apps/api/openai.py +++ b/apps/api/openai.py @@ -109,7 +109,7 @@ def chat_completions(request, experiment_id: str): session = serializer.save() response_message = handle_api_message( request.user, - session.experiment, + session.experiment_version, session.experiment_channel, last_message.get("content"), session.participant.identifier, diff --git a/apps/api/serializers.py b/apps/api/serializers.py index a03627da0..0cf823242 100644 --- a/apps/api/serializers.py +++ b/apps/api/serializers.py @@ -21,7 +21,7 @@ class ExperimentSerializer(serializers.ModelSerializer): class Meta: model = Experiment - fields = ["id", "name", "url"] + fields = ["id", "name", "url", "version_number"] class ParticipantSerializer(serializers.ModelSerializer): diff --git a/apps/api/tests/test_api.py b/apps/api/tests/test_api.py index 7c0e614fa..68140354a 100644 --- a/apps/api/tests/test_api.py +++ b/apps/api/tests/test_api.py @@ -30,6 +30,7 @@ def test_list_experiments(experiment): "name": experiment.name, "id": experiment.public_id, "url": f"http://testserver/api/experiments/{experiment.public_id}/", + "version_number": 1, } ], "next": None, @@ -48,6 +49,7 @@ def test_retrieve_experiments(experiment): "id": experiment.public_id, "name": experiment.name, "url": f"http://testserver/api/experiments/{experiment.public_id}/", + "version_number": 1, } diff --git a/apps/api/tests/test_session_api.py b/apps/api/tests/test_session_api.py index 4dd7409f5..a54ed1c01 100644 --- a/apps/api/tests/test_session_api.py +++ b/apps/api/tests/test_session_api.py @@ -82,6 +82,7 @@ def get_session_json(session, expected_messages=None): "id": str(experiment.public_id), "name": experiment.name, "url": f"http://testserver/api/experiments/{experiment.public_id}/", + "version_number": 1, }, "participant": {"identifier": session.participant.identifier}, "id": str(session.external_id), diff --git a/apps/channels/tasks.py b/apps/channels/tasks.py index 5374f6524..b792e22fa 100644 --- a/apps/channels/tasks.py +++ b/apps/channels/tasks.py @@ -33,7 +33,7 @@ def handle_telegram_message(self, message_data: str, channel_external_id: uuid): return message = TelegramMessage.parse(update) - message_handler = TelegramChannel(experiment_channel.experiment, experiment_channel) + message_handler = TelegramChannel(experiment_channel.experiment.default_version, experiment_channel) update_taskbadger_data(self, message_handler, message) with current_team(experiment_channel.team): @@ -68,7 +68,7 @@ def handle_twilio_message(self, message_data: str, request_uri: str, signature: validate_twillio_request(experiment_channel, raw_data, request_uri, signature) - message_handler = ChannelClass(experiment_channel.experiment, experiment_channel=experiment_channel) + message_handler = ChannelClass(experiment_channel.experiment.default_version, experiment_channel=experiment_channel) update_taskbadger_data(self, message_handler, message) with current_team(experiment_channel.team): @@ -106,7 +106,7 @@ def handle_sureadhere_message(self, sureadhere_tenant_id: str, message_data: dic if not experiment_channel: log.info(f"No experiment channel found for SureAdhere tenant ID: {sureadhere_tenant_id}") return - channel = SureAdhereChannel(experiment_channel.experiment, experiment_channel) + channel = SureAdhereChannel(experiment_channel.experiment.default_version, experiment_channel) update_taskbadger_data(self, channel, message) with current_team(experiment_channel.team): channel.new_user_message(message) @@ -127,18 +127,20 @@ def handle_turn_message(self, experiment_id: uuid, message_data: dict): if not experiment_channel: log.info(f"No experiment channel found for experiment_id: {experiment_id}") return - channel = WhatsappChannel(experiment_channel.experiment, experiment_channel) + channel = WhatsappChannel(experiment_channel.experiment.default_version, experiment_channel) update_taskbadger_data(self, channel, message) with current_team(experiment_channel.team): channel.new_user_message(message) -def handle_api_message(user, experiment, experiment_channel, message_text: str, participant_id: str, session=None): +def handle_api_message( + user, experiment_version, experiment_channel, message_text: str, participant_id: str, session=None +): """Synchronously handles the message coming from the API""" message = BaseMessage(participant_id=participant_id, message_text=message_text) channel = ApiChannel( - experiment, + experiment_version, experiment_channel, experiment_session=session, user=user, diff --git a/apps/channels/tests/test_base_channel_behavior.py b/apps/channels/tests/test_base_channel_behavior.py index b791e0739..3342bb527 100644 --- a/apps/channels/tests/test_base_channel_behavior.py +++ b/apps/channels/tests/test_base_channel_behavior.py @@ -10,6 +10,7 @@ from apps.channels.models import ChannelPlatform, ExperimentChannel from apps.chat.channels import URL_REGEX, ChannelBase, TelegramChannel, strip_urls_and_emojis +from apps.chat.exceptions import VersionedExperimentSessionsNotAllowedException from apps.chat.models import ChatMessageType from apps.experiments.models import ( ExperimentRoute, @@ -45,7 +46,7 @@ def test_incoming_message_adds_channel_info(telegram_channel): chat_id = 123123 message = telegram_messages.text_message(chat_id=chat_id) - _simulate_user_message(telegram_channel, message) + _send_user_message_on_channel(telegram_channel, message) experiment_session = ExperimentSession.objects.filter( experiment=telegram_channel.experiment, participant__identifier=chat_id @@ -60,7 +61,7 @@ def test_incoming_message_adds_channel_info(telegram_channel): def test_channel_added_for_experiment_session(telegram_channel): chat_id = 123123 message = telegram_messages.text_message(chat_id=chat_id) - _simulate_user_message(telegram_channel, message) + _send_user_message_on_channel(telegram_channel, message) participant = Participant.objects.get(identifier=chat_id) experiment_session = participant.experimentsession_set.first() assert experiment_session.experiment_channel is not None @@ -76,7 +77,7 @@ def test_incoming_message_uses_existing_experiment_session(telegram_channel): # First message message = telegram_messages.text_message(chat_id=chat_id) - _simulate_user_message(telegram_channel, message) + _send_user_message_on_channel(telegram_channel, message) # Let's find the session it created experiment_sessions_count = ExperimentSession.objects.filter( @@ -88,7 +89,7 @@ def test_incoming_message_uses_existing_experiment_session(telegram_channel): telegram_channel._create_new_experiment_session = Mock() # Second message - _simulate_user_message(telegram_channel, message) + _send_user_message_on_channel(telegram_channel, message) # Assertions experiment_sessions_count = ExperimentSession.objects.filter( @@ -107,14 +108,14 @@ def test_different_sessions_created_for_different_users(telegram_channel): # First user's message user_1_message = telegram_messages.text_message(chat_id=user_1_chat_id) - _simulate_user_message(telegram_channel, user_1_message) + _send_user_message_on_channel(telegram_channel, user_1_message) # Calling new_user_message added an experiment_session, so we should remove it before reusing the instance telegram_channel.experiment_session = None # Second user's message user_2_message = telegram_messages.text_message(chat_id=user_2_chat_id) - _simulate_user_message(telegram_channel, user_2_message) + _send_user_message_on_channel(telegram_channel, user_2_message) # Assertions experiment_sessions_count = ExperimentSession.objects.count() @@ -141,8 +142,8 @@ def test_different_participants_created_for_same_user_in_different_teams(): assert experiment1.team != experiment2.team - _simulate_user_message(channel1, user_message) - _simulate_user_message(channel2, user_message) + _send_user_message_on_channel(channel1, user_message) + _send_user_message_on_channel(channel2, user_message) experiment_sessions_count = ExperimentSession.objects.count() assert experiment_sessions_count == 2 @@ -159,7 +160,7 @@ def test_reset_command_creates_new_experiment_session(_send_text_to_user_mock, t telegram_chat_id = 00000 normal_message = telegram_messages.text_message(chat_id=telegram_chat_id) - _simulate_user_message(telegram_channel, normal_message) + _send_user_message_on_channel(telegram_channel, normal_message) reset_message = telegram_messages.text_message( chat_id=telegram_chat_id, message_text=ExperimentChannel.RESET_COMMAND @@ -184,10 +185,10 @@ def test_reset_conversation_does_not_create_new_session( telegram_chat_id = 00000 message1 = telegram_messages.text_message(chat_id=telegram_chat_id, message_text=ExperimentChannel.RESET_COMMAND) - _simulate_user_message(telegram_channel, message1) + _send_user_message_on_channel(telegram_channel, message1) message2 = telegram_messages.text_message(chat_id=telegram_chat_id, message_text=ExperimentChannel.RESET_COMMAND) - _simulate_user_message(telegram_channel, message2) + _send_user_message_on_channel(telegram_channel, message2) sessions = ExperimentSession.objects.for_chat_id(telegram_chat_id).all() assert len(sessions) == 1 @@ -195,7 +196,7 @@ def test_reset_conversation_does_not_create_new_session( assert sessions[0].chat.get_langchain_messages() == [] -def _simulate_user_message(channel_instance, user_message: str): +def _send_user_message_on_channel(channel_instance, user_message: str): with mock_experiment_llm(channel_instance.experiment, responses=["OK"]): channel_instance.new_user_message(user_message) @@ -229,7 +230,7 @@ def _user_message(message: str): _user_message("Hi") chat = channel.experiment_session.chat - pre_survey_link = channel.experiment_session.get_pre_survey_link() + pre_survey_link = channel.experiment_session.get_pre_survey_link(experiment) confirmation_text = pre_survey.confirmation_text expected_survey_text = confirmation_text.format(survey_link=pre_survey_link) # Let's see if the bot asked consent @@ -775,3 +776,52 @@ def test_participant_authorization( telegram_channel.new_user_message(message) send_text_to_user.assert_called() assert send_text_to_user.call_args[0][0] == "Sorry, you are not allowed to chat to this bot" + + +class TestChannel(ChannelBase): + def send_text_to_user(self): + pass + + +@pytest.mark.django_db() +class TestBaseChannelMethods: + """Unit tests for the methods of the ChannelBase class""" + + def test_participant_identifier(self): + """Fetching the participant data""" + session = ExperimentSessionFactory(participant__identifier="Alpha") + exp_channel = ExperimentChannelFactory(experiment=session.experiment) + channel_base = TestChannel(experiment=session.experiment, experiment_channel=exp_channel) + channel_base.message = telegram_messages.text_message(chat_id="Beta") + + assert channel_base.participant_identifier == "Beta" + channel_base.experiment_session = session + assert channel_base.participant_identifier == "Alpha" + + @patch("apps.chat.channels.TelegramChannel.send_text_to_user", Mock()) + @patch("apps.chat.channels.TelegramChannel._get_bot_response", Mock()) + def test_new_sessions_are_linked_to_the_working_experiment(self, experiment): + working_version = experiment + channel = ExperimentChannelFactory(experiment=working_version) + new_version = working_version.create_new_version() + + telegram = TelegramChannel(experiment=new_version, experiment_channel=channel) + telegram.telegram_bot = Mock() + telegram.new_user_message(telegram_messages.text_message()) + + # Check that the working experiment is linked to the session + assert ExperimentSession.objects.filter(experiment=working_version).exists() + assert not ExperimentSession.objects.filter(experiment=new_version).exists() + + def test_can_start_a_session_with_working_experiment(self, experiment): + assert experiment.is_versioned is False + channel = ExperimentChannelFactory(experiment=experiment) + session = ChannelBase.start_new_session(experiment, channel, participant_identifier="testy-pie") + assert session.experiment == experiment + + def test_cannot_start_a_session_with_an_experiment_version(self, experiment): + channel = ExperimentChannelFactory(experiment=experiment) + new_version = experiment.create_new_version() + assert new_version.is_versioned is True + with pytest.raises(VersionedExperimentSessionsNotAllowedException): + ChannelBase.start_new_session(new_version, channel, participant_identifier="testy-pie") diff --git a/apps/channels/tests/test_web_channel.py b/apps/channels/tests/test_web_channel.py index 757afeb15..1c2aaad95 100644 --- a/apps/channels/tests/test_web_channel.py +++ b/apps/channels/tests/test_web_channel.py @@ -12,7 +12,7 @@ @patch("apps.events.tasks.enqueue_static_triggers", Mock()) @patch("apps.chat.channels.WebChannel.new_user_message") def test_start_new_session(new_user_message, with_seed_message, experiment): - """A simple test to make sure we create""" + """A simple test to make sure we create a session and send a session message""" if with_seed_message: experiment.seed_message = "Tell a joke" experiment.save() @@ -37,4 +37,17 @@ def test_start_new_session(new_user_message, with_seed_message, experiment): assert message.attachments == [] -# TODO: Add more tests +@pytest.mark.django_db() +class TestVersioning: + @patch("apps.events.tasks.enqueue_static_triggers", Mock()) + @patch("apps.chat.channels.WebChannel.check_and_process_seed_message") + def test_start_new_session_uses_default_version(self, check_and_process_seed_message, experiment): + new_version = experiment.create_new_version() + session = WebChannel.start_new_session( + experiment, + "jack@titanic.com", + ) + + _session_used, experiment_used = check_and_process_seed_message.call_args[0] + assert experiment_used == new_version + assert session.experiment == experiment diff --git a/apps/channels/views.py b/apps/channels/views.py index 05b0bacab..fcf5fb342 100644 --- a/apps/channels/views.py +++ b/apps/channels/views.py @@ -96,9 +96,11 @@ def new_api_message(request, experiment_id: uuid): session = None if session_id := message_data.get("session"): try: + # TODO: Support ability to select a specific version + experiment = Experiment.objects.get(public_id=experiment_id) session = ExperimentSession.objects.select_related("experiment", "experiment_channel").get( external_id=session_id, - experiment__public_id=experiment_id, + experiment=experiment, team=request.team, participant__user=request.user, experiment_channel__platform=ChannelPlatform.API, @@ -115,7 +117,7 @@ def new_api_message(request, experiment_id: uuid): response = tasks.handle_api_message( request.user, - experiment, + experiment.default_version, experiment_channel, message_data["message"], participant_id, diff --git a/apps/chat/bots.py b/apps/chat/bots.py index 4f1e4847b..830db482e 100644 --- a/apps/chat/bots.py +++ b/apps/chat/bots.py @@ -60,7 +60,7 @@ class TopicBot: """ def __init__(self, session: ExperimentSession, experiment: Experiment | None = None, disable_tools: bool = False): - self.experiment = experiment or session.experiment + self.experiment = experiment or session.experiment_version self.disable_tools = disable_tools self.prompt = self.experiment.prompt_text self.input_formatter = self.experiment.input_formatter @@ -85,7 +85,6 @@ def __init__(self, session: ExperimentSession, experiment: Experiment | None = N self.trace_service = None if self.experiment.trace_provider: self.trace_service = self.experiment.trace_provider.get_service() - self._initialize() def _initialize(self): @@ -255,7 +254,7 @@ def filter_ai_messages(self) -> bool: class PipelineBot: def __init__(self, session: ExperimentSession): - self.experiment = session.experiment + self.experiment = session.experiment_version self.session = session def process_input(self, user_input: str, save_input_to_history=True, attachments: list["Attachment"] | None = None): diff --git a/apps/chat/channels.py b/apps/chat/channels.py index 989d1da15..f4a0a323d 100644 --- a/apps/chat/channels.py +++ b/apps/chat/channels.py @@ -15,7 +15,12 @@ from apps.channels import audio from apps.channels.models import ChannelPlatform, ExperimentChannel from apps.chat.bots import get_bot -from apps.chat.exceptions import AudioSynthesizeException, MessageHandlerException, ParticipantNotAllowedException +from apps.chat.exceptions import ( + AudioSynthesizeException, + MessageHandlerException, + ParticipantNotAllowedException, + VersionedExperimentSessionsNotAllowedException, +) from apps.chat.models import ChatMessage, ChatMessageType from apps.events.models import StaticTriggerType from apps.events.tasks import enqueue_static_triggers @@ -108,12 +113,12 @@ def __init__( self.experiment_session = experiment_session self.message = None self._user_query = None - self.bot = get_bot(experiment_session) if experiment_session else None + self.bot = get_bot(experiment_session, experiment=experiment) if experiment_session else None @classmethod def start_new_session( cls, - experiment: Experiment, + working_experiment: Experiment, experiment_channel: ExperimentChannel, participant_identifier: str, participant_user: CustomUser | None = None, @@ -122,7 +127,7 @@ def start_new_session( session_external_id: str | None = None, ): return _start_experiment_session( - experiment, + working_experiment, experiment_channel, participant_identifier, participant_user, @@ -220,7 +225,7 @@ def _add_message(self, message): raise ParticipantNotAllowedException() self._ensure_sessions_exists() - self.bot = get_bot(self.experiment_session) + self.bot = get_bot(self.experiment_session, experiment=self.experiment) def new_user_message(self, message) -> str: """Handles the message coming from the user. Call this to send bot messages to the user. @@ -316,7 +321,7 @@ def _ask_user_for_consent(self): self.send_text_to_user(bot_message) def _ask_user_to_take_survey(self): - pre_survey_link = self.experiment_session.get_pre_survey_link() + pre_survey_link = self.experiment_session.get_pre_survey_link(self.experiment) confirmation_text = self.experiment.pre_survey.confirmation_text bot_message = confirmation_text.format(survey_link=pre_survey_link) self._add_message_to_history(bot_message, ChatMessageType.AI) @@ -414,7 +419,7 @@ def _transcribe_audio(self, audio: BytesIO) -> str: return speech_service.transcribe_audio(audio) def _get_bot_response(self, message: str) -> str: - self.bot = self.bot or get_bot(self.experiment_session) + self.bot = self.bot or get_bot(self.experiment_session, experiment=self.experiment) answer = self.bot.process_input(message, attachments=self.message.attachments) return answer @@ -441,7 +446,6 @@ def _ensure_sessions_exists(self): if not self.experiment_session: self._create_new_experiment_session() - enqueue_static_triggers.delay(self.experiment_session.id, StaticTriggerType.PARTICIPANT_JOINED_EXPERIMENT) else: if self._is_reset_conversation_request() and self.experiment_session.user_already_engaged(): self._reset_session() @@ -460,7 +464,7 @@ def _ensure_sessions_exists(self): def _get_latest_session(self): return ( ExperimentSession.objects.filter( - experiment=self.experiment, + experiment=self.experiment.get_working_version(), participant__identifier=str(self.participant_identifier), ) .order_by("-created_at") @@ -477,7 +481,7 @@ def _create_new_experiment_session(self): session """ self.experiment_session = self.start_new_session( - experiment=self.experiment, + working_experiment=self.experiment.get_working_version(), experiment_channel=self.experiment_channel, participant_identifier=self.participant_identifier, participant_user=self.participant_user, @@ -513,7 +517,7 @@ def _inform_user_of_error(self): def _generate_response_for_user(self, prompt: str) -> str: """Generates a response based on the `prompt`.""" - topic_bot = self.bot or get_bot(self.experiment_session) + topic_bot = self.bot or get_bot(self.experiment_session, experiment=self.experiment) return topic_bot.process_input(user_input=prompt, save_input_to_history=False) @@ -534,26 +538,26 @@ def _ensure_sessions_exists(self): @classmethod def start_new_session( cls, - experiment: Experiment, + working_experiment: Experiment, participant_identifier: str, participant_user: CustomUser | None = None, session_status: SessionStatus = SessionStatus.ACTIVE, timezone: str | None = None, ): - experiment_channel = ExperimentChannel.objects.get_team_web_channel(experiment.team) + experiment_channel = ExperimentChannel.objects.get_team_web_channel(working_experiment.team) session = super().start_new_session( - experiment, experiment_channel, participant_identifier, participant_user, session_status, timezone + working_experiment, experiment_channel, participant_identifier, participant_user, session_status, timezone ) - WebChannel.check_and_process_seed_message(session) + WebChannel.check_and_process_seed_message(session, working_experiment.default_version) return session @classmethod - def check_and_process_seed_message(cls, session: ExperimentSession): + def check_and_process_seed_message(cls, session: ExperimentSession, experiment: Experiment): from apps.experiments.tasks import get_response_for_webchat_task - if session.experiment.seed_message: + if seed_message := experiment.seed_message: session.seed_task_id = get_response_for_webchat_task.delay( - session.id, message_text=session.experiment.seed_message, attachments=[] + experiment_session_id=session.id, experiment_id=experiment.id, message_text=seed_message, attachments=[] ).task_id session.save() return session @@ -758,7 +762,7 @@ def _ensure_sessions_exists(self): def _start_experiment_session( - experiment: Experiment, + working_experiment: Experiment, experiment_channel: ExperimentChannel, participant_identifier: str, participant_user: CustomUser | None = None, @@ -766,6 +770,12 @@ def _start_experiment_session( timezone: str | None = None, session_external_id: str | None = None, ) -> ExperimentSession: + if working_experiment.is_versioned: + raise VersionedExperimentSessionsNotAllowedException( + message="A session cannot be linked to an experiment version. " + ) + + team = working_experiment.team if not participant_identifier and not participant_user: raise ValueError("Either participant_identifier or participant_user must be specified!") @@ -775,7 +785,7 @@ def _start_experiment_session( with transaction.atomic(): participant, created = Participant.objects.get_or_create( - team=experiment.team, + team=team, identifier=participant_identifier, platform=experiment_channel.platform, defaults={"user": participant_user}, @@ -785,8 +795,8 @@ def _start_experiment_session( participant.save() session = ExperimentSession.objects.create( - team=experiment.team, - experiment=experiment, + team=team, + experiment=working_experiment, experiment_channel=experiment_channel, status=session_status, participant=participant, @@ -795,7 +805,7 @@ def _start_experiment_session( # Record the participant's timezone if timezone: - participant.update_memory(data={"timezone": timezone}, experiment=experiment) + participant.update_memory(data={"timezone": timezone}, experiment=working_experiment) if participant.experimentsession_set.count() == 1: enqueue_static_triggers.delay(session.id, StaticTriggerType.PARTICIPANT_JOINED_EXPERIMENT) diff --git a/apps/chat/exceptions.py b/apps/chat/exceptions.py index 41d15d7a3..a423b1114 100644 --- a/apps/chat/exceptions.py +++ b/apps/chat/exceptions.py @@ -21,3 +21,7 @@ def __init__(self, message): class ParticipantNotAllowedException(Exception): pass + + +class VersionedExperimentSessionsNotAllowedException(ChatException): + pass diff --git a/apps/chat/models.py b/apps/chat/models.py index 1095c470b..97d52f618 100644 --- a/apps/chat/models.py +++ b/apps/chat/models.py @@ -7,7 +7,7 @@ from django.utils.functional import classproperty from langchain.schema import BaseMessage, messages_from_dict -from apps.annotations.models import TaggedModelMixin, UserCommentsMixin +from apps.annotations.models import Tag, TagCategories, TaggedModelMixin, UserCommentsMixin from apps.files.models import File from apps.teams.models import BaseTeamModel from apps.utils.models import BaseModel @@ -188,6 +188,15 @@ def get_attached_files(self): def get_metadata(self, key: str): return self.metadata.get(key, None) + def add_system_tag(self, tag: str, tag_category: TagCategories): + tag, _ = Tag.objects.get_or_create( + name=tag, + team=self.chat.team, + is_system_tag=True, + category=tag_category, + ) + self.add_tag(tag, team=self.chat.team, added_by=None) + class ChatAttachment(BaseModel): chat = models.ForeignKey(Chat, on_delete=models.CASCADE, related_name="attachments") diff --git a/apps/experiments/admin.py b/apps/experiments/admin.py index c5ec06501..a04fe9604 100644 --- a/apps/experiments/admin.py +++ b/apps/experiments/admin.py @@ -1,4 +1,6 @@ from django.contrib import admin +from django.db.models.query import QuerySet +from django.http import HttpRequest from apps.experiments import models @@ -67,12 +69,30 @@ class SurveyAdmin(admin.ModelAdmin): @admin.register(models.Experiment) class ExperimentAdmin(admin.ModelAdmin): - list_display = ("name", "team", "owner", "source_material", "llm", "llm_provider") + list_display = ( + "name", + "team", + "owner", + "source_material", + "llm", + "llm_provider", + "version_family", + "version_number", + ) list_filter = ("team", "owner", "source_material") inlines = [SafetyLayerInline] exclude = ["safety_layers"] readonly_fields = ("public_id",) + def get_queryset(self, request: HttpRequest) -> QuerySet: + return models.Experiment.objects.get_all() + + @admin.display(description="Version Family") + def version_family(self, obj): + if obj.working_version: + return obj.working_version.name + return obj.name + @admin.register(models.ExperimentRoute) class ExperimentRouteAdmin(admin.ModelAdmin): diff --git a/apps/experiments/decorators.py b/apps/experiments/decorators.py index e2d387a55..aa7aef31a 100644 --- a/apps/experiments/decorators.py +++ b/apps/experiments/decorators.py @@ -25,11 +25,14 @@ def decorator(view_func): def decorated_view(request, team_slug: str, experiment_id: str, session_id: str): request.experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) request.experiment_session = get_object_or_404( - ExperimentSession, experiment=request.experiment, external_id=session_id, team=request.team + ExperimentSession, + experiment=request.experiment, + external_id=session_id, + team=request.team, ) if allowed_states and request.experiment_session.status not in allowed_states: - return _redirect_for_state(request, request.experiment_session, team_slug) + return _redirect_for_state(request, team_slug) return view_func(request, team_slug, experiment_id, session_id) return decorated_view @@ -98,20 +101,22 @@ def _validate_access_cookie_data(experiment_session, access_data): return _get_access_cookie_data(experiment_session) == access_data -def _redirect_for_state(request, experiment_session, team_slug): - view_args = [team_slug, experiment_session.experiment.public_id, experiment_session.external_id] - if experiment_session.status in [SessionStatus.SETUP, SessionStatus.PENDING]: - return HttpResponseRedirect(reverse("experiments:start_session_from_invite", args=view_args)) - elif experiment_session.status == SessionStatus.PENDING_PRE_SURVEY: - return HttpResponseRedirect(reverse("experiments:experiment_pre_survey", args=view_args)) - elif experiment_session.status == SessionStatus.ACTIVE: - return HttpResponseRedirect(reverse("experiments:experiment_chat", args=view_args)) - elif experiment_session.status == SessionStatus.PENDING_REVIEW: - return HttpResponseRedirect(reverse("experiments:experiment_review", args=view_args)) - elif experiment_session.status == SessionStatus.COMPLETE: - return HttpResponseRedirect(reverse("experiments:experiment_complete", args=view_args)) - else: - messages.info( - request, "Session was in an unknown/unexpected state." " It may be old, or something may have gone wrong." - ) - return HttpResponseRedirect(reverse("experiments:experiment_session_view", args=view_args)) +def _redirect_for_state(request, team_slug): + view_args = [team_slug, request.experiment.public_id, request.experiment_session.external_id] + match request.experiment_session.status: + case SessionStatus.SETUP | SessionStatus.PENDING: + return HttpResponseRedirect(reverse("experiments:start_session_from_invite", args=view_args)) + case SessionStatus.PENDING_PRE_SURVEY: + return HttpResponseRedirect(reverse("experiments:experiment_pre_survey", args=view_args)) + case SessionStatus.ACTIVE: + return HttpResponseRedirect(reverse("experiments:experiment_chat", args=view_args)) + case SessionStatus.PENDING_REVIEW: + return HttpResponseRedirect(reverse("experiments:experiment_review", args=view_args)) + case SessionStatus.COMPLETE: + return HttpResponseRedirect(reverse("experiments:experiment_complete", args=view_args)) + case _: + messages.info( + request, + "Session was in an unknown/unexpected state." " It may be old, or something may have gone wrong.", + ) + return HttpResponseRedirect(reverse("experiments:experiment_session_view", args=view_args)) diff --git a/apps/experiments/email.py b/apps/experiments/email.py index 67e3261ef..644d4cb73 100644 --- a/apps/experiments/email.py +++ b/apps/experiments/email.py @@ -10,11 +10,13 @@ def send_experiment_invitation(experiment_session: ExperimentSession): if not experiment_session.participant: raise Exception("Session has no participant!") + experiment_version_name = experiment_session.experiment_version.name email_context = { "session": experiment_session, + "experiment_name": experiment_version_name, } send_mail( - subject=_("You're invited to {}!").format(experiment_session.experiment.name), + subject=_("You're invited to {}!").format(experiment_version_name), message=render_to_string("experiments/email/invitation.txt", context=email_context), from_email=settings.DEFAULT_FROM_EMAIL, recipient_list=[experiment_session.participant.email], diff --git a/apps/experiments/forms.py b/apps/experiments/forms.py index 5993e73fd..de2f3b068 100644 --- a/apps/experiments/forms.py +++ b/apps/experiments/forms.py @@ -8,7 +8,6 @@ class ConsentForm(forms.Form): identifier = forms.CharField(required=False) consent_agreement = forms.BooleanField(required=True, label="I Agree") - experiment_id = forms.IntegerField(widget=forms.HiddenInput()) participant_id = forms.IntegerField(required=False, widget=forms.HiddenInput()) def __init__(self, consent, *args, **kwargs): diff --git a/apps/experiments/migrations/0094_consentform_working_version_and_more.py b/apps/experiments/migrations/0094_consentform_working_version_and_more.py new file mode 100644 index 000000000..94133ee39 --- /dev/null +++ b/apps/experiments/migrations/0094_consentform_working_version_and_more.py @@ -0,0 +1,29 @@ +# Generated by Django 5.1 on 2024-09-10 06:30 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('experiments', '0093_participant_name'), + ] + + operations = [ + migrations.AddField( + model_name='consentform', + name='working_version', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='versions', to='experiments.consentform'), + ), + migrations.AddField( + model_name='experiment', + name='version_description', + field=models.TextField(blank=True, default=''), + ), + migrations.AddField( + model_name='survey', + name='working_version', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='versions', to='experiments.survey'), + ), + ] diff --git a/apps/experiments/models.py b/apps/experiments/models.py index d22fa9e3a..75bf4c29d 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -12,13 +12,13 @@ from django.contrib.postgres.fields import ArrayField from django.core.validators import MaxValueValidator, MinValueValidator, validate_email from django.db import models, transaction -from django.db.models import Count, OuterRef, Q, Subquery +from django.db.models import BooleanField, Case, Count, OuterRef, Q, Subquery, When from django.urls import reverse from django.utils import timezone from django.utils.translation import gettext from django_cryptography.fields import encrypt from field_audit import audit_fields -from field_audit.models import AuditingManager +from field_audit.models import AuditAction, AuditingManager from apps.chat.models import Chat, ChatMessage, ChatMessageType from apps.experiments import model_audit_fields @@ -52,17 +52,57 @@ def working_versions_queryset(self): """Returns a queryset for all working experiments""" return self.get_queryset().filter(working_version=None) + def get_queryset(self): + return super().get_queryset().filter(is_archived=False) + + def get_all(self): + """A method to return all experiments whether it is deprecated or not""" + return super().get_queryset() + class SourceMaterialObjectManager(AuditingManager): - pass + def get_queryset(self) -> models.QuerySet: + return ( + super() + .get_queryset() + .annotate( + is_version=Case( + When(working_version_id__isnull=False, then=True), + When(working_version_id__isnull=True, then=False), + output_field=BooleanField(), + ) + ) + ) class SafetyLayerObjectManager(AuditingManager): - pass + def get_queryset(self) -> models.QuerySet: + return ( + super() + .get_queryset() + .annotate( + is_version=Case( + When(working_version_id__isnull=False, then=True), + When(working_version_id__isnull=True, then=False), + output_field=BooleanField(), + ) + ) + ) class ConsentFormObjectManager(AuditingManager): - pass + def get_queryset(self) -> models.QuerySet: + return ( + super() + .get_queryset() + .annotate( + is_version=Case( + When(working_version_id__isnull=False, then=True), + When(working_version_id__isnull=True, then=False), + output_field=BooleanField(), + ) + ) + ) class SyntheticVoiceObjectManager(AuditingManager): @@ -115,6 +155,13 @@ def get_working_version(self) -> "Experiment": return self return self.working_version + def get_working_version_id(self) -> int: + return self.working_version_id if self.working_version_id else self.id + + @property + def has_versions(self): + return self.versions.exists() + @audit_fields(*model_audit_fields.SOURCE_MATERIAL_FIELDS, audit_special_queryset_writes=True) class SourceMaterial(BaseTeamModel, VersionsMixin): @@ -181,7 +228,22 @@ def get_absolute_url(self): return reverse("experiments:safety_edit", args=[self.team.slug, self.id]) -class Survey(BaseTeamModel): +class SurveyObjectManager(models.Manager): + def get_queryset(self) -> models.QuerySet: + return ( + super() + .get_queryset() + .annotate( + is_version=Case( + When(working_version_id__isnull=False, then=True), + When(working_version_id__isnull=True, then=False), + output_field=BooleanField(), + ) + ) + ) + + +class Survey(BaseTeamModel, VersionsMixin): """ A survey. """ @@ -196,6 +258,14 @@ class Survey(BaseTeamModel): " Survey link: {survey_link}" ), ) + working_version = models.ForeignKey( + "self", + on_delete=models.CASCADE, + null=True, + blank=True, + related_name="versions", + ) + objects = SurveyObjectManager() class Meta: ordering = ["name"] @@ -216,7 +286,7 @@ def get_absolute_url(self): @audit_fields(*model_audit_fields.CONSENT_FORM_FIELDS, audit_special_queryset_writes=True) -class ConsentForm(BaseTeamModel): +class ConsentForm(BaseTeamModel, VersionsMixin): """ Custom markdown consent form to be used by experiments. """ @@ -233,6 +303,13 @@ class ConsentForm(BaseTeamModel): default="Respond with '1' if you agree", help_text=("Use this text to tell the user to respond with '1' in order to give their consent"), ) + working_version = models.ForeignKey( + "self", + on_delete=models.CASCADE, + null=True, + blank=True, + related_name="versions", + ) class Meta: ordering = ["name"] @@ -470,6 +547,10 @@ class Experiment(BaseTeamModel, VersionsMixin): version_number = models.PositiveIntegerField(default=1) is_default_version = models.BooleanField(default=False) is_archived = models.BooleanField(default=False) + version_description = models.TextField( + blank=True, + default="", + ) objects = ExperimentObjectManager() class Meta: @@ -512,16 +593,17 @@ def tools_enabled(self): def event_triggers(self): return [*self.timeout_triggers.all(), *self.static_triggers.all()] - @property - def has_versions(self): - return self.versions.count() > 0 - @property def version_display(self) -> str: if self.is_working_version: return "" return f"v{self.version_number}" + @cached_property + def default_version(self) -> "Experiment": + """Returns the default experiment, or if there is none, the working experiment""" + return Experiment.objects.get_default_or_working(self) + def get_chat_model(self): service = self.get_llm_service() return service.get_chat_model(self.llm, self.temperature) @@ -536,7 +618,7 @@ def get_api_url(self): return absolute_url(reverse("api:openai-chat-completions", args=[self.public_id])) @transaction.atomic() - def create_new_version(self): + def create_new_version(self, version_description: str | None = None, make_default: bool = False): """ Creates a copy of an experiment as a new version of the original experiment. """ @@ -547,30 +629,47 @@ def create_new_version(self): # Fetch a new instance so the previous instance reference isn't simply being updated. I am not 100% sure # why simply chaing the pk, id and _state.adding wasn't enough. new_version = super().create_new_version(save=False) + new_version.version_description = version_description or "" new_version.public_id = uuid4() new_version.version_number = version_number - if self.source_material: - new_version.source_material = self.source_material.create_new_version() - if new_version.version_number == 1: + self._copy_attr_to_new_version("source_material", new_version) + self._copy_attr_to_new_version("consent_form", new_version) + self._copy_attr_to_new_version("pre_survey", new_version) + self._copy_attr_to_new_version("post_survey", new_version) + + if new_version.version_number == 1 or make_default: new_version.is_default_version = True + + if make_default: + self.versions.filter(is_default_version=True).update( + is_default_version=False, audit_action=AuditAction.AUDIT + ) + new_version.save() - self.copy_safety_layers_to_new_version(new_version) - self.copy_routes_to_new_version(new_version) - self.copy_static_triggers_to_new_version(new_version) - self.copy_timeout_triggers_to_new_version(new_version) + self._copy_safety_layers_to_new_version(new_version) + self._copy_routes_to_new_version(new_version) + self.copy_trigger_to_new_version(trigger_queryset=self.static_triggers, new_version=new_version) + self.copy_trigger_to_new_version(trigger_queryset=self.timeout_triggers, new_version=new_version) new_version.files.set(self.files.all()) return new_version - def copy_safety_layers_to_new_version(self, new_version: "Experiment"): + def _copy_attr_to_new_version(self, attr_name, new_version: "Experiment"): + """Copies the attribute `attr_name` to the new version by creating a new version of the related record and + linking that to `new_version` + """ + if instance := getattr(self, attr_name): + setattr(new_version, attr_name, instance.create_new_version()) + + def _copy_safety_layers_to_new_version(self, new_version: "Experiment"): duplicated_layers = [] for layer in self.safety_layers.all(): duplicated_layers.append(layer.create_new_version()) new_version.safety_layers.set(duplicated_layers) - def copy_routes_to_new_version(self, new_version: "Experiment"): + def _copy_routes_to_new_version(self, new_version: "Experiment"): """ This copies the experiment routes where this experiment is the parent and sets the new parent to the new version. @@ -578,13 +677,9 @@ def copy_routes_to_new_version(self, new_version: "Experiment"): for route in self.child_links.all(): route.create_new_version(new_version) - def copy_static_triggers_to_new_version(self, new_version: "Experiment"): - for static_trigger in self.static_triggers.all(): - static_trigger.create_new_version(new_experiment=new_version) - - def copy_timeout_triggers_to_new_version(self, new_version: "Experiment"): - for timeout_trigger in self.timeout_triggers.all(): - timeout_trigger.create_new_version(new_experiment=new_version) + def copy_trigger_to_new_version(self, trigger_queryset, new_version): + for trigger in trigger_queryset.all(): + trigger.create_new_version(new_experiment=new_version) @property def is_public(self) -> bool: @@ -616,7 +711,13 @@ class ExperimentRoute(BaseTeamModel, VersionsMixin): @classmethod def eligible_children(cls, team: Team, parent: Experiment | None = None): - """Returns a list of experiments: that are not parents, and are not children of the current experiment""" + """ + Returns a list of experiments that fit the following criteria: + - They are not the same as the parent + - they are not parents + - they are not not children of the current experiment + - they are not part of the current experiment's version family + """ parent_ids = cls.objects.filter(team=team).values_list("parent_id", flat=True).distinct() if parent: @@ -626,6 +727,7 @@ def eligible_children(cls, team: Team, parent: Experiment | None = None): .exclude(id__in=child_ids) .exclude(id__in=parent_ids) .exclude(id=parent.id) + .exclude(working_version_id=parent.id) ) else: eligible_experiments = Experiment.objects.filter(team=team).exclude(id__in=parent_ids) @@ -932,11 +1034,11 @@ def user_already_engaged(self) -> bool: def get_platform_name(self) -> str: return self.experiment_channel.get_platform_display() - def get_pre_survey_link(self): - return self.experiment.pre_survey.get_link(self.participant, self) + def get_pre_survey_link(self, experiment_version: Experiment): + return experiment_version.pre_survey.get_link(self.participant, self) - def get_post_survey_link(self): - return self.experiment.post_survey.get_link(self.participant, self) + def get_post_survey_link(self, experiment_version: Experiment): + return experiment_version.post_survey.get_link(self.participant, self) def is_stale(self) -> bool: """A Channel Session is considered stale if the experiment that the channel points to differs from the @@ -1026,6 +1128,16 @@ def participant_data_from_experiment(self) -> dict: except ParticipantData.DoesNotExist: return {} + @cached_property + def experiment_version(self) -> Experiment: + """Returns the default experiment, or if there is none, the working experiment""" + return self.experiment.default_version + + @cached_property + def working_experiment(self) -> Experiment: + """Returns the default experiment, or if there is none, the working experiment""" + return self.experiment.get_working_version() + def get_participant_timezone(self): participant_data = self.participant_data_from_experiment return participant_data.get("timezone") diff --git a/apps/experiments/tables.py b/apps/experiments/tables.py index 58d6eb826..fbcd5b2e3 100644 --- a/apps/experiments/tables.py +++ b/apps/experiments/tables.py @@ -161,6 +161,7 @@ class Meta: class ExperimentVersionsTable(tables.Table): version_number = columns.Column(verbose_name="Version Number", accessor="version_number") created_at = columns.Column(verbose_name="Created On", accessor="created_at") + version_description = columns.Column(verbose_name="Description", default="") is_default = columns.TemplateColumn( template_code="""{% if record.is_default_version %} diff --git a/apps/experiments/tasks.py b/apps/experiments/tasks.py index 0b092db05..a9e8924f5 100644 --- a/apps/experiments/tasks.py +++ b/apps/experiments/tasks.py @@ -8,7 +8,7 @@ from apps.channels.datamodels import Attachment, BaseMessage from apps.chat.bots import create_conversation from apps.chat.channels import WebChannel -from apps.experiments.models import ExperimentSession, PromptBuilderHistory, SourceMaterial +from apps.experiments.models import Experiment, ExperimentSession, PromptBuilderHistory, SourceMaterial from apps.service_providers.models import LlmProvider from apps.teams.utils import current_team from apps.users.models import CustomUser @@ -17,13 +17,14 @@ @shared_task(bind=True, base=TaskbadgerTask) def get_response_for_webchat_task( - self, experiment_session_id: int, message_text: str, attachments: list | None = None + self, experiment_session_id: int, experiment_id: int, message_text: str, attachments: list | None = None ) -> str: experiment_session = ExperimentSession.objects.select_related("experiment", "experiment__team").get( id=experiment_session_id ) + experiment = Experiment.objects.get(id=experiment_id) web_channel = WebChannel( - experiment_session.experiment, + experiment, experiment_session.experiment_channel, experiment_session=experiment_session, ) diff --git a/apps/experiments/tests/test_consent_views.py b/apps/experiments/tests/test_consent_views.py new file mode 100644 index 000000000..795d4867b --- /dev/null +++ b/apps/experiments/tests/test_consent_views.py @@ -0,0 +1,17 @@ +from django.test import RequestFactory +from django.urls import reverse + +from apps.experiments.views.consent import ConsentFormTableView + + +class TestConsentFormTableView: + def test_get_queryset(self, experiment): + assert experiment.consent_form is not None + experiment.create_new_version() + + request = RequestFactory().get(reverse("experiments:consent_table", args=[experiment.team.slug])) + request.team = experiment.team + view = ConsentFormTableView() + view.request = request + for consent_form in view.get_queryset().all(): + assert consent_form.is_working_version is True diff --git a/apps/experiments/tests/test_models.py b/apps/experiments/tests/test_models.py index e419546df..69a3f2584 100644 --- a/apps/experiments/tests/test_models.py +++ b/apps/experiments/tests/test_models.py @@ -20,6 +20,7 @@ ExperimentSessionFactory, ParticipantFactory, SourceMaterialFactory, + SurveyFactory, SyntheticVoiceFactory, VersionedExperimentFactory, ) @@ -498,7 +499,26 @@ def test_create_new_route_version(self, versioned): @pytest.mark.django_db() -class TestExperimentVersioning: +class TestExperimentRoute: + def test_eligible_children(self): + parent = ExperimentFactory() + experiment_version = parent.create_new_version() + experiment1 = ExperimentFactory(team=parent.team) + experiment2 = ExperimentFactory(team=parent.team) + + queryset = ExperimentRoute.eligible_children(team=parent.team, parent=parent) + assert parent not in queryset + assert experiment_version not in queryset + assert experiment1 in queryset + assert experiment2 in queryset + assert len(queryset) == 2 + + queryset = ExperimentRoute.eligible_children(team=parent.team) + assert len(queryset) == 4 + + +@pytest.mark.django_db() +class TestExperimentModel: def test_working_experiment_cannot_be_the_default_version(self): with pytest.raises(ValueError, match="A working experiment cannot be a default version"): ExperimentFactory(is_default_version=True, working_version=None) @@ -549,9 +569,16 @@ def _setup_original_experiment(self): # Setup Timeout Trigger TimeoutTriggerFactory(experiment=experiment) + + # Surveys + pre_survey = SurveyFactory(team=team) + post_survey = SurveyFactory(team=team) + experiment.pre_survey = pre_survey + experiment.post_survey = post_survey + + experiment.save() return experiment - @pytest.mark.django_db() def test_first_version_is_automatically_the_default(self): experiment = ExperimentFactory() new_version = experiment.create_new_version() @@ -562,13 +589,12 @@ def test_first_version_is_automatically_the_default(self): assert another_version.version_number == 2 assert not another_version.is_default_version - @pytest.mark.django_db() def test_create_experiment_version(self): original_experiment = self._setup_original_experiment() assert original_experiment.version_number == 1 - new_version = original_experiment.create_new_version() + new_version = original_experiment.create_new_version("tis a new version") original_experiment.refresh_from_db() assert new_version != original_experiment @@ -577,6 +603,7 @@ def test_create_experiment_version(self): assert new_version.version_number == 1 assert new_version.is_default_version is True assert new_version.working_version == original_experiment + assert new_version.version_description == "tis a new version" _compare_models( original=original_experiment, new=new_version, @@ -587,13 +614,20 @@ def test_create_experiment_version(self): "working_version_id", "version_number", "is_default_version", + "consent_form_id", + "pre_survey_id", + "post_survey_id", + "version_description", ], ) self._assert_safety_layers_are_duplicated(original_experiment, new_version) - self._assert_source_material_is_duplicated(original_experiment, new_version) self._assert_files_are_duplicated(original_experiment, new_version) self._assert_triggers_are_duplicated("static", original_experiment, new_version) self._assert_triggers_are_duplicated("timeout", original_experiment, new_version) + self._assert_attribute_duplicated("source_material", original_experiment, new_version) + self._assert_attribute_duplicated("consent_form", original_experiment, new_version) + self._assert_attribute_duplicated("pre_survey", original_experiment, new_version) + self._assert_attribute_duplicated("post_survey", original_experiment, new_version) another_new_version = original_experiment.create_new_version() original_experiment.refresh_from_db() @@ -641,6 +675,13 @@ def _assert_triggers_are_duplicated(self, trigger_type, original_experiment, new expected_changed_fields=["id", "action_id", "working_version_id", "experiment_id"], ) + def _assert_attribute_duplicated(self, attr_name, original_experiment, new_version): + _compare_models( + original=getattr(original_experiment, attr_name), + new=getattr(new_version, attr_name), + expected_changed_fields=["id", "working_version_id"], + ) + @pytest.mark.django_db() class TestExperimentObjectManager: @@ -669,6 +710,18 @@ def test_working_versions_queryset(self): # All experiments in this queryset should have versions assert working_version.has_versions is True + def test_archived_experiments_are_filtered_out(self): + """Default queries should exclude archived experiments""" + experiment = ExperimentFactory() + new_version = experiment.create_new_version() + assert Experiment.objects.count() == 2 + new_version.is_archived = True + new_version.save() + assert Experiment.objects.count() == 1 + + # To get all experiment,s use the dedicated object method + assert Experiment.objects.get_all().count() == 2 + def _compare_models(original, new, expected_changed_fields: list) -> set: """ @@ -688,4 +741,7 @@ def _compare_models(original, new, expected_changed_fields: list) -> set: if field_value != new_dict[field_name]: changed_fields.add(field_name) - assert changed_fields.difference(set(expected_changed_fields)) == set() + field_difference = changed_fields.difference(set(expected_changed_fields)) + assert ( + field_difference == set() + ), f"These fields differ between the experiment versions, but should not: {field_difference}" diff --git a/apps/experiments/tests/test_source_material_views.py b/apps/experiments/tests/test_source_material_views.py new file mode 100644 index 000000000..d373a30ce --- /dev/null +++ b/apps/experiments/tests/test_source_material_views.py @@ -0,0 +1,21 @@ +from django.test import RequestFactory +from django.urls import reverse + +from apps.experiments.models import SourceMaterial +from apps.experiments.views.source_material import SourceMaterialTableView + + +class TestSourceMaterialTableView: + def test_get_queryset(self, experiment): + experiment.source_material = SourceMaterial.objects.create( + team=experiment.team, owner=experiment.owner, topic="Testing", description="descripto", material="Meh" + ) + experiment.save() + experiment.create_new_version() + assert SourceMaterial.objects.count() == 2 + + request = RequestFactory().get(reverse("experiments:source_material_table", args=[experiment.team.slug])) + request.team = experiment.team + view = SourceMaterialTableView() + view.request = request + assert list(view.get_queryset().all()) == [experiment.source_material] diff --git a/apps/experiments/tests/test_survey_views.py b/apps/experiments/tests/test_survey_views.py new file mode 100644 index 000000000..3a41dcaa7 --- /dev/null +++ b/apps/experiments/tests/test_survey_views.py @@ -0,0 +1,18 @@ +from django.test import RequestFactory +from django.urls import reverse + +from apps.experiments.models import Survey +from apps.experiments.views.survey import SurveyTableView + + +class TestSurveyTableView: + def test_get_queryset(self, experiment): + assert experiment.pre_survey is not None + experiment.create_new_version() + assert Survey.objects.count() == 2 + + request = RequestFactory().get(reverse("experiments:survey_table", args=[experiment.team.slug])) + request.team = experiment.team + view = SurveyTableView() + view.request = request + assert list(view.get_queryset().all()) == [experiment.pre_survey] diff --git a/apps/experiments/tests/test_views.py b/apps/experiments/tests/test_views.py index 3fbca9275..e8a3b9c59 100644 --- a/apps/experiments/tests/test_views.py +++ b/apps/experiments/tests/test_views.py @@ -5,6 +5,7 @@ import pytest from django.conf import settings from django.core.exceptions import ValidationError +from django.test import RequestFactory from django.urls import reverse from waffle.testutils import override_flag @@ -17,7 +18,7 @@ ParticipantData, VoiceResponseBehaviours, ) -from apps.experiments.views.experiment import ExperimentForm, validate_prompt_variables +from apps.experiments.views.experiment import ExperimentForm, ExperimentTableView, validate_prompt_variables from apps.teams.backends import add_user_to_team from apps.utils.factories.assistants import OpenAiAssistantFactory from apps.utils.factories.experiment import ( @@ -377,3 +378,21 @@ def test_experiment_session_message_view_creates_files(delay_mock, experiment, c assert ci_resource.files.filter(name="ci.text").exists() fs_resource = session.chat.attachments.get(tool_type="file_search") assert fs_resource.files.filter(name="fs.text").exists() + + +class TestExperimentTableView: + def test_get_queryset(self, experiment): + team = experiment.team + experiment.create_new_version() + archived_working = ExperimentFactory(team=team) + archived_version = archived_working.create_new_version() + archived_version.is_archived = archived_working.is_archived = True + archived_version.save() + archived_working.save() + assert Experiment.objects.get_all().count() == 4 + + request = RequestFactory().get(reverse("experiments:table", args=[team.slug])) + request.team = team + view = ExperimentTableView() + view.request = request + assert list(view.get_queryset().all()) == [experiment] diff --git a/apps/experiments/views/consent.py b/apps/experiments/views/consent.py index 5b3981824..afb9d376c 100644 --- a/apps/experiments/views/consent.py +++ b/apps/experiments/views/consent.py @@ -32,7 +32,7 @@ class ConsentFormTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - return ConsentForm.objects.filter(team=self.request.team) + return ConsentForm.objects.filter(team=self.request.team, is_version=False) class CreateConsentForm(CreateView): diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index c03073ade..f402c5640 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -104,7 +104,7 @@ class ExperimentTableView(SingleTableView, PermissionRequiredMixin): permission_required = "experiments.view_experiment" def get_queryset(self): - query_set = Experiment.objects.filter(team=self.request.team, working_version__isnull=True) + query_set = Experiment.objects.filter(team=self.request.team, working_version__isnull=True, is_archived=False) search = self.request.GET.get("search") if search: search_vector = SearchVector("name", weight="A") + SearchVector("description", weight="B") @@ -240,11 +240,11 @@ def __init__(self, request, *args, **kwargs): self.fields["voice_provider"].queryset = team.voiceprovider_set.exclude( syntheticvoice__service__in=exclude_services ) - self.fields["safety_layers"].queryset = team.safetylayer_set - self.fields["source_material"].queryset = team.sourcematerial_set - self.fields["pre_survey"].queryset = team.survey_set - self.fields["post_survey"].queryset = team.survey_set - self.fields["consent_form"].queryset = team.consentform_set + self.fields["safety_layers"].queryset = team.safetylayer_set.exclude(is_version=True) + self.fields["source_material"].queryset = team.sourcematerial_set.exclude(is_version=True) + self.fields["pre_survey"].queryset = team.survey_set.exclude(is_version=True) + self.fields["post_survey"].queryset = team.survey_set.exclude(is_version=True) + self.fields["consent_form"].queryset = team.consentform_set.exclude(is_version=True) self.fields["synthetic_voice"].queryset = SyntheticVoice.get_for_team(team, exclude_services) self.fields["trace_provider"].queryset = team.traceprovider_set @@ -479,10 +479,8 @@ class DeleteFileFromExperiment(BaseDeleteFileView): class ExperimentVersionForm(forms.ModelForm): class Meta: model = Experiment - fields = ["is_default_version"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + fields = ["version_description", "is_default_version"] + help_texts = {"version_description": "A description of this version, or what changed from the previous version"} class CreateExperimentVersion(LoginAndTeamRequiredMixin, CreateView): @@ -496,7 +494,9 @@ class CreateExperimentVersion(LoginAndTeamRequiredMixin, CreateView): def form_valid(self, form): working_experiment = self.get_object() - working_experiment.create_new_version() + description = form.cleaned_data["version_description"] + is_default = form.cleaned_data["is_default_version"] + working_experiment.create_new_version(version_description=description, make_default=is_default) return HttpResponseRedirect(self.get_success_url()) def get_success_url(self): @@ -693,10 +693,11 @@ def update_delete_channel(request, team_slug: str, experiment_id: int, channel_i @require_POST @login_and_team_required def start_authed_web_session(request, team_slug: str, experiment_id: int): + """Start an authed web session with the chosen experiment, be it a specific version or not""" experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) session = WebChannel.start_new_session( - experiment, + working_experiment=experiment, participant_user=request.user, participant_identifier=request.user.email, timezone=request.session.get("detected_tz", None), @@ -712,14 +713,15 @@ def experiment_chat_session(request, team_slug: str, experiment_id: int, session session = get_object_or_404( ExperimentSession, participant__user=request.user, experiment_id=experiment_id, id=session_id ) + experiment_version = experiment.default_version + version_specific_vars = { + "assistant": experiment_version.assistant, + "experiment_name": experiment_version.name, + } return TemplateResponse( request, "experiments/experiment_chat.html", - { - "experiment": experiment, - "session": session, - "active_tab": "experiments", - }, + {"experiment": experiment, "session": session, "active_tab": "experiments", **version_specific_vars}, ) @@ -749,7 +751,17 @@ def experiment_session_message(request, team_slug: str, experiment_id: int, sess tool_resource.files.add(*created_files) - result = get_response_for_webchat_task.delay(session.id, message_text, attachments=attachments) + experiment_version = experiment.default_version + result = get_response_for_webchat_task.delay( + experiment_session_id=session.id, + experiment_id=experiment_version.id, + message_text=message_text, + attachments=attachments, + ) + + version_specific_vars = { + "assistant": experiment_version.assistant, + } return TemplateResponse( request, "experiments/chat/experiment_response_htmx.html", @@ -759,6 +771,7 @@ def experiment_session_message(request, team_slug: str, experiment_id: int, sess "message_text": message_text, "task_id": result.task_id, "created_files": created_files, + **version_specific_vars, }, ) @@ -827,10 +840,11 @@ def start_session_public(request, team_slug: str, experiment_id: str): # old links dont have uuids raise Http404 - if not experiment.is_public: + experiment_version = experiment.default_version + if not experiment_version.is_public: raise Http404 - consent = experiment.consent_form + consent = experiment_version.consent_form user = get_real_user_or_none(request.user) if request.method == "POST": form = ConsentForm(consent, request.POST, initial={"identifier": user.email if user else None}) @@ -842,7 +856,7 @@ def start_session_public(request, team_slug: str, experiment_id: str): identifier = user.email if user else str(uuid.uuid4()) session = WebChannel.start_new_session( - experiment, + working_experiment=experiment, participant_user=user, participant_identifier=identifier, timezone=request.session.get("detected_tz", None), @@ -853,12 +867,16 @@ def start_session_public(request, team_slug: str, experiment_id: str): form = ConsentForm( consent, initial={ - "experiment_id": experiment.id, + "experiment_id": experiment_version.id, "identifier": user.email if user else None, }, ) consent_notice = consent.get_rendered_content() + version_specific_vars = { + "experiment_name": experiment_version.name, + "experiment_description": experiment_version.description, + } return TemplateResponse( request, "experiments/start_experiment_session.html", @@ -867,6 +885,7 @@ def start_session_public(request, team_slug: str, experiment_id: str): "experiment": experiment, "consent_notice": mark_safe(consent_notice), "form": form, + **version_specific_vars, }, ) @@ -875,17 +894,18 @@ def start_session_public(request, team_slug: str, experiment_id: str): @permission_required("experiments.invite_participants", raise_exception=True) def experiment_invitations(request, team_slug: str, experiment_id: str): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) + experiment_version = experiment.default_version sessions = experiment.sessions.order_by("-created_at").filter( status__in=["setup", "pending"], participant__isnull=False, ) - form = ExperimentInvitationForm(initial={"experiment_id": experiment.id}) + form = ExperimentInvitationForm(initial={"experiment_id": experiment_id}) if request.method == "POST": post_form = ExperimentInvitationForm(request.POST) if post_form.is_valid(): if ExperimentSession.objects.filter( team=request.team, - experiment=experiment, + experiment_id=experiment_id, status__in=["setup", "pending"], participant__identifier=post_form.cleaned_data["email"], ).exists(): @@ -894,7 +914,7 @@ def experiment_invitations(request, team_slug: str, experiment_id: str): else: with transaction.atomic(): session = WebChannel.start_new_session( - experiment=experiment, + experiment, participant_identifier=post_form.cleaned_data["email"], session_status=SessionStatus.SETUP, timezone=request.session.get("detected_tz", None), @@ -904,14 +924,14 @@ def experiment_invitations(request, team_slug: str, experiment_id: str): else: form = post_form + version_specific_vars = { + "experiment_name": experiment_version.name, + "experiment_description": experiment_version.description, + } return TemplateResponse( request, "experiments/experiment_invitations.html", - { - "invitation_form": form, - "experiment": experiment, - "sessions": sessions, - }, + {"invitation_form": form, "experiment": experiment, "sessions": sessions, **version_specific_vars}, ) @@ -947,16 +967,16 @@ def send_invitation(request, team_slug: str, experiment_id: str, session_id: str def _record_consent_and_redirect(request, team_slug: str, experiment_session: ExperimentSession): # record consent, update status experiment_session.consent_date = timezone.now() - if experiment_session.experiment.pre_survey: + if experiment_session.experiment_version.pre_survey: experiment_session.status = SessionStatus.PENDING_PRE_SURVEY - redirct_url_name = "experiments:experiment_pre_survey" + redirect_url_name = "experiments:experiment_pre_survey" else: experiment_session.status = SessionStatus.ACTIVE - redirct_url_name = "experiments:experiment_chat" + redirect_url_name = "experiments:experiment_chat" experiment_session.save() response = HttpResponseRedirect( reverse( - redirct_url_name, + redirect_url_name, args=[team_slug, experiment_session.experiment.public_id, experiment_session.external_id], ) ) @@ -967,35 +987,38 @@ def _record_consent_and_redirect(request, team_slug: str, experiment_session: Ex def start_session_from_invite(request, team_slug: str, experiment_id: str, session_id: str): experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) experiment_session = get_object_or_404(ExperimentSession, experiment=experiment, external_id=session_id) - consent = experiment.consent_form + default_version = experiment.default_version + consent = default_version.consent_form initial = { - "experiment_id": experiment.id, + "participant_id": experiment_session.participant.id, + "identifier": experiment_session.participant.identifier, } if not experiment_session.participant: raise Http404() - initial["participant_id"] = experiment_session.participant.id - initial["identifier"] = experiment_session.participant.identifier - if request.method == "POST": form = ConsentForm(consent, request.POST, initial=initial) if form.is_valid(): - WebChannel.check_and_process_seed_message(experiment_session) return _record_consent_and_redirect(request, team_slug, experiment_session) else: form = ConsentForm(consent, initial=initial) consent_notice = consent.get_rendered_content() + version_specific_vars = { + "experiment_name": default_version.name, + "experiment_description": default_version.description, + } return TemplateResponse( request, "experiments/start_experiment_session.html", { "active_tab": "experiments", - "experiment": experiment, + "experiment": default_version, "consent_notice": mark_safe(consent_notice), "form": form, + **version_specific_vars, }, ) @@ -1016,6 +1039,14 @@ def experiment_pre_survey(request, team_slug: str, experiment_id: str, session_i ) else: form = SurveyCompletedForm() + + default_version = request.experiment.default_version + experiment_session = request.experiment_session + version_specific_vars = { + "experiment_name": default_version.name, + "experiment_description": default_version.description, + "pre_survey_link": experiment_session.get_pre_survey_link(default_version), + } return TemplateResponse( request, "experiments/pre_survey.html", @@ -1023,7 +1054,8 @@ def experiment_pre_survey(request, team_slug: str, experiment_id: str, session_i "active_tab": "experiments", "form": form, "experiment": request.experiment, - "experiment_session": request.experiment_session, + "experiment_session": experiment_session, + **version_specific_vars, }, ) @@ -1058,6 +1090,7 @@ def experiment_review(request, team_slug: str, experiment_id: str, session_id: s form = None survey_link = None survey_text = None + experiment_version = request.experiment.default_version if request.method == "POST": # no validation needed request.experiment_session.status = SessionStatus.COMPLETE @@ -1066,11 +1099,17 @@ def experiment_review(request, team_slug: str, experiment_id: str, session_id: s return HttpResponseRedirect( reverse("experiments:experiment_complete", args=[team_slug, experiment_id, session_id]) ) - elif request.experiment.post_survey: + elif experiment_version.post_survey: form = SurveyCompletedForm() - survey_link = request.experiment_session.get_post_survey_link() - survey_text = request.experiment.post_survey.confirmation_text.format(survey_link=survey_link) - + survey_link = request.experiment_session.get_post_survey_link(experiment_version) + survey_text = experiment_version.post_survey.confirmation_text.format(survey_link=survey_link) + + version_specific_vars = { + "experiment.post_survey": experiment_version.post_survey, + "survey_link": survey_link, + "survey_text": survey_text, + "experiment_name": experiment_version.name, + } return TemplateResponse( request, "experiments/experiment_review.html", @@ -1078,10 +1117,9 @@ def experiment_review(request, team_slug: str, experiment_id: str, session_id: s "experiment": request.experiment, "experiment_session": request.experiment_session, "active_tab": "experiments", - "survey_link": survey_link, - "survey_text": survey_text, "form": form, "available_tags": [t.name for t in Tag.objects.filter(team__slug=team_slug, is_system_tag=False).all()], + **version_specific_vars, }, ) diff --git a/apps/experiments/views/safety.py b/apps/experiments/views/safety.py index 9abcca6da..ce4492b2d 100644 --- a/apps/experiments/views/safety.py +++ b/apps/experiments/views/safety.py @@ -30,7 +30,7 @@ class SafetyLayerTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - return SafetyLayer.objects.filter(team=self.request.team) + return SafetyLayer.objects.filter(team=self.request.team, is_version=False) class CreateSafetyLayer(CreateView): diff --git a/apps/experiments/views/source_material.py b/apps/experiments/views/source_material.py index a550730d6..7e0c92825 100644 --- a/apps/experiments/views/source_material.py +++ b/apps/experiments/views/source_material.py @@ -33,7 +33,7 @@ class SourceMaterialTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - query_set = SourceMaterial.objects.filter(team=self.request.team) + query_set = SourceMaterial.objects.filter(team=self.request.team, is_version=False) search = self.request.GET.get("search") if search: search_vector = SearchVector("topic", weight="A") + SearchVector("description", weight="B") diff --git a/apps/experiments/views/survey.py b/apps/experiments/views/survey.py index 85f18e2a7..ea9aa589e 100644 --- a/apps/experiments/views/survey.py +++ b/apps/experiments/views/survey.py @@ -33,7 +33,7 @@ class SurveyTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - return Survey.objects.filter(team=self.request.team) + return Survey.objects.filter(team=self.request.team, is_version=False) class CreateSurvey(CreateView): diff --git a/apps/service_providers/llm_service/state.py b/apps/service_providers/llm_service/state.py index 6ab97c16c..38a4c5ed4 100644 --- a/apps/service_providers/llm_service/state.py +++ b/apps/service_providers/llm_service/state.py @@ -147,6 +147,11 @@ def save_message_to_history(self, message: str, type_: ChatMessageType, experime ) chat_message.add_tag(tag, team=self.session.team, added_by=None) + if type_ == ChatMessageType.AI and not self.experiment.is_working_version: + chat_message.add_system_tag( + tag=self.experiment.version_display, tag_category=TagCategories.EXPERIMENT_VERSION + ) + def check_cancellation(self): self.session.chat.refresh_from_db(fields=["metadata"]) # temporary mechanism to cancel the chat @@ -262,6 +267,12 @@ def save_message_to_history( category=TagCategories.BOT_RESPONSE, ) chat_message.add_tag(tag, team=self.session.team, added_by=None) + + if type_ == ChatMessageType.AI and not self.experiment.is_working_version: + chat_message.add_system_tag( + tag=self.experiment.version_display, tag_category=TagCategories.EXPERIMENT_VERSION + ) + return chat_message def get_tools(self): diff --git a/apps/service_providers/tests/test_runnables.py b/apps/service_providers/tests/test_runnables.py index 51d3cf27e..0692cee86 100644 --- a/apps/service_providers/tests/test_runnables.py +++ b/apps/service_providers/tests/test_runnables.py @@ -6,6 +6,7 @@ import pytest from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from apps.annotations.models import TagCategories from apps.chat.models import Chat, ChatMessage, ChatMessageType from apps.experiments.models import AgentTools, SourceMaterial from apps.service_providers.llm_service.runnables import ( @@ -78,6 +79,19 @@ def test_runnable(runnable, session, fake_llm_service): assert "tools" not in fake_llm_service.llm.get_calls()[0].kwargs +@pytest.mark.django_db() +@freezegun.freeze_time("2024-02-08 13:00:08.877096+00:00") +def test_bot_message_is_tagged_with_experiment_version(runnable, session, fake_llm_service): + experiment_version = session.experiment.create_new_version() + experiment_version.get_llm_service = lambda: fake_llm_service + chain = runnable.build(state=ChatExperimentState(experiment_version, session)) + chain.invoke("hi") + ai_message = session.chat.messages.get(message_type=ChatMessageType.AI) + tag = ai_message.tags.first() + assert tag.name == "v1" + assert tag.category == TagCategories.EXPERIMENT_VERSION + + @pytest.mark.django_db() @freezegun.freeze_time("2024-02-08 13:00:08.877096+00:00") def test_runnable_with_source_material(runnable, session, fake_llm_service): diff --git a/apps/slack/slack_listeners.py b/apps/slack/slack_listeners.py index 649e04350..ab409f78f 100644 --- a/apps/slack/slack_listeners.py +++ b/apps/slack/slack_listeners.py @@ -52,11 +52,13 @@ def respond_to_message(event, context: BoltContext, session=None): channel_id = event.get("channel") thread_ts = event.get("thread_ts", None) or event["ts"] experiment_channel = get_experiment_channel(channel_id) + if not experiment_channel: context.say("There are no bots associated with this channel.", thread_ts=thread_ts) return - if session and session.team_id != experiment_channel.experiment.team_id: + experiment = experiment_channel.experiment + if session and session.team_id != experiment.team_id: raise TeamAccessException("Session and Channel teams do not match") slack_user = event.get("user") @@ -64,7 +66,10 @@ def respond_to_message(event, context: BoltContext, session=None): if not session: external_id = make_session_external_id(channel_id, thread_ts) session = SlackChannel.start_new_session( - experiment_channel.experiment, experiment_channel, slack_user, session_external_id=external_id + working_experiment=experiment, + experiment_channel=experiment_channel, + participant_identifier=slack_user, + session_external_id=external_id, ) # strip out the mention @@ -75,7 +80,7 @@ def respond_to_message(event, context: BoltContext, session=None): # Set `send_response_to_user` to `False` to prevent it sending the message since we're going to send # it here using the already authenticated client. - ocs_channel = SlackChannel(experiment_channel.experiment, experiment_channel, session, send_response_to_user=False) + ocs_channel = SlackChannel(experiment.default_version, experiment_channel, session, send_response_to_user=False) response = ocs_channel.new_user_message(message) context.say(response, thread_ts=thread_ts) diff --git a/templates/experiments/chat/chat_ui.html b/templates/experiments/chat/chat_ui.html index be5d6b734..5042e7245 100644 --- a/templates/experiments/chat/chat_ui.html +++ b/templates/experiments/chat/chat_ui.html @@ -8,7 +8,7 @@
{% include "experiments/chat/components/system_icon.html" %}
-

Hello, you can ask me anything you want about {{ experiment.name }}.

+

Hello, you can ask me anything you want about {{ experiment_name }}.

{% endif %} diff --git a/templates/experiments/chat/input_bar.html b/templates/experiments/chat/input_bar.html index e87be5069..1e54c91d0 100644 --- a/templates/experiments/chat/input_bar.html +++ b/templates/experiments/chat/input_bar.html @@ -13,12 +13,12 @@ hx-post="{% url 'experiments:experiment_session_message' request.team.slug experiment.id session.id %}" hx-swap="outerHTML" hx-indicator="#message-submit" - {% if session.experiment.assistant %}enctype="multipart/form-data"{% endif %} + {% if assistant %}enctype="multipart/form-data"{% endif %} > {% csrf_token %}
- {% if session.experiment.assistant %} + {% if assistant %}
-{% if session.experiment.assistant %} +{% if assistant %}