From a8385bed56dc4b1bb05223cc0b547a6d22084604 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 12:41:44 +0200 Subject: [PATCH 01/46] Add two convenience methods on ExperimentSession --- apps/experiments/models.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/apps/experiments/models.py b/apps/experiments/models.py index c7d9827d5..d050b6975 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -1015,6 +1015,16 @@ def participant_data_from_experiment(self) -> dict: except ParticipantData.DoesNotExist: return {} + @cached_property + def experiment_version(self) -> Experiment: + """Returns the default experiment, or if there is none, the working experiment""" + return Experiment.objects.get_default_or_working(self.experiment) + + @cached_property + def working_experiment(self) -> Experiment: + """Returns the default experiment, or if there is none, the working experiment""" + return self.experiment.get_working_version() + def get_participant_timezone(self): participant_data = self.participant_data_from_experiment return participant_data.get("timezone") From 10e95a9a0ae0605c6c6f4d259012544040eab2ec Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 12:42:54 +0200 Subject: [PATCH 02/46] Update invites to use details from the experiment version. Also fix issue where seed message is generated twice --- apps/chat/channels.py | 4 ++-- apps/experiments/email.py | 2 +- apps/experiments/views/experiment.py | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/apps/chat/channels.py b/apps/chat/channels.py index 989d1da15..992aacac8 100644 --- a/apps/chat/channels.py +++ b/apps/chat/channels.py @@ -551,9 +551,9 @@ def start_new_session( def check_and_process_seed_message(cls, session: ExperimentSession): from apps.experiments.tasks import get_response_for_webchat_task - if session.experiment.seed_message: + if seed_message := session.experiment_version.seed_message: session.seed_task_id = get_response_for_webchat_task.delay( - session.id, message_text=session.experiment.seed_message, attachments=[] + session.id, message_text=seed_message, attachments=[] ).task_id session.save() return session diff --git a/apps/experiments/email.py b/apps/experiments/email.py index 67e3261ef..8322e7cb6 100644 --- a/apps/experiments/email.py +++ b/apps/experiments/email.py @@ -14,7 +14,7 @@ def send_experiment_invitation(experiment_session: ExperimentSession): "session": experiment_session, } send_mail( - subject=_("You're invited to {}!").format(experiment_session.experiment.name), + subject=_("You're invited to {}!").format(experiment_session.experiment_version.name), message=render_to_string("experiments/email/invitation.txt", context=email_context), from_email=settings.DEFAULT_FROM_EMAIL, recipient_list=[experiment_session.participant.email], diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index 56c3fc301..3f6c71a17 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -943,16 +943,16 @@ def send_invitation(request, team_slug: str, experiment_id: str, session_id: str def _record_consent_and_redirect(request, team_slug: str, experiment_session: ExperimentSession): # record consent, update status experiment_session.consent_date = timezone.now() - if experiment_session.experiment.pre_survey: + if experiment_session.experiment_version.pre_survey: experiment_session.status = SessionStatus.PENDING_PRE_SURVEY - redirct_url_name = "experiments:experiment_pre_survey" + redirect_url_name = "experiments:experiment_pre_survey" else: experiment_session.status = SessionStatus.ACTIVE - redirct_url_name = "experiments:experiment_chat" + redirect_url_name = "experiments:experiment_chat" experiment_session.save() response = HttpResponseRedirect( reverse( - redirct_url_name, + redirect_url_name, args=[team_slug, experiment_session.experiment.public_id, experiment_session.external_id], ) ) @@ -963,10 +963,11 @@ def _record_consent_and_redirect(request, team_slug: str, experiment_session: Ex def start_session_from_invite(request, team_slug: str, experiment_id: str, session_id: str): experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) experiment_session = get_object_or_404(ExperimentSession, experiment=experiment, external_id=session_id) - consent = experiment.consent_form + experiment_version = experiment_session.experiment_version + consent = experiment_version.consent_form initial = { - "experiment_id": experiment.id, + "experiment_id": experiment_version.id, } if not experiment_session.participant: raise Http404() @@ -977,7 +978,6 @@ def start_session_from_invite(request, team_slug: str, experiment_id: str, sessi if request.method == "POST": form = ConsentForm(consent, request.POST, initial=initial) if form.is_valid(): - WebChannel.check_and_process_seed_message(experiment_session) return _record_consent_and_redirect(request, team_slug, experiment_session) else: From f3216e579734aec55518df93b53549249fdf7605 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 12:43:17 +0200 Subject: [PATCH 03/46] get_response_for_webchat_task to use the experiment version --- apps/experiments/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/experiments/tasks.py b/apps/experiments/tasks.py index 0b092db05..9cc2124c6 100644 --- a/apps/experiments/tasks.py +++ b/apps/experiments/tasks.py @@ -23,7 +23,7 @@ def get_response_for_webchat_task( id=experiment_session_id ) web_channel = WebChannel( - experiment_session.experiment, + experiment_session.experiment_version, experiment_session.experiment_channel, experiment_session=experiment_session, ) From 1063cf4453b1195fee97a085e1681de0f0d45cff Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 12:47:17 +0200 Subject: [PATCH 04/46] Exclude an experiment's versions from eligible_children in experiment routes --- apps/experiments/models.py | 9 ++++++++- apps/experiments/tests/test_models.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/apps/experiments/models.py b/apps/experiments/models.py index d050b6975..4aa5e2dc8 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -616,7 +616,13 @@ class ExperimentRoute(BaseTeamModel, VersionsMixin): @classmethod def eligible_children(cls, team: Team, parent: Experiment | None = None): - """Returns a list of experiments: that are not parents, and are not children of the current experiment""" + """ + Returns a list of experiments that fit the following criteria: + - They are not the same as the parent + - they are not parents + - they are ot not children of the current experiment + - they are not part of the current experiment's version family + """ parent_ids = cls.objects.filter(team=team).values_list("parent_id", flat=True).distinct() if parent: @@ -626,6 +632,7 @@ def eligible_children(cls, team: Team, parent: Experiment | None = None): .exclude(id__in=child_ids) .exclude(id__in=parent_ids) .exclude(id=parent.id) + .exclude(id__in=parent.versions.all()) ) else: eligible_experiments = Experiment.objects.filter(team=team).exclude(id__in=parent_ids) diff --git a/apps/experiments/tests/test_models.py b/apps/experiments/tests/test_models.py index e81eef31b..fdf5d43f9 100644 --- a/apps/experiments/tests/test_models.py +++ b/apps/experiments/tests/test_models.py @@ -343,6 +343,25 @@ def test_create_new_route_version(self, versioned): _compare_models(working_route, versioned_route, expected_changed_fields=expected_difference) +@pytest.mark.django_db() +class TestExperimentRoute: + def test_eligible_children(self): + parent = ExperimentFactory() + experiment_version = parent.create_new_version() + experiment1 = ExperimentFactory(team=parent.team) + experiment2 = ExperimentFactory(team=parent.team) + + queryset = ExperimentRoute.eligible_children(team=parent.team, parent=parent) + assert parent not in queryset + assert experiment_version not in queryset + assert experiment1 in queryset + assert experiment2 in queryset + assert len(queryset) == 2 + + queryset = ExperimentRoute.eligible_children(team=parent.team) + assert len(queryset) == 4 + + @pytest.mark.django_db() class TestExperimentVersioning: def test_working_experiment_cannot_be_the_default_version(self): From 8c70e6a2230c77ee61ed5339f846d78db932634f Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 12:56:08 +0200 Subject: [PATCH 05/46] Tag the bot message with the version of the experiment that generated it --- .../migrations/0005_alter_tag_category.py | 18 ++++++++++++++++++ apps/annotations/models.py | 1 + apps/chat/models.py | 11 ++++++++++- apps/service_providers/llm_service/state.py | 11 +++++++++++ apps/service_providers/tests/test_runnables.py | 14 ++++++++++++++ 5 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 apps/annotations/migrations/0005_alter_tag_category.py diff --git a/apps/annotations/migrations/0005_alter_tag_category.py b/apps/annotations/migrations/0005_alter_tag_category.py new file mode 100644 index 000000000..120ecab6e --- /dev/null +++ b/apps/annotations/migrations/0005_alter_tag_category.py @@ -0,0 +1,18 @@ +# Generated by Django 5.1 on 2024-09-06 10:51 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('annotations', '0004_default_tag_category'), + ] + + operations = [ + migrations.AlterField( + model_name='tag', + name='category', + field=models.CharField(blank=True, choices=[('bot_response', 'Bot Response'), ('experiment_version', 'Experiment Version')], default=''), + ), + ] diff --git a/apps/annotations/models.py b/apps/annotations/models.py index 9c99e985e..4d3378243 100644 --- a/apps/annotations/models.py +++ b/apps/annotations/models.py @@ -16,6 +16,7 @@ class TagCategories(models.TextChoices): BOT_RESPONSE = "bot_response", _("Bot Response") + EXPERIMENT_VERSION = "experiment_version", _("Experiment Version") @audit_fields( diff --git a/apps/chat/models.py b/apps/chat/models.py index 1095c470b..97d52f618 100644 --- a/apps/chat/models.py +++ b/apps/chat/models.py @@ -7,7 +7,7 @@ from django.utils.functional import classproperty from langchain.schema import BaseMessage, messages_from_dict -from apps.annotations.models import TaggedModelMixin, UserCommentsMixin +from apps.annotations.models import Tag, TagCategories, TaggedModelMixin, UserCommentsMixin from apps.files.models import File from apps.teams.models import BaseTeamModel from apps.utils.models import BaseModel @@ -188,6 +188,15 @@ def get_attached_files(self): def get_metadata(self, key: str): return self.metadata.get(key, None) + def add_system_tag(self, tag: str, tag_category: TagCategories): + tag, _ = Tag.objects.get_or_create( + name=tag, + team=self.chat.team, + is_system_tag=True, + category=tag_category, + ) + self.add_tag(tag, team=self.chat.team, added_by=None) + class ChatAttachment(BaseModel): chat = models.ForeignKey(Chat, on_delete=models.CASCADE, related_name="attachments") diff --git a/apps/service_providers/llm_service/state.py b/apps/service_providers/llm_service/state.py index 6ab97c16c..38a4c5ed4 100644 --- a/apps/service_providers/llm_service/state.py +++ b/apps/service_providers/llm_service/state.py @@ -147,6 +147,11 @@ def save_message_to_history(self, message: str, type_: ChatMessageType, experime ) chat_message.add_tag(tag, team=self.session.team, added_by=None) + if type_ == ChatMessageType.AI and not self.experiment.is_working_version: + chat_message.add_system_tag( + tag=self.experiment.version_display, tag_category=TagCategories.EXPERIMENT_VERSION + ) + def check_cancellation(self): self.session.chat.refresh_from_db(fields=["metadata"]) # temporary mechanism to cancel the chat @@ -262,6 +267,12 @@ def save_message_to_history( category=TagCategories.BOT_RESPONSE, ) chat_message.add_tag(tag, team=self.session.team, added_by=None) + + if type_ == ChatMessageType.AI and not self.experiment.is_working_version: + chat_message.add_system_tag( + tag=self.experiment.version_display, tag_category=TagCategories.EXPERIMENT_VERSION + ) + return chat_message def get_tools(self): diff --git a/apps/service_providers/tests/test_runnables.py b/apps/service_providers/tests/test_runnables.py index 51d3cf27e..0692cee86 100644 --- a/apps/service_providers/tests/test_runnables.py +++ b/apps/service_providers/tests/test_runnables.py @@ -6,6 +6,7 @@ import pytest from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from apps.annotations.models import TagCategories from apps.chat.models import Chat, ChatMessage, ChatMessageType from apps.experiments.models import AgentTools, SourceMaterial from apps.service_providers.llm_service.runnables import ( @@ -78,6 +79,19 @@ def test_runnable(runnable, session, fake_llm_service): assert "tools" not in fake_llm_service.llm.get_calls()[0].kwargs +@pytest.mark.django_db() +@freezegun.freeze_time("2024-02-08 13:00:08.877096+00:00") +def test_bot_message_is_tagged_with_experiment_version(runnable, session, fake_llm_service): + experiment_version = session.experiment.create_new_version() + experiment_version.get_llm_service = lambda: fake_llm_service + chain = runnable.build(state=ChatExperimentState(experiment_version, session)) + chain.invoke("hi") + ai_message = session.chat.messages.get(message_type=ChatMessageType.AI) + tag = ai_message.tags.first() + assert tag.name == "v1" + assert tag.category == TagCategories.EXPERIMENT_VERSION + + @pytest.mark.django_db() @freezegun.freeze_time("2024-02-08 13:00:08.877096+00:00") def test_runnable_with_source_material(runnable, session, fake_llm_service): From ec0351f96608cbca31cbfab44f0fb23133b33e18 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 13:06:33 +0200 Subject: [PATCH 06/46] Hide versions from the safety layer and source material list views --- apps/experiments/views/safety.py | 2 +- apps/experiments/views/source_material.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/experiments/views/safety.py b/apps/experiments/views/safety.py index 9abcca6da..c77146730 100644 --- a/apps/experiments/views/safety.py +++ b/apps/experiments/views/safety.py @@ -30,7 +30,7 @@ class SafetyLayerTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - return SafetyLayer.objects.filter(team=self.request.team) + return SafetyLayer.objects.filter(team=self.request.team).exclude(working_version__isnull=False) class CreateSafetyLayer(CreateView): diff --git a/apps/experiments/views/source_material.py b/apps/experiments/views/source_material.py index a550730d6..d2d6a30ee 100644 --- a/apps/experiments/views/source_material.py +++ b/apps/experiments/views/source_material.py @@ -33,7 +33,7 @@ class SourceMaterialTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - query_set = SourceMaterial.objects.filter(team=self.request.team) + query_set = SourceMaterial.objects.filter(team=self.request.team).exclude(working_version__isnull=False) search = self.request.GET.get("search") if search: search_vector = SearchVector("topic", weight="A") + SearchVector("description", weight="B") From cb638ab6d49c44d1ba02ebdd07ef49d79eeeeaac Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 14:50:33 +0200 Subject: [PATCH 07/46] Update web interface to use the working experiment version where needed --- apps/chat/channels.py | 8 ++--- apps/experiments/decorators.py | 24 +++++++------ apps/experiments/models.py | 8 +++++ apps/experiments/views/experiment.py | 40 ++++++++++++++------- templates/experiments/email/invitation.html | 2 +- 5 files changed, 54 insertions(+), 28 deletions(-) diff --git a/apps/chat/channels.py b/apps/chat/channels.py index 992aacac8..b5ab2d9e4 100644 --- a/apps/chat/channels.py +++ b/apps/chat/channels.py @@ -544,14 +544,14 @@ def start_new_session( session = super().start_new_session( experiment, experiment_channel, participant_identifier, participant_user, session_status, timezone ) - WebChannel.check_and_process_seed_message(session) + WebChannel.check_and_process_seed_message(session, experiment) return session @classmethod - def check_and_process_seed_message(cls, session: ExperimentSession): + def check_and_process_seed_message(cls, session: ExperimentSession, experiment: Experiment): from apps.experiments.tasks import get_response_for_webchat_task - if seed_message := session.experiment_version.seed_message: + if seed_message := experiment.seed_message: session.seed_task_id = get_response_for_webchat_task.delay( session.id, message_text=seed_message, attachments=[] ).task_id @@ -786,7 +786,7 @@ def _start_experiment_session( session = ExperimentSession.objects.create( team=experiment.team, - experiment=experiment, + experiment=experiment.get_working_version(), experiment_channel=experiment_channel, status=session_status, participant=participant, diff --git a/apps/experiments/decorators.py b/apps/experiments/decorators.py index e2d387a55..fe4f424cc 100644 --- a/apps/experiments/decorators.py +++ b/apps/experiments/decorators.py @@ -25,11 +25,14 @@ def decorator(view_func): def decorated_view(request, team_slug: str, experiment_id: str, session_id: str): request.experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) request.experiment_session = get_object_or_404( - ExperimentSession, experiment=request.experiment, external_id=session_id, team=request.team + ExperimentSession, + experiment_id=request.experiment.get_working_version_id(), + external_id=session_id, + team=request.team, ) if allowed_states and request.experiment_session.status not in allowed_states: - return _redirect_for_state(request, request.experiment_session, team_slug) + return _redirect_for_state(request, team_slug) return view_func(request, team_slug, experiment_id, session_id) return decorated_view @@ -87,7 +90,7 @@ def _inner(request, *args, **kwargs): def _get_access_cookie_data(experiment_session): return { - "experiment_id": str(experiment_session.experiment.public_id), + "experiment_id": str(experiment_session.working_experiment.public_id), "session_id": str(experiment_session.external_id), "participant_id": experiment_session.participant_id, "user_id": experiment_session.participant.user_id, @@ -98,17 +101,18 @@ def _validate_access_cookie_data(experiment_session, access_data): return _get_access_cookie_data(experiment_session) == access_data -def _redirect_for_state(request, experiment_session, team_slug): - view_args = [team_slug, experiment_session.experiment.public_id, experiment_session.external_id] - if experiment_session.status in [SessionStatus.SETUP, SessionStatus.PENDING]: +def _redirect_for_state(request, team_slug): + view_args = [team_slug, request.experiment.public_id, request.experiment_session.external_id] + # TODO: Refactor using case match + if request.experiment_session.status in [SessionStatus.SETUP, SessionStatus.PENDING]: return HttpResponseRedirect(reverse("experiments:start_session_from_invite", args=view_args)) - elif experiment_session.status == SessionStatus.PENDING_PRE_SURVEY: + elif request.experiment_session.status == SessionStatus.PENDING_PRE_SURVEY: return HttpResponseRedirect(reverse("experiments:experiment_pre_survey", args=view_args)) - elif experiment_session.status == SessionStatus.ACTIVE: + elif request.experiment_session.status == SessionStatus.ACTIVE: return HttpResponseRedirect(reverse("experiments:experiment_chat", args=view_args)) - elif experiment_session.status == SessionStatus.PENDING_REVIEW: + elif request.experiment_session.status == SessionStatus.PENDING_REVIEW: return HttpResponseRedirect(reverse("experiments:experiment_review", args=view_args)) - elif experiment_session.status == SessionStatus.COMPLETE: + elif request.experiment_session.status == SessionStatus.COMPLETE: return HttpResponseRedirect(reverse("experiments:experiment_complete", args=view_args)) else: messages.info( diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 4aa5e2dc8..84308a21d 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -115,6 +115,14 @@ def get_working_version(self) -> "Experiment": return self return self.working_version + def get_working_version_id(self) -> int: + return self.working_version_id if self.working_version_id else self.id + + @cached_property + def default_version(self) -> "Experiment": + """Returns the default experiment, or if there is none, the working experiment""" + return Experiment.objects.get_default_or_working(self) + @audit_fields(*model_audit_fields.SOURCE_MATERIAL_FIELDS, audit_special_queryset_writes=True) class SourceMaterial(BaseTeamModel, VersionsMixin): diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index 3f6c71a17..5412479c5 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -691,6 +691,7 @@ def update_delete_channel(request, team_slug: str, experiment_id: int, channel_i @require_POST @login_and_team_required def start_authed_web_session(request, team_slug: str, experiment_id: int): + """Start an authed web session with the chosen experiment, be it a specific version or not""" experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) session = WebChannel.start_new_session( @@ -707,8 +708,9 @@ def start_authed_web_session(request, team_slug: str, experiment_id: int): @login_and_team_required def experiment_chat_session(request, team_slug: str, experiment_id: int, session_id: int): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) + working_version = experiment.get_working_version() session = get_object_or_404( - ExperimentSession, participant__user=request.user, experiment_id=experiment_id, id=session_id + ExperimentSession, participant__user=request.user, experiment_id=working_version.id, id=session_id ) return TemplateResponse( request, @@ -725,8 +727,11 @@ def experiment_chat_session(request, team_slug: str, experiment_id: int, session def experiment_session_message(request, team_slug: str, experiment_id: int, session_id: int): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) # hack for anonymous user/teams + working_version_id = experiment.get_working_version_id() user = get_real_user_or_none(request.user) - session = get_object_or_404(ExperimentSession, participant__user=user, experiment_id=experiment_id, id=session_id) + session = get_object_or_404( + ExperimentSession, participant__user=user, experiment_id=working_version_id, id=session_id + ) message_text = request.POST["message"] uploaded_files = request.FILES @@ -764,9 +769,12 @@ def experiment_session_message(request, team_slug: str, experiment_id: int, sess # @login_and_team_required def get_message_response(request, team_slug: str, experiment_id: int, session_id: int, task_id: str): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) + working_version_id = experiment.get_working_version_id() # hack for anonymous user/teams user = get_real_user_or_none(request.user) - session = get_object_or_404(ExperimentSession, participant__user=user, experiment_id=experiment_id, id=session_id) + session = get_object_or_404( + ExperimentSession, participant__user=user, experiment_id=working_version_id, id=session_id + ) last_message = ChatMessage.objects.filter(chat=session.chat).order_by("-created_at").first() progress = Progress(AsyncResult(task_id)).get_info() # don't render empty messages @@ -787,11 +795,14 @@ def get_message_response(request, team_slug: str, experiment_id: int, session_id def poll_messages(request, team_slug: str, experiment_id: int, session_id: int): + # experiment_id can be a version's ID user = get_real_user_or_none(request.user) params = request.GET.dict() since_param = params.get("since") + experiment = get_object_or_404(Experiment, id=experiment_id) + working_version_id = experiment.get_working_version_id() experiment_session = get_object_or_404( - ExperimentSession, participant__user=user, experiment_id=experiment_id, id=session_id, team=request.team + ExperimentSession, participant__user=user, experiment_id=working_version_id, id=session_id, team=request.team ) since = timezone.now() @@ -873,17 +884,18 @@ def start_session_public(request, team_slug: str, experiment_id: str): @permission_required("experiments.invite_participants", raise_exception=True) def experiment_invitations(request, team_slug: str, experiment_id: str): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) + working_version = experiment.get_working_version() sessions = experiment.sessions.order_by("-created_at").filter( status__in=["setup", "pending"], participant__isnull=False, ) - form = ExperimentInvitationForm(initial={"experiment_id": experiment.id}) + form = ExperimentInvitationForm(initial={"experiment_id": working_version.id}) if request.method == "POST": post_form = ExperimentInvitationForm(request.POST) if post_form.is_valid(): if ExperimentSession.objects.filter( team=request.team, - experiment=experiment, + experiment=working_version, status__in=["setup", "pending"], participant__identifier=post_form.cleaned_data["email"], ).exists(): @@ -892,7 +904,7 @@ def experiment_invitations(request, team_slug: str, experiment_id: str): else: with transaction.atomic(): session = WebChannel.start_new_session( - experiment=experiment, + experiment=experiment.default_version, participant_identifier=post_form.cleaned_data["email"], session_status=SessionStatus.SETUP, timezone=request.session.get("detected_tz", None), @@ -953,7 +965,7 @@ def _record_consent_and_redirect(request, team_slug: str, experiment_session: Ex response = HttpResponseRedirect( reverse( redirect_url_name, - args=[team_slug, experiment_session.experiment.public_id, experiment_session.external_id], + args=[team_slug, experiment_session.experiment_version.public_id, experiment_session.external_id], ) ) return set_session_access_cookie(response, experiment_session) @@ -961,13 +973,15 @@ def _record_consent_and_redirect(request, team_slug: str, experiment_session: Ex @experiment_session_view(allowed_states=[SessionStatus.SETUP, SessionStatus.PENDING]) def start_session_from_invite(request, team_slug: str, experiment_id: str, session_id: str): + # A session from invite will (for now?) always use the default experiment version experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) - experiment_session = get_object_or_404(ExperimentSession, experiment=experiment, external_id=session_id) - experiment_version = experiment_session.experiment_version - consent = experiment_version.consent_form + working_version = experiment.get_working_version() + default_version = experiment.default_version + experiment_session = get_object_or_404(ExperimentSession, experiment=working_version, external_id=session_id) + consent = experiment.consent_form initial = { - "experiment_id": experiment_version.id, + "experiment_id": default_version.id, } if not experiment_session.participant: raise Http404() @@ -989,7 +1003,7 @@ def start_session_from_invite(request, team_slug: str, experiment_id: str, sessi "experiments/start_experiment_session.html", { "active_tab": "experiments", - "experiment": experiment, + "experiment": experiment.default_version, "consent_notice": mark_safe(consent_notice), "form": form, }, diff --git a/templates/experiments/email/invitation.html b/templates/experiments/email/invitation.html index a0e583799..37861b91f 100644 --- a/templates/experiments/email/invitation.html +++ b/templates/experiments/email/invitation.html @@ -1,7 +1,7 @@ {% extends 'account/email/email_template_base.html' %} {% load i18n %} {% block message_body %} - You've been invited to participate in a Dimagi ChatBot experiment called {{ session.experiment }}. + You've been invited to participate in a Dimagi ChatBot experiment called {{ session.experiment_version }}. {% endblock %} {% block cta_link %}{{ session.get_invite_url }}{% endblock %} {% block cta_text %}{% translate "Start Experiment" %}{% endblock %} From 99ab0a3062d6c17a47316b8dce5b896babe0f257 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 14:56:25 +0200 Subject: [PATCH 08/46] Small refactor: Use match case instead of if-elsif-elsif-elsif-elif...you get the idea --- apps/experiments/decorators.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/apps/experiments/decorators.py b/apps/experiments/decorators.py index fe4f424cc..5f7d21da5 100644 --- a/apps/experiments/decorators.py +++ b/apps/experiments/decorators.py @@ -103,19 +103,20 @@ def _validate_access_cookie_data(experiment_session, access_data): def _redirect_for_state(request, team_slug): view_args = [team_slug, request.experiment.public_id, request.experiment_session.external_id] - # TODO: Refactor using case match - if request.experiment_session.status in [SessionStatus.SETUP, SessionStatus.PENDING]: - return HttpResponseRedirect(reverse("experiments:start_session_from_invite", args=view_args)) - elif request.experiment_session.status == SessionStatus.PENDING_PRE_SURVEY: - return HttpResponseRedirect(reverse("experiments:experiment_pre_survey", args=view_args)) - elif request.experiment_session.status == SessionStatus.ACTIVE: - return HttpResponseRedirect(reverse("experiments:experiment_chat", args=view_args)) - elif request.experiment_session.status == SessionStatus.PENDING_REVIEW: - return HttpResponseRedirect(reverse("experiments:experiment_review", args=view_args)) - elif request.experiment_session.status == SessionStatus.COMPLETE: - return HttpResponseRedirect(reverse("experiments:experiment_complete", args=view_args)) - else: - messages.info( - request, "Session was in an unknown/unexpected state." " It may be old, or something may have gone wrong." - ) - return HttpResponseRedirect(reverse("experiments:experiment_session_view", args=view_args)) + match request.experiment_session.status: + case SessionStatus.SETUP | SessionStatus.PENDING: + return HttpResponseRedirect(reverse("experiments:start_session_from_invite", args=view_args)) + case SessionStatus.PENDING_PRE_SURVEY: + return HttpResponseRedirect(reverse("experiments:experiment_pre_survey", args=view_args)) + case SessionStatus.ACTIVE: + return HttpResponseRedirect(reverse("experiments:experiment_chat", args=view_args)) + case SessionStatus.PENDING_REVIEW: + return HttpResponseRedirect(reverse("experiments:experiment_review", args=view_args)) + case SessionStatus.COMPLETE: + return HttpResponseRedirect(reverse("experiments:experiment_complete", args=view_args)) + case _: + messages.info( + request, + "Session was in an unknown/unexpected state." " It may be old, or something may have gone wrong.", + ) + return HttpResponseRedirect(reverse("experiments:experiment_session_view", args=view_args)) From 82ad5ca5bab68c48a6a96e588070af8a0828b070 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 15:04:28 +0200 Subject: [PATCH 09/46] Rename experiment param to experiment_version to make it clearer that it can be the version --- apps/chat/channels.py | 26 ++++++++++++++------------ apps/experiments/views/experiment.py | 6 +++--- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/apps/chat/channels.py b/apps/chat/channels.py index b5ab2d9e4..f23664c2f 100644 --- a/apps/chat/channels.py +++ b/apps/chat/channels.py @@ -113,7 +113,7 @@ def __init__( @classmethod def start_new_session( cls, - experiment: Experiment, + experiment_version: Experiment, experiment_channel: ExperimentChannel, participant_identifier: str, participant_user: CustomUser | None = None, @@ -122,7 +122,7 @@ def start_new_session( session_external_id: str | None = None, ): return _start_experiment_session( - experiment, + experiment_version, experiment_channel, participant_identifier, participant_user, @@ -477,7 +477,7 @@ def _create_new_experiment_session(self): session """ self.experiment_session = self.start_new_session( - experiment=self.experiment, + experiment_version=self.experiment, experiment_channel=self.experiment_channel, participant_identifier=self.participant_identifier, participant_user=self.participant_user, @@ -534,17 +534,17 @@ def _ensure_sessions_exists(self): @classmethod def start_new_session( cls, - experiment: Experiment, + experiment_version: Experiment, participant_identifier: str, participant_user: CustomUser | None = None, session_status: SessionStatus = SessionStatus.ACTIVE, timezone: str | None = None, ): - experiment_channel = ExperimentChannel.objects.get_team_web_channel(experiment.team) + experiment_channel = ExperimentChannel.objects.get_team_web_channel(experiment_version.team) session = super().start_new_session( - experiment, experiment_channel, participant_identifier, participant_user, session_status, timezone + experiment_version, experiment_channel, participant_identifier, participant_user, session_status, timezone ) - WebChannel.check_and_process_seed_message(session, experiment) + WebChannel.check_and_process_seed_message(session, experiment_version) return session @classmethod @@ -758,7 +758,7 @@ def _ensure_sessions_exists(self): def _start_experiment_session( - experiment: Experiment, + experiment_version: Experiment, experiment_channel: ExperimentChannel, participant_identifier: str, participant_user: CustomUser | None = None, @@ -766,6 +766,8 @@ def _start_experiment_session( timezone: str | None = None, session_external_id: str | None = None, ) -> ExperimentSession: + team = experiment_version.team + working_version = experiment_version.get_working_version() if not participant_identifier and not participant_user: raise ValueError("Either participant_identifier or participant_user must be specified!") @@ -775,7 +777,7 @@ def _start_experiment_session( with transaction.atomic(): participant, created = Participant.objects.get_or_create( - team=experiment.team, + team=team, identifier=participant_identifier, platform=experiment_channel.platform, defaults={"user": participant_user}, @@ -785,8 +787,8 @@ def _start_experiment_session( participant.save() session = ExperimentSession.objects.create( - team=experiment.team, - experiment=experiment.get_working_version(), + team=team, + experiment=working_version, experiment_channel=experiment_channel, status=session_status, participant=participant, @@ -795,7 +797,7 @@ def _start_experiment_session( # Record the participant's timezone if timezone: - participant.update_memory(data={"timezone": timezone}, experiment=experiment) + participant.update_memory(data={"timezone": timezone}, experiment=working_version) if participant.experimentsession_set.count() == 1: enqueue_static_triggers.delay(session.id, StaticTriggerType.PARTICIPANT_JOINED_EXPERIMENT) diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index 5412479c5..539ec0d4a 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -695,7 +695,7 @@ def start_authed_web_session(request, team_slug: str, experiment_id: int): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) session = WebChannel.start_new_session( - experiment, + experiment_version=experiment, participant_user=request.user, participant_identifier=request.user.email, timezone=request.session.get("detected_tz", None), @@ -851,7 +851,7 @@ def start_session_public(request, team_slug: str, experiment_id: str): identifier = user.email if user else str(uuid.uuid4()) session = WebChannel.start_new_session( - experiment, + experiment_version=experiment, participant_user=user, participant_identifier=identifier, timezone=request.session.get("detected_tz", None), @@ -904,7 +904,7 @@ def experiment_invitations(request, team_slug: str, experiment_id: str): else: with transaction.atomic(): session = WebChannel.start_new_session( - experiment=experiment.default_version, + experiment_version=experiment.default_version, participant_identifier=post_form.cleaned_data["email"], session_status=SessionStatus.SETUP, timezone=request.session.get("detected_tz", None), From 33ebe7729a2b9022ece6621105efcac456567d1e Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 15:04:43 +0200 Subject: [PATCH 10/46] Update slack listener to use the default experiment version --- apps/slack/slack_listeners.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/apps/slack/slack_listeners.py b/apps/slack/slack_listeners.py index 649e04350..732a93ed2 100644 --- a/apps/slack/slack_listeners.py +++ b/apps/slack/slack_listeners.py @@ -61,10 +61,11 @@ def respond_to_message(event, context: BoltContext, session=None): slack_user = event.get("user") + experiment_version = experiment_channel.experiment.default_version if not session: external_id = make_session_external_id(channel_id, thread_ts) session = SlackChannel.start_new_session( - experiment_channel.experiment, experiment_channel, slack_user, session_external_id=external_id + experiment_version, experiment_channel, slack_user, session_external_id=external_id ) # strip out the mention @@ -75,7 +76,7 @@ def respond_to_message(event, context: BoltContext, session=None): # Set `send_response_to_user` to `False` to prevent it sending the message since we're going to send # it here using the already authenticated client. - ocs_channel = SlackChannel(experiment_channel.experiment, experiment_channel, session, send_response_to_user=False) + ocs_channel = SlackChannel(experiment_version, experiment_channel, session, send_response_to_user=False) response = ocs_channel.new_user_message(message) context.say(response, thread_ts=thread_ts) From 04194f94fde10a2f6292db1f8067feee5ccefb5b Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 15:15:31 +0200 Subject: [PATCH 11/46] Update get survey link methods to get the default version's links --- apps/experiments/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 84308a21d..61fea85b7 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -937,10 +937,10 @@ def get_platform_name(self) -> str: return self.experiment_channel.get_platform_display() def get_pre_survey_link(self): - return self.experiment.pre_survey.get_link(self.participant, self) + return self.experiment_version.pre_survey.get_link(self.participant, self) def get_post_survey_link(self): - return self.experiment.post_survey.get_link(self.participant, self) + return self.experiment_version.post_survey.get_link(self.participant, self) def is_stale(self) -> bool: """A Channel Session is considered stale if the experiment that the channel points to differs from the From e7275e2cb8a384884df0398092c203b583ff11ad Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 15:16:49 +0200 Subject: [PATCH 12/46] Update experiment admin dashboard to show experiment versions --- apps/experiments/admin.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/apps/experiments/admin.py b/apps/experiments/admin.py index c5ec06501..0487f1c6e 100644 --- a/apps/experiments/admin.py +++ b/apps/experiments/admin.py @@ -67,12 +67,27 @@ class SurveyAdmin(admin.ModelAdmin): @admin.register(models.Experiment) class ExperimentAdmin(admin.ModelAdmin): - list_display = ("name", "team", "owner", "source_material", "llm", "llm_provider") + list_display = ( + "name", + "team", + "owner", + "source_material", + "llm", + "llm_provider", + "version_family", + "version_number", + ) list_filter = ("team", "owner", "source_material") inlines = [SafetyLayerInline] exclude = ["safety_layers"] readonly_fields = ("public_id",) + @admin.display(description="Version Family") + def version_family(self, obj): + if obj.working_version: + return obj.working_version.name + return obj.name + @admin.register(models.ExperimentRoute) class ExperimentRouteAdmin(admin.ModelAdmin): From 7c9800a40b49451c6cb87d295ba44028cabc17c9 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 15:34:00 +0200 Subject: [PATCH 13/46] Update channel views and API to use the experiment version. API to be changed to use a specified version in the future --- apps/api/openai.py | 2 +- apps/channels/tasks.py | 14 ++++++++------ apps/channels/views.py | 7 +++++-- apps/chat/bots.py | 4 ++-- apps/chat/channels.py | 2 +- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/apps/api/openai.py b/apps/api/openai.py index 6ba053e5e..4f717cfb2 100644 --- a/apps/api/openai.py +++ b/apps/api/openai.py @@ -109,7 +109,7 @@ def chat_completions(request, experiment_id: str): session = serializer.save() response_message = handle_api_message( request.user, - session.experiment, + session.experiment_version, session.experiment_channel, last_message.get("content"), session.participant.identifier, diff --git a/apps/channels/tasks.py b/apps/channels/tasks.py index 5374f6524..b792e22fa 100644 --- a/apps/channels/tasks.py +++ b/apps/channels/tasks.py @@ -33,7 +33,7 @@ def handle_telegram_message(self, message_data: str, channel_external_id: uuid): return message = TelegramMessage.parse(update) - message_handler = TelegramChannel(experiment_channel.experiment, experiment_channel) + message_handler = TelegramChannel(experiment_channel.experiment.default_version, experiment_channel) update_taskbadger_data(self, message_handler, message) with current_team(experiment_channel.team): @@ -68,7 +68,7 @@ def handle_twilio_message(self, message_data: str, request_uri: str, signature: validate_twillio_request(experiment_channel, raw_data, request_uri, signature) - message_handler = ChannelClass(experiment_channel.experiment, experiment_channel=experiment_channel) + message_handler = ChannelClass(experiment_channel.experiment.default_version, experiment_channel=experiment_channel) update_taskbadger_data(self, message_handler, message) with current_team(experiment_channel.team): @@ -106,7 +106,7 @@ def handle_sureadhere_message(self, sureadhere_tenant_id: str, message_data: dic if not experiment_channel: log.info(f"No experiment channel found for SureAdhere tenant ID: {sureadhere_tenant_id}") return - channel = SureAdhereChannel(experiment_channel.experiment, experiment_channel) + channel = SureAdhereChannel(experiment_channel.experiment.default_version, experiment_channel) update_taskbadger_data(self, channel, message) with current_team(experiment_channel.team): channel.new_user_message(message) @@ -127,18 +127,20 @@ def handle_turn_message(self, experiment_id: uuid, message_data: dict): if not experiment_channel: log.info(f"No experiment channel found for experiment_id: {experiment_id}") return - channel = WhatsappChannel(experiment_channel.experiment, experiment_channel) + channel = WhatsappChannel(experiment_channel.experiment.default_version, experiment_channel) update_taskbadger_data(self, channel, message) with current_team(experiment_channel.team): channel.new_user_message(message) -def handle_api_message(user, experiment, experiment_channel, message_text: str, participant_id: str, session=None): +def handle_api_message( + user, experiment_version, experiment_channel, message_text: str, participant_id: str, session=None +): """Synchronously handles the message coming from the API""" message = BaseMessage(participant_id=participant_id, message_text=message_text) channel = ApiChannel( - experiment, + experiment_version, experiment_channel, experiment_session=session, user=user, diff --git a/apps/channels/views.py b/apps/channels/views.py index 05b0bacab..4226f820a 100644 --- a/apps/channels/views.py +++ b/apps/channels/views.py @@ -96,9 +96,12 @@ def new_api_message(request, experiment_id: uuid): session = None if session_id := message_data.get("session"): try: + # TODO: Support ability to select a specific version + experiment = Experiment.objects.get(public_id=experiment_id) + working_version_id = experiment.get_working_version_id() session = ExperimentSession.objects.select_related("experiment", "experiment_channel").get( external_id=session_id, - experiment__public_id=experiment_id, + experiment__id=working_version_id, team=request.team, participant__user=request.user, experiment_channel__platform=ChannelPlatform.API, @@ -115,7 +118,7 @@ def new_api_message(request, experiment_id: uuid): response = tasks.handle_api_message( request.user, - experiment, + experiment.default_version, experiment_channel, message_data["message"], participant_id, diff --git a/apps/chat/bots.py b/apps/chat/bots.py index 4f1e4847b..a17242769 100644 --- a/apps/chat/bots.py +++ b/apps/chat/bots.py @@ -60,7 +60,7 @@ class TopicBot: """ def __init__(self, session: ExperimentSession, experiment: Experiment | None = None, disable_tools: bool = False): - self.experiment = experiment or session.experiment + self.experiment = experiment or session.experiment_version self.disable_tools = disable_tools self.prompt = self.experiment.prompt_text self.input_formatter = self.experiment.input_formatter @@ -255,7 +255,7 @@ def filter_ai_messages(self) -> bool: class PipelineBot: def __init__(self, session: ExperimentSession): - self.experiment = session.experiment + self.experiment = session.experiment_version self.session = session def process_input(self, user_input: str, save_input_to_history=True, attachments: list["Attachment"] | None = None): diff --git a/apps/chat/channels.py b/apps/chat/channels.py index f23664c2f..7a1e481c4 100644 --- a/apps/chat/channels.py +++ b/apps/chat/channels.py @@ -460,7 +460,7 @@ def _ensure_sessions_exists(self): def _get_latest_session(self): return ( ExperimentSession.objects.filter( - experiment=self.experiment, + experiment=self.experiment.get_working_version(), participant__identifier=str(self.participant_identifier), ) .order_by("-created_at") From 47221ae7df4cb2acc4b054a1bf153596c0d3f0c9 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Fri, 6 Sep 2024 15:47:55 +0200 Subject: [PATCH 14/46] Base channel test updates --- .../tests/test_base_channel_behavior.py | 44 ++++++++++++++----- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/apps/channels/tests/test_base_channel_behavior.py b/apps/channels/tests/test_base_channel_behavior.py index b791e0739..163467e29 100644 --- a/apps/channels/tests/test_base_channel_behavior.py +++ b/apps/channels/tests/test_base_channel_behavior.py @@ -45,7 +45,7 @@ def test_incoming_message_adds_channel_info(telegram_channel): chat_id = 123123 message = telegram_messages.text_message(chat_id=chat_id) - _simulate_user_message(telegram_channel, message) + _send_user_message_on_channel(telegram_channel, message) experiment_session = ExperimentSession.objects.filter( experiment=telegram_channel.experiment, participant__identifier=chat_id @@ -60,7 +60,7 @@ def test_incoming_message_adds_channel_info(telegram_channel): def test_channel_added_for_experiment_session(telegram_channel): chat_id = 123123 message = telegram_messages.text_message(chat_id=chat_id) - _simulate_user_message(telegram_channel, message) + _send_user_message_on_channel(telegram_channel, message) participant = Participant.objects.get(identifier=chat_id) experiment_session = participant.experimentsession_set.first() assert experiment_session.experiment_channel is not None @@ -76,7 +76,7 @@ def test_incoming_message_uses_existing_experiment_session(telegram_channel): # First message message = telegram_messages.text_message(chat_id=chat_id) - _simulate_user_message(telegram_channel, message) + _send_user_message_on_channel(telegram_channel, message) # Let's find the session it created experiment_sessions_count = ExperimentSession.objects.filter( @@ -88,7 +88,7 @@ def test_incoming_message_uses_existing_experiment_session(telegram_channel): telegram_channel._create_new_experiment_session = Mock() # Second message - _simulate_user_message(telegram_channel, message) + _send_user_message_on_channel(telegram_channel, message) # Assertions experiment_sessions_count = ExperimentSession.objects.filter( @@ -107,14 +107,14 @@ def test_different_sessions_created_for_different_users(telegram_channel): # First user's message user_1_message = telegram_messages.text_message(chat_id=user_1_chat_id) - _simulate_user_message(telegram_channel, user_1_message) + _send_user_message_on_channel(telegram_channel, user_1_message) # Calling new_user_message added an experiment_session, so we should remove it before reusing the instance telegram_channel.experiment_session = None # Second user's message user_2_message = telegram_messages.text_message(chat_id=user_2_chat_id) - _simulate_user_message(telegram_channel, user_2_message) + _send_user_message_on_channel(telegram_channel, user_2_message) # Assertions experiment_sessions_count = ExperimentSession.objects.count() @@ -141,8 +141,8 @@ def test_different_participants_created_for_same_user_in_different_teams(): assert experiment1.team != experiment2.team - _simulate_user_message(channel1, user_message) - _simulate_user_message(channel2, user_message) + _send_user_message_on_channel(channel1, user_message) + _send_user_message_on_channel(channel2, user_message) experiment_sessions_count = ExperimentSession.objects.count() assert experiment_sessions_count == 2 @@ -159,7 +159,7 @@ def test_reset_command_creates_new_experiment_session(_send_text_to_user_mock, t telegram_chat_id = 00000 normal_message = telegram_messages.text_message(chat_id=telegram_chat_id) - _simulate_user_message(telegram_channel, normal_message) + _send_user_message_on_channel(telegram_channel, normal_message) reset_message = telegram_messages.text_message( chat_id=telegram_chat_id, message_text=ExperimentChannel.RESET_COMMAND @@ -184,10 +184,10 @@ def test_reset_conversation_does_not_create_new_session( telegram_chat_id = 00000 message1 = telegram_messages.text_message(chat_id=telegram_chat_id, message_text=ExperimentChannel.RESET_COMMAND) - _simulate_user_message(telegram_channel, message1) + _send_user_message_on_channel(telegram_channel, message1) message2 = telegram_messages.text_message(chat_id=telegram_chat_id, message_text=ExperimentChannel.RESET_COMMAND) - _simulate_user_message(telegram_channel, message2) + _send_user_message_on_channel(telegram_channel, message2) sessions = ExperimentSession.objects.for_chat_id(telegram_chat_id).all() assert len(sessions) == 1 @@ -195,7 +195,7 @@ def test_reset_conversation_does_not_create_new_session( assert sessions[0].chat.get_langchain_messages() == [] -def _simulate_user_message(channel_instance, user_message: str): +def _send_user_message_on_channel(channel_instance, user_message: str): with mock_experiment_llm(channel_instance.experiment, responses=["OK"]): channel_instance.new_user_message(user_message) @@ -775,3 +775,23 @@ def test_participant_authorization( telegram_channel.new_user_message(message) send_text_to_user.assert_called() assert send_text_to_user.call_args[0][0] == "Sorry, you are not allowed to chat to this bot" + + +@pytest.mark.django_db() +class TestVersioning: + """Tests relating to versioning behaviour within the ChannelBase class""" + + @patch("apps.chat.channels.TelegramChannel.send_text_to_user", Mock()) + @patch("apps.chat.channels.TelegramChannel._get_bot_response", Mock()) + def test_new_sessions_are_linked_to_the_working_experiment(self, experiment): + working_version = experiment + channel = ExperimentChannelFactory(experiment=working_version) + new_version = working_version.create_new_version() + + telegram = TelegramChannel(experiment=new_version, experiment_channel=channel) + telegram.telegram_bot = Mock() + telegram.new_user_message(telegram_messages.text_message()) + + # Check that the working experiment is linked to the session + assert ExperimentSession.objects.filter(experiment=working_version).exists() + assert not ExperimentSession.objects.filter(experiment=new_version).exists() From 0c53e44bc2a40625585d041f45b86c8e34725c7d Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Mon, 9 Sep 2024 08:15:47 +0200 Subject: [PATCH 15/46] Annotate safety layer and source material queries with is_version for easy lookup --- apps/experiments/models.py | 28 ++++++++++++++++++++--- apps/experiments/views/safety.py | 2 +- apps/experiments/views/source_material.py | 2 +- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 61fea85b7..f2185597f 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -12,7 +12,7 @@ from django.contrib.postgres.fields import ArrayField from django.core.validators import MaxValueValidator, MinValueValidator, validate_email from django.db import models, transaction -from django.db.models import Count, OuterRef, Q, Subquery +from django.db.models import BooleanField, Case, Count, OuterRef, Q, Subquery, When from django.urls import reverse from django.utils import timezone from django.utils.translation import gettext @@ -54,11 +54,33 @@ def working_versions_queryset(self): class SourceMaterialObjectManager(AuditingManager): - pass + def get_queryset(self) -> models.QuerySet: + return ( + super() + .get_queryset() + .annotate( + is_version=Case( + When(working_version_id__isnull=False, then=True), + When(working_version_id__isnull=True, then=False), + output_field=BooleanField(), + ) + ) + ) class SafetyLayerObjectManager(AuditingManager): - pass + def get_queryset(self) -> models.QuerySet: + return ( + super() + .get_queryset() + .annotate( + is_version=Case( + When(working_version_id__isnull=False, then=True), + When(working_version_id__isnull=True, then=False), + output_field=BooleanField(), + ) + ) + ) class ConsentFormObjectManager(AuditingManager): diff --git a/apps/experiments/views/safety.py b/apps/experiments/views/safety.py index c77146730..ce4492b2d 100644 --- a/apps/experiments/views/safety.py +++ b/apps/experiments/views/safety.py @@ -30,7 +30,7 @@ class SafetyLayerTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - return SafetyLayer.objects.filter(team=self.request.team).exclude(working_version__isnull=False) + return SafetyLayer.objects.filter(team=self.request.team, is_version=False) class CreateSafetyLayer(CreateView): diff --git a/apps/experiments/views/source_material.py b/apps/experiments/views/source_material.py index d2d6a30ee..7e0c92825 100644 --- a/apps/experiments/views/source_material.py +++ b/apps/experiments/views/source_material.py @@ -33,7 +33,7 @@ class SourceMaterialTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - query_set = SourceMaterial.objects.filter(team=self.request.team).exclude(working_version__isnull=False) + query_set = SourceMaterial.objects.filter(team=self.request.team, is_version=False) search = self.request.GET.get("search") if search: search_vector = SearchVector("topic", weight="A") + SearchVector("description", weight="B") From 4efb82d32f4c2076dd0d79f3a9a94482efca6b64 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Mon, 9 Sep 2024 14:42:51 +0200 Subject: [PATCH 16/46] Handle experiment deletions Archive experiment and its versions when being deleted --- apps/experiments/models.py | 16 +++++++++++++++- apps/experiments/tests/test_models.py | 27 ++++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/apps/experiments/models.py b/apps/experiments/models.py index f2185597f..4f220f9de 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -18,7 +18,7 @@ from django.utils.translation import gettext from django_cryptography.fields import encrypt from field_audit import audit_fields -from field_audit.models import AuditingManager +from field_audit.models import AuditAction, AuditingManager from apps.chat.models import Chat, ChatMessage, ChatMessageType from apps.experiments import model_audit_fields @@ -626,6 +626,20 @@ def is_public(self) -> bool: def is_participant_allowed(self, identifier: str): return identifier in self.participant_allowlist or self.team.members.filter(email=identifier).exists() + @transaction.atomic() + def delete(self, *args, **kwargs) -> tuple: + """Deletion strategy: + - If this experiment is a version, archive it + - If this experiment is the working version and has versions, archive all versions and this one + - If this experiment is the working version and does not have versions, delete it + """ + if self.is_working_version and not self.has_versions: + return super().delete(*args, **kwargs) + + self.versions.update(is_archived=True, audit_action=AuditAction.AUDIT) + self.is_archived = True + self.save() + class ExperimentRouteType(models.TextChoices): PROCESSOR = "processor" diff --git a/apps/experiments/tests/test_models.py b/apps/experiments/tests/test_models.py index fdf5d43f9..139382725 100644 --- a/apps/experiments/tests/test_models.py +++ b/apps/experiments/tests/test_models.py @@ -363,7 +363,7 @@ def test_eligible_children(self): @pytest.mark.django_db() -class TestExperimentVersioning: +class TestExperimentModel: def test_working_experiment_cannot_be_the_default_version(self): with pytest.raises(ValueError, match="A working experiment cannot be a default version"): ExperimentFactory(is_default_version=True, working_version=None) @@ -416,7 +416,6 @@ def _setup_original_experiment(self): TimeoutTriggerFactory(experiment=experiment) return experiment - @pytest.mark.django_db() def test_first_version_is_automatically_the_default(self): experiment = ExperimentFactory() new_version = experiment.create_new_version() @@ -427,7 +426,6 @@ def test_first_version_is_automatically_the_default(self): assert another_version.version_number == 2 assert not another_version.is_default_version - @pytest.mark.django_db() def test_create_experiment_version(self): original_experiment = self._setup_original_experiment() @@ -506,6 +504,29 @@ def _assert_triggers_are_duplicated(self, trigger_type, original_experiment, new expected_changed_fields=["id", "action_id", "working_version_id", "experiment_id"], ) + def test_delete_working_experiment_without_versions(self): + working_version = ExperimentFactory() + working_version.delete() + with pytest.raises(Experiment.DoesNotExist): + working_version.refresh_from_db() + + def test_delete_working_experiment_with_versions(self): + working_version = ExperimentFactory() + working_version.create_new_version() + + working_version.delete() + working_version.refresh_from_db() + assert working_version.is_archived is True + for version in working_version.versions.all(): + assert version.is_archived is True + + def test_delete_versioned_experiment(self): + working_version = ExperimentFactory() + version = working_version.create_new_version() + version.delete() + version.refresh_from_db() + assert version.is_archived is True + @pytest.mark.django_db() class TestExperimentObjectManager: From 69a7aab9856ea428d2334b9d206619339a30fba7 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 08:11:35 +0200 Subject: [PATCH 17/46] Duplicate the consent form and surveys as well. Also do some small refactoring --- ..._working_version_survey_working_version.py | 24 +++++++ apps/experiments/models.py | 62 ++++++++++++------- apps/experiments/tests/test_models.py | 29 ++++++++- 3 files changed, 92 insertions(+), 23 deletions(-) create mode 100644 apps/experiments/migrations/0093_consentform_working_version_survey_working_version.py diff --git a/apps/experiments/migrations/0093_consentform_working_version_survey_working_version.py b/apps/experiments/migrations/0093_consentform_working_version_survey_working_version.py new file mode 100644 index 000000000..6a69e5aee --- /dev/null +++ b/apps/experiments/migrations/0093_consentform_working_version_survey_working_version.py @@ -0,0 +1,24 @@ +# Generated by Django 5.1 on 2024-09-09 14:01 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('experiments', '0092_experiment_is_archived_experiment_is_default_version_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='consentform', + name='working_version', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='versions', to='experiments.consentform'), + ), + migrations.AddField( + model_name='survey', + name='working_version', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='versions', to='experiments.survey'), + ), + ] diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 4f220f9de..a3509d8fe 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -145,6 +145,10 @@ def default_version(self) -> "Experiment": """Returns the default experiment, or if there is none, the working experiment""" return Experiment.objects.get_default_or_working(self) + @property + def has_versions(self): + return self.versions.count() > 0 + @audit_fields(*model_audit_fields.SOURCE_MATERIAL_FIELDS, audit_special_queryset_writes=True) class SourceMaterial(BaseTeamModel, VersionsMixin): @@ -211,7 +215,7 @@ def get_absolute_url(self): return reverse("experiments:safety_edit", args=[self.team.slug, self.id]) -class Survey(BaseTeamModel): +class Survey(BaseTeamModel, VersionsMixin): """ A survey. """ @@ -226,6 +230,13 @@ class Survey(BaseTeamModel): " Survey link: {survey_link}" ), ) + working_version = models.ForeignKey( + "self", + on_delete=models.CASCADE, + null=True, + blank=True, + related_name="versions", + ) class Meta: ordering = ["name"] @@ -246,7 +257,7 @@ def get_absolute_url(self): @audit_fields(*model_audit_fields.CONSENT_FORM_FIELDS, audit_special_queryset_writes=True) -class ConsentForm(BaseTeamModel): +class ConsentForm(BaseTeamModel, VersionsMixin): """ Custom markdown consent form to be used by experiments. """ @@ -263,6 +274,13 @@ class ConsentForm(BaseTeamModel): default="Respond with '1' if you agree", help_text=("Use this text to tell the user to respond with '1' in order to give their consent"), ) + working_version = models.ForeignKey( + "self", + on_delete=models.CASCADE, + null=True, + blank=True, + related_name="versions", + ) class Meta: ordering = ["name"] @@ -542,10 +560,6 @@ def tools_enabled(self): def event_triggers(self): return [*self.timeout_triggers.all(), *self.static_triggers.all()] - @property - def has_versions(self): - return self.versions.count() > 0 - @property def version_display(self) -> str: if self.is_working_version: @@ -579,28 +593,38 @@ def create_new_version(self): new_version = super().create_new_version(save=False) new_version.public_id = uuid4() new_version.version_number = version_number - if self.source_material: - new_version.source_material = self.source_material.create_new_version() + + self._copy_attr_to_new_version("source_material", new_version) + self._copy_attr_to_new_version("consent_form", new_version) + self._copy_attr_to_new_version("pre_survey", new_version) + self._copy_attr_to_new_version("post_survey", new_version) if new_version.version_number == 1: new_version.is_default_version = True new_version.save() - self.copy_safety_layers_to_new_version(new_version) - self.copy_routes_to_new_version(new_version) - self.copy_static_triggers_to_new_version(new_version) - self.copy_timeout_triggers_to_new_version(new_version) + self._copy_safety_layers_to_new_version(new_version) + self._copy_routes_to_new_version(new_version) + self.copy_trigger_to_new_version(trigger_queryset=self.static_triggers, new_version=new_version) + self.copy_trigger_to_new_version(trigger_queryset=self.timeout_triggers, new_version=new_version) new_version.files.set(self.files.all()) return new_version - def copy_safety_layers_to_new_version(self, new_version: "Experiment"): + def _copy_attr_to_new_version(self, attr_name, new_version: "Experiment"): + """Copies the attribute `attr_name` to the new version by creating a new version of the related record and + linking that to `new_version` + """ + if instance := getattr(self, attr_name): + setattr(new_version, attr_name, instance.create_new_version()) + + def _copy_safety_layers_to_new_version(self, new_version: "Experiment"): duplicated_layers = [] for layer in self.safety_layers.all(): duplicated_layers.append(layer.create_new_version()) new_version.safety_layers.set(duplicated_layers) - def copy_routes_to_new_version(self, new_version: "Experiment"): + def _copy_routes_to_new_version(self, new_version: "Experiment"): """ This copies the experiment routes where this experiment is the parent and sets the new parent to the new version. @@ -608,13 +632,9 @@ def copy_routes_to_new_version(self, new_version: "Experiment"): for route in self.child_links.all(): route.create_new_version(new_version) - def copy_static_triggers_to_new_version(self, new_version: "Experiment"): - for static_trigger in self.static_triggers.all(): - static_trigger.create_new_version(new_experiment=new_version) - - def copy_timeout_triggers_to_new_version(self, new_version: "Experiment"): - for timeout_trigger in self.timeout_triggers.all(): - timeout_trigger.create_new_version(new_experiment=new_version) + def copy_trigger_to_new_version(self, trigger_queryset, new_version): + for trigger in trigger_queryset.all(): + trigger.create_new_version(new_experiment=new_version) @property def is_public(self) -> bool: diff --git a/apps/experiments/tests/test_models.py b/apps/experiments/tests/test_models.py index 139382725..9416a1a56 100644 --- a/apps/experiments/tests/test_models.py +++ b/apps/experiments/tests/test_models.py @@ -19,6 +19,7 @@ ExperimentSessionFactory, ParticipantFactory, SourceMaterialFactory, + SurveyFactory, SyntheticVoiceFactory, VersionedExperimentFactory, ) @@ -414,6 +415,14 @@ def _setup_original_experiment(self): # Setup Timeout Trigger TimeoutTriggerFactory(experiment=experiment) + + # Surveys + pre_survey = SurveyFactory(team=team) + post_survey = SurveyFactory(team=team) + experiment.pre_survey = pre_survey + experiment.post_survey = post_survey + + experiment.save() return experiment def test_first_version_is_automatically_the_default(self): @@ -450,13 +459,19 @@ def test_create_experiment_version(self): "working_version_id", "version_number", "is_default_version", + "consent_form_id", + "pre_survey_id", + "post_survey_id", ], ) self._assert_safety_layers_are_duplicated(original_experiment, new_version) - self._assert_source_material_is_duplicated(original_experiment, new_version) self._assert_files_are_duplicated(original_experiment, new_version) self._assert_triggers_are_duplicated("static", original_experiment, new_version) self._assert_triggers_are_duplicated("timeout", original_experiment, new_version) + self._assert_attribute_duplicated("source_material", original_experiment, new_version) + self._assert_attribute_duplicated("consent_form", original_experiment, new_version) + self._assert_attribute_duplicated("pre_survey", original_experiment, new_version) + self._assert_attribute_duplicated("post_survey", original_experiment, new_version) another_new_version = original_experiment.create_new_version() original_experiment.refresh_from_db() @@ -504,6 +519,13 @@ def _assert_triggers_are_duplicated(self, trigger_type, original_experiment, new expected_changed_fields=["id", "action_id", "working_version_id", "experiment_id"], ) + def _assert_attribute_duplicated(self, attr_name, original_experiment, new_version): + _compare_models( + original=getattr(original_experiment, attr_name), + new=getattr(new_version, attr_name), + expected_changed_fields=["id", "working_version_id"], + ) + def test_delete_working_experiment_without_versions(self): working_version = ExperimentFactory() working_version.delete() @@ -574,4 +596,7 @@ def _compare_models(original, new, expected_changed_fields: list) -> set: if field_value != new_dict[field_name]: changed_fields.add(field_name) - assert changed_fields.difference(set(expected_changed_fields)) == set() + field_difference = changed_fields.difference(set(expected_changed_fields)) + assert ( + field_difference == set() + ), f"These fields differ between the experiment versions, but should not: {field_difference}" From 391493595e2e665a815201a68481a951b5d28281 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 08:30:20 +0200 Subject: [PATCH 18/46] Add version descriptions (back) --- ...ion.py => 0093_consentform_working_version_and_more.py} | 7 ++++++- apps/experiments/models.py | 7 ++++++- apps/experiments/tables.py | 1 + apps/experiments/tests/test_models.py | 4 +++- apps/experiments/views/experiment.py | 5 +++-- 5 files changed, 19 insertions(+), 5 deletions(-) rename apps/experiments/migrations/{0093_consentform_working_version_survey_working_version.py => 0093_consentform_working_version_and_more.py} (78%) diff --git a/apps/experiments/migrations/0093_consentform_working_version_survey_working_version.py b/apps/experiments/migrations/0093_consentform_working_version_and_more.py similarity index 78% rename from apps/experiments/migrations/0093_consentform_working_version_survey_working_version.py rename to apps/experiments/migrations/0093_consentform_working_version_and_more.py index 6a69e5aee..6791ba570 100644 --- a/apps/experiments/migrations/0093_consentform_working_version_survey_working_version.py +++ b/apps/experiments/migrations/0093_consentform_working_version_and_more.py @@ -1,4 +1,4 @@ -# Generated by Django 5.1 on 2024-09-09 14:01 +# Generated by Django 5.1 on 2024-09-10 06:30 import django.db.models.deletion from django.db import migrations, models @@ -16,6 +16,11 @@ class Migration(migrations.Migration): name='working_version', field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='versions', to='experiments.consentform'), ), + migrations.AddField( + model_name='experiment', + name='version_description', + field=models.TextField(blank=True, default=''), + ), migrations.AddField( model_name='survey', name='working_version', diff --git a/apps/experiments/models.py b/apps/experiments/models.py index a3509d8fe..0a74287e4 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -518,6 +518,10 @@ class Experiment(BaseTeamModel, VersionsMixin): version_number = models.PositiveIntegerField(default=1) is_default_version = models.BooleanField(default=False) is_archived = models.BooleanField(default=False) + version_description = models.TextField( + blank=True, + default="", + ) objects = ExperimentObjectManager() class Meta: @@ -580,7 +584,7 @@ def get_api_url(self): return absolute_url(reverse("api:openai-chat-completions", args=[self.public_id])) @transaction.atomic() - def create_new_version(self): + def create_new_version(self, version_description: str | None = None): """ Creates a copy of an experiment as a new version of the original experiment. """ @@ -591,6 +595,7 @@ def create_new_version(self): # Fetch a new instance so the previous instance reference isn't simply being updated. I am not 100% sure # why simply chaing the pk, id and _state.adding wasn't enough. new_version = super().create_new_version(save=False) + new_version.version_description = version_description new_version.public_id = uuid4() new_version.version_number = version_number diff --git a/apps/experiments/tables.py b/apps/experiments/tables.py index 9761ad61b..411f10e20 100644 --- a/apps/experiments/tables.py +++ b/apps/experiments/tables.py @@ -161,6 +161,7 @@ class Meta: class ExperimentVersionsTable(tables.Table): version_number = columns.Column(verbose_name="Version Number", accessor="version_number") created_at = columns.Column(verbose_name="Created On", accessor="created_at") + version_description = columns.Column(verbose_name="Description", default="") is_default = columns.TemplateColumn( template_code="""{% if record.is_default_version %} ✓ diff --git a/apps/experiments/tests/test_models.py b/apps/experiments/tests/test_models.py index 9416a1a56..a0d83bab7 100644 --- a/apps/experiments/tests/test_models.py +++ b/apps/experiments/tests/test_models.py @@ -440,7 +440,7 @@ def test_create_experiment_version(self): assert original_experiment.version_number == 1 - new_version = original_experiment.create_new_version() + new_version = original_experiment.create_new_version("tis a new version") original_experiment.refresh_from_db() assert new_version != original_experiment @@ -449,6 +449,7 @@ def test_create_experiment_version(self): assert new_version.version_number == 1 assert new_version.is_default_version is True assert new_version.working_version == original_experiment + assert new_version.version_description == "tis a new version" _compare_models( original=original_experiment, new=new_version, @@ -462,6 +463,7 @@ def test_create_experiment_version(self): "consent_form_id", "pre_survey_id", "post_survey_id", + "version_description", ], ) self._assert_safety_layers_are_duplicated(original_experiment, new_version) diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index 539ec0d4a..1be7a22f3 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -477,7 +477,8 @@ class DeleteFileFromExperiment(BaseDeleteFileView): class ExperimentVersionForm(forms.ModelForm): class Meta: model = Experiment - fields = ["is_default_version"] + fields = ["version_description", "is_default_version"] + help_texts = {"version_description": "A description of this version, or what changed from the previous version"} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -494,7 +495,7 @@ class CreateExperimentVersion(LoginAndTeamRequiredMixin, CreateView): def form_valid(self, form): working_experiment = self.get_object() - working_experiment.create_new_version() + working_experiment.create_new_version(version_description=form.cleaned_data["version_description"]) return HttpResponseRedirect(self.get_success_url()) def get_success_url(self): From dc9ee4ccd3b7fd29f46afb5e2e83decc571ac337 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 08:51:37 +0200 Subject: [PATCH 19/46] Filter related fields' versions out. Annotate the consentform and Survey object manager queryset --- apps/experiments/models.py | 29 +++++++++++++++++++++++++++- apps/experiments/views/experiment.py | 10 +++++----- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 0a74287e4..495687fb5 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -84,7 +84,18 @@ def get_queryset(self) -> models.QuerySet: class ConsentFormObjectManager(AuditingManager): - pass + def get_queryset(self) -> models.QuerySet: + return ( + super() + .get_queryset() + .annotate( + is_version=Case( + When(working_version_id__isnull=False, then=True), + When(working_version_id__isnull=True, then=False), + output_field=BooleanField(), + ) + ) + ) class SyntheticVoiceObjectManager(AuditingManager): @@ -215,6 +226,21 @@ def get_absolute_url(self): return reverse("experiments:safety_edit", args=[self.team.slug, self.id]) +class SurveyObjectManager(models.Manager): + def get_queryset(self) -> models.QuerySet: + return ( + super() + .get_queryset() + .annotate( + is_version=Case( + When(working_version_id__isnull=False, then=True), + When(working_version_id__isnull=True, then=False), + output_field=BooleanField(), + ) + ) + ) + + class Survey(BaseTeamModel, VersionsMixin): """ A survey. @@ -237,6 +263,7 @@ class Survey(BaseTeamModel, VersionsMixin): blank=True, related_name="versions", ) + objects = SurveyObjectManager() class Meta: ordering = ["name"] diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index 1be7a22f3..fac03b488 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -238,11 +238,11 @@ def __init__(self, request, *args, **kwargs): self.fields["voice_provider"].queryset = team.voiceprovider_set.exclude( syntheticvoice__service__in=exclude_services ) - self.fields["safety_layers"].queryset = team.safetylayer_set - self.fields["source_material"].queryset = team.sourcematerial_set - self.fields["pre_survey"].queryset = team.survey_set - self.fields["post_survey"].queryset = team.survey_set - self.fields["consent_form"].queryset = team.consentform_set + self.fields["safety_layers"].queryset = team.safetylayer_set.exclude(is_version=True) + self.fields["source_material"].queryset = team.sourcematerial_set.exclude(is_version=True) + self.fields["pre_survey"].queryset = team.survey_set.exclude(is_version=True) + self.fields["post_survey"].queryset = team.survey_set.exclude(is_version=True) + self.fields["consent_form"].queryset = team.consentform_set.exclude(is_version=True) self.fields["synthetic_voice"].queryset = SyntheticVoice.get_for_team(team, exclude_services) self.fields["trace_provider"].queryset = team.traceprovider_set From af424dc095e9fc369f92026a8868ed77a697cf09 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 08:52:27 +0200 Subject: [PATCH 20/46] Make the new version the default if the new version form specified 'is_default_version' --- apps/experiments/models.py | 10 ++++++++-- apps/experiments/views/experiment.py | 4 +++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 495687fb5..8bb1e7095 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -611,7 +611,7 @@ def get_api_url(self): return absolute_url(reverse("api:openai-chat-completions", args=[self.public_id])) @transaction.atomic() - def create_new_version(self, version_description: str | None = None): + def create_new_version(self, version_description: str | None = None, make_default: bool = False): """ Creates a copy of an experiment as a new version of the original experiment. """ @@ -631,8 +631,14 @@ def create_new_version(self, version_description: str | None = None): self._copy_attr_to_new_version("pre_survey", new_version) self._copy_attr_to_new_version("post_survey", new_version) - if new_version.version_number == 1: + if new_version.version_number == 1 or make_default: new_version.is_default_version = True + + if make_default: + self.versions.filter(is_default_version=True).update( + is_default_version=False, audit_action=AuditAction.AUDIT + ) + new_version.save() self._copy_safety_layers_to_new_version(new_version) diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index fac03b488..168e9464e 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -495,7 +495,9 @@ class CreateExperimentVersion(LoginAndTeamRequiredMixin, CreateView): def form_valid(self, form): working_experiment = self.get_object() - working_experiment.create_new_version(version_description=form.cleaned_data["version_description"]) + description = form.cleaned_data["version_description"] + is_default = form.cleaned_data["is_default_version"] + working_experiment.create_new_version(version_description=description, make_default=is_default) return HttpResponseRedirect(self.get_success_url()) def get_success_url(self): From 6a194667e35918b5c82040bea4b8c4fc908553d8 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 09:11:23 +0200 Subject: [PATCH 21/46] Filter versioned consent forms and surveys out for their list view --- apps/experiments/views/consent.py | 2 +- apps/experiments/views/survey.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/experiments/views/consent.py b/apps/experiments/views/consent.py index 5b3981824..afb9d376c 100644 --- a/apps/experiments/views/consent.py +++ b/apps/experiments/views/consent.py @@ -32,7 +32,7 @@ class ConsentFormTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - return ConsentForm.objects.filter(team=self.request.team) + return ConsentForm.objects.filter(team=self.request.team, is_version=False) class CreateConsentForm(CreateView): diff --git a/apps/experiments/views/survey.py b/apps/experiments/views/survey.py index 85f18e2a7..ea9aa589e 100644 --- a/apps/experiments/views/survey.py +++ b/apps/experiments/views/survey.py @@ -33,7 +33,7 @@ class SurveyTableView(SingleTableView): template_name = "table/single_table.html" def get_queryset(self): - return Survey.objects.filter(team=self.request.team) + return Survey.objects.filter(team=self.request.team, is_version=False) class CreateSurvey(CreateView): From 9a6df6de31e77c3930d65efd03652b0563d3fa38 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 09:25:50 +0200 Subject: [PATCH 22/46] Update get_response_for_webchat_task and get_bot in the channels class to use the experiment that passed in or on the class --- apps/chat/bots.py | 1 - apps/chat/channels.py | 10 +++++----- apps/experiments/tasks.py | 7 ++++--- apps/experiments/views/experiment.py | 7 ++++++- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/apps/chat/bots.py b/apps/chat/bots.py index a17242769..830db482e 100644 --- a/apps/chat/bots.py +++ b/apps/chat/bots.py @@ -85,7 +85,6 @@ def __init__(self, session: ExperimentSession, experiment: Experiment | None = N self.trace_service = None if self.experiment.trace_provider: self.trace_service = self.experiment.trace_provider.get_service() - self._initialize() def _initialize(self): diff --git a/apps/chat/channels.py b/apps/chat/channels.py index 7a1e481c4..3bdfb9b84 100644 --- a/apps/chat/channels.py +++ b/apps/chat/channels.py @@ -108,7 +108,7 @@ def __init__( self.experiment_session = experiment_session self.message = None self._user_query = None - self.bot = get_bot(experiment_session) if experiment_session else None + self.bot = get_bot(experiment_session, experiment=experiment) if experiment_session else None @classmethod def start_new_session( @@ -220,7 +220,7 @@ def _add_message(self, message): raise ParticipantNotAllowedException() self._ensure_sessions_exists() - self.bot = get_bot(self.experiment_session) + self.bot = get_bot(self.experiment_session, experiment=self.experiment) def new_user_message(self, message) -> str: """Handles the message coming from the user. Call this to send bot messages to the user. @@ -414,7 +414,7 @@ def _transcribe_audio(self, audio: BytesIO) -> str: return speech_service.transcribe_audio(audio) def _get_bot_response(self, message: str) -> str: - self.bot = self.bot or get_bot(self.experiment_session) + self.bot = self.bot or get_bot(self.experiment_session, experiment=self.experiment) answer = self.bot.process_input(message, attachments=self.message.attachments) return answer @@ -513,7 +513,7 @@ def _inform_user_of_error(self): def _generate_response_for_user(self, prompt: str) -> str: """Generates a response based on the `prompt`.""" - topic_bot = self.bot or get_bot(self.experiment_session) + topic_bot = self.bot or get_bot(self.experiment_session, experiment=self.experiment) return topic_bot.process_input(user_input=prompt, save_input_to_history=False) @@ -553,7 +553,7 @@ def check_and_process_seed_message(cls, session: ExperimentSession, experiment: if seed_message := experiment.seed_message: session.seed_task_id = get_response_for_webchat_task.delay( - session.id, message_text=seed_message, attachments=[] + experiment_session_id=session.id, experiment_id=experiment.id, message_text=seed_message, attachments=[] ).task_id session.save() return session diff --git a/apps/experiments/tasks.py b/apps/experiments/tasks.py index 9cc2124c6..a9e8924f5 100644 --- a/apps/experiments/tasks.py +++ b/apps/experiments/tasks.py @@ -8,7 +8,7 @@ from apps.channels.datamodels import Attachment, BaseMessage from apps.chat.bots import create_conversation from apps.chat.channels import WebChannel -from apps.experiments.models import ExperimentSession, PromptBuilderHistory, SourceMaterial +from apps.experiments.models import Experiment, ExperimentSession, PromptBuilderHistory, SourceMaterial from apps.service_providers.models import LlmProvider from apps.teams.utils import current_team from apps.users.models import CustomUser @@ -17,13 +17,14 @@ @shared_task(bind=True, base=TaskbadgerTask) def get_response_for_webchat_task( - self, experiment_session_id: int, message_text: str, attachments: list | None = None + self, experiment_session_id: int, experiment_id: int, message_text: str, attachments: list | None = None ) -> str: experiment_session = ExperimentSession.objects.select_related("experiment", "experiment__team").get( id=experiment_session_id ) + experiment = Experiment.objects.get(id=experiment_id) web_channel = WebChannel( - experiment_session.experiment_version, + experiment, experiment_session.experiment_channel, experiment_session=experiment_session, ) diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index 168e9464e..0106da105 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -755,7 +755,12 @@ def experiment_session_message(request, team_slug: str, experiment_id: int, sess tool_resource.files.add(*created_files) - result = get_response_for_webchat_task.delay(session.id, message_text, attachments=attachments) + result = get_response_for_webchat_task.delay( + experiment_session_id=session.id, + experiment_id=experiment.id, + message_text=message_text, + attachments=attachments, + ) return TemplateResponse( request, "experiments/chat/experiment_response_htmx.html", From ebaac498dec2fe45ee72c80f7782421a3c55b2a1 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 16:25:30 +0200 Subject: [PATCH 23/46] Update dashboard views to assume the experiment_id url param is always the working experiment id --- apps/experiments/decorators.py | 4 +-- apps/experiments/models.py | 2 +- apps/experiments/views/experiment.py | 29 +++++++-------------- templates/experiments/email/invitation.html | 2 +- 4 files changed, 14 insertions(+), 23 deletions(-) diff --git a/apps/experiments/decorators.py b/apps/experiments/decorators.py index 5f7d21da5..3dc9134b6 100644 --- a/apps/experiments/decorators.py +++ b/apps/experiments/decorators.py @@ -26,7 +26,7 @@ def decorated_view(request, team_slug: str, experiment_id: str, session_id: str) request.experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) request.experiment_session = get_object_or_404( ExperimentSession, - experiment_id=request.experiment.get_working_version_id(), + experiment_id=experiment_id, external_id=session_id, team=request.team, ) @@ -90,7 +90,7 @@ def _inner(request, *args, **kwargs): def _get_access_cookie_data(experiment_session): return { - "experiment_id": str(experiment_session.working_experiment.public_id), + "experiment_id": str(experiment_session.experiment.public_id), "session_id": str(experiment_session.external_id), "participant_id": experiment_session.participant_id, "user_id": experiment_session.participant.user_id, diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 8bb1e7095..19e29b609 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -734,7 +734,7 @@ def eligible_children(cls, team: Team, parent: Experiment | None = None): .exclude(id__in=child_ids) .exclude(id__in=parent_ids) .exclude(id=parent.id) - .exclude(id__in=parent.versions.all()) + .exclude(working_version_id=parent.id) ) else: eligible_experiments = Experiment.objects.filter(team=team).exclude(id__in=parent_ids) diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index 0106da105..b5446b4fb 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -698,7 +698,7 @@ def start_authed_web_session(request, team_slug: str, experiment_id: int): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) session = WebChannel.start_new_session( - experiment_version=experiment, + experiment_version=experiment.default_version, participant_user=request.user, participant_identifier=request.user.email, timezone=request.session.get("detected_tz", None), @@ -711,9 +711,8 @@ def start_authed_web_session(request, team_slug: str, experiment_id: int): @login_and_team_required def experiment_chat_session(request, team_slug: str, experiment_id: int, session_id: int): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) - working_version = experiment.get_working_version() session = get_object_or_404( - ExperimentSession, participant__user=request.user, experiment_id=working_version.id, id=session_id + ExperimentSession, participant__user=request.user, experiment_id=experiment_id, id=session_id ) return TemplateResponse( request, @@ -730,11 +729,8 @@ def experiment_chat_session(request, team_slug: str, experiment_id: int, session def experiment_session_message(request, team_slug: str, experiment_id: int, session_id: int): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) # hack for anonymous user/teams - working_version_id = experiment.get_working_version_id() user = get_real_user_or_none(request.user) - session = get_object_or_404( - ExperimentSession, participant__user=user, experiment_id=working_version_id, id=session_id - ) + session = get_object_or_404(ExperimentSession, participant__user=user, experiment_id=experiment_id, id=session_id) message_text = request.POST["message"] uploaded_files = request.FILES @@ -777,12 +773,9 @@ def experiment_session_message(request, team_slug: str, experiment_id: int, sess # @login_and_team_required def get_message_response(request, team_slug: str, experiment_id: int, session_id: int, task_id: str): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) - working_version_id = experiment.get_working_version_id() # hack for anonymous user/teams user = get_real_user_or_none(request.user) - session = get_object_or_404( - ExperimentSession, participant__user=user, experiment_id=working_version_id, id=session_id - ) + session = get_object_or_404(ExperimentSession, participant__user=user, experiment_id=experiment_id, id=session_id) last_message = ChatMessage.objects.filter(chat=session.chat).order_by("-created_at").first() progress = Progress(AsyncResult(task_id)).get_info() # don't render empty messages @@ -840,6 +833,7 @@ def poll_messages(request, team_slug: str, experiment_id: int, session_id: int): def start_session_public(request, team_slug: str, experiment_id: str): try: experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) + experiment = experiment.default_version except ValidationError: # old links dont have uuids raise Http404 @@ -892,18 +886,17 @@ def start_session_public(request, team_slug: str, experiment_id: str): @permission_required("experiments.invite_participants", raise_exception=True) def experiment_invitations(request, team_slug: str, experiment_id: str): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) - working_version = experiment.get_working_version() sessions = experiment.sessions.order_by("-created_at").filter( status__in=["setup", "pending"], participant__isnull=False, ) - form = ExperimentInvitationForm(initial={"experiment_id": working_version.id}) + form = ExperimentInvitationForm(initial={"experiment_id": experiment_id}) if request.method == "POST": post_form = ExperimentInvitationForm(request.POST) if post_form.is_valid(): if ExperimentSession.objects.filter( team=request.team, - experiment=working_version, + experiment_id=experiment_id, status__in=["setup", "pending"], participant__identifier=post_form.cleaned_data["email"], ).exists(): @@ -981,12 +974,10 @@ def _record_consent_and_redirect(request, team_slug: str, experiment_session: Ex @experiment_session_view(allowed_states=[SessionStatus.SETUP, SessionStatus.PENDING]) def start_session_from_invite(request, team_slug: str, experiment_id: str, session_id: str): - # A session from invite will (for now?) always use the default experiment version experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) - working_version = experiment.get_working_version() + experiment_session = get_object_or_404(ExperimentSession, experiment_id=experiment_id, external_id=session_id) default_version = experiment.default_version - experiment_session = get_object_or_404(ExperimentSession, experiment=working_version, external_id=session_id) - consent = experiment.consent_form + consent = default_version.consent_form initial = { "experiment_id": default_version.id, @@ -1011,7 +1002,7 @@ def start_session_from_invite(request, team_slug: str, experiment_id: str, sessi "experiments/start_experiment_session.html", { "active_tab": "experiments", - "experiment": experiment.default_version, + "experiment": default_version, "consent_notice": mark_safe(consent_notice), "form": form, }, diff --git a/templates/experiments/email/invitation.html b/templates/experiments/email/invitation.html index 37861b91f..a0e583799 100644 --- a/templates/experiments/email/invitation.html +++ b/templates/experiments/email/invitation.html @@ -1,7 +1,7 @@ {% extends 'account/email/email_template_base.html' %} {% load i18n %} {% block message_body %} - You've been invited to participate in a Dimagi ChatBot experiment called {{ session.experiment_version }}. + You've been invited to participate in a Dimagi ChatBot experiment called {{ session.experiment }}. {% endblock %} {% block cta_link %}{{ session.get_invite_url }}{% endblock %} {% block cta_text %}{% translate "Start Experiment" %}{% endblock %} From 90bf8589b361961f0eb3786ae1903fed1200c79f Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 16:31:46 +0200 Subject: [PATCH 24/46] fix spelling --- apps/experiments/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 19e29b609..75db2d179 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -722,7 +722,7 @@ def eligible_children(cls, team: Team, parent: Experiment | None = None): Returns a list of experiments that fit the following criteria: - They are not the same as the parent - they are not parents - - they are ot not children of the current experiment + - they are not not children of the current experiment - they are not part of the current experiment's version family """ parent_ids = cls.objects.filter(team=team).values_list("parent_id", flat=True).distinct() From 970e16f880f11ac14506e6b749b519ab7246e655 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 16:45:24 +0200 Subject: [PATCH 25/46] Small fix --- apps/experiments/decorators.py | 2 +- apps/experiments/views/experiment.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/apps/experiments/decorators.py b/apps/experiments/decorators.py index 3dc9134b6..aa7aef31a 100644 --- a/apps/experiments/decorators.py +++ b/apps/experiments/decorators.py @@ -26,7 +26,7 @@ def decorated_view(request, team_slug: str, experiment_id: str, session_id: str) request.experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) request.experiment_session = get_object_or_404( ExperimentSession, - experiment_id=experiment_id, + experiment=request.experiment, external_id=session_id, team=request.team, ) diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index b5446b4fb..a19781c49 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -966,7 +966,7 @@ def _record_consent_and_redirect(request, team_slug: str, experiment_session: Ex response = HttpResponseRedirect( reverse( redirect_url_name, - args=[team_slug, experiment_session.experiment_version.public_id, experiment_session.external_id], + args=[team_slug, experiment_session.experiment.public_id, experiment_session.external_id], ) ) return set_session_access_cookie(response, experiment_session) @@ -975,7 +975,7 @@ def _record_consent_and_redirect(request, team_slug: str, experiment_session: Ex @experiment_session_view(allowed_states=[SessionStatus.SETUP, SessionStatus.PENDING]) def start_session_from_invite(request, team_slug: str, experiment_id: str, session_id: str): experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) - experiment_session = get_object_or_404(ExperimentSession, experiment_id=experiment_id, external_id=session_id) + experiment_session = get_object_or_404(ExperimentSession, experiment=experiment, external_id=session_id) default_version = experiment.default_version consent = default_version.consent_form From b3e3c2ee1aa1a3fef7f8bf0b067686022395ecc9 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Tue, 10 Sep 2024 16:47:51 +0200 Subject: [PATCH 26/46] Do not show archived experiments in experiments list --- apps/experiments/views/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index a19781c49..6791c4ee9 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -104,7 +104,7 @@ class ExperimentTableView(SingleTableView, PermissionRequiredMixin): permission_required = "experiments.view_experiment" def get_queryset(self): - query_set = Experiment.objects.filter(team=self.request.team, working_version__isnull=True) + query_set = Experiment.objects.filter(team=self.request.team, working_version__isnull=True, is_archived=False) search = self.request.GET.get("search") if search: search_vector = SearchVector("name", weight="A") + SearchVector("description", weight="B") From 4994299a9bda34581c9304ccac7a187733f5b9f0 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Wed, 11 Sep 2024 09:14:55 +0200 Subject: [PATCH 27/46] Filter archived experiments out by default and add an object manager method to retrieve all experiments --- apps/experiments/admin.py | 5 +++++ apps/experiments/models.py | 7 +++++++ apps/experiments/tests/test_models.py | 11 +++++++++++ 3 files changed, 23 insertions(+) diff --git a/apps/experiments/admin.py b/apps/experiments/admin.py index 0487f1c6e..a04fe9604 100644 --- a/apps/experiments/admin.py +++ b/apps/experiments/admin.py @@ -1,4 +1,6 @@ from django.contrib import admin +from django.db.models.query import QuerySet +from django.http import HttpRequest from apps.experiments import models @@ -82,6 +84,9 @@ class ExperimentAdmin(admin.ModelAdmin): exclude = ["safety_layers"] readonly_fields = ("public_id",) + def get_queryset(self, request: HttpRequest) -> QuerySet: + return models.Experiment.objects.get_all() + @admin.display(description="Version Family") def version_family(self, obj): if obj.working_version: diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 75db2d179..8a603e1bb 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -52,6 +52,13 @@ def working_versions_queryset(self): """Returns a queryset for all working experiments""" return self.get_queryset().filter(working_version=None) + def get_queryset(self): + return super().get_queryset().filter(is_archived=False) + + def get_all(self): + """A method to return all experiments whether it is deprecated or not""" + return super().get_queryset() + class SourceMaterialObjectManager(AuditingManager): def get_queryset(self) -> models.QuerySet: diff --git a/apps/experiments/tests/test_models.py b/apps/experiments/tests/test_models.py index a0d83bab7..42bfdd228 100644 --- a/apps/experiments/tests/test_models.py +++ b/apps/experiments/tests/test_models.py @@ -579,6 +579,17 @@ def test_working_versions_queryset(self): # All experiments in this queryset should have versions assert working_version.has_versions is True + def test_archived_experiments_are_filtered_out(self): + """Default queries should exclude archived experiments""" + experiment = ExperimentFactory() + experiment.create_new_version() + assert Experiment.objects.count() == 2 + experiment.delete() + assert Experiment.objects.count() == 0 + + # To get all experiment,s use the dedicated object method + assert Experiment.objects.get_all().count() == 2 + def _compare_models(original, new, expected_changed_fields: list) -> set: """ From 57819394fcf7a784249793d49c41b9877d844fac Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Wed, 11 Sep 2024 09:25:24 +0200 Subject: [PATCH 28/46] mark experiment channel as deleted when we archive an experiment with it linked --- apps/experiments/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 8a603e1bb..36115374a 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -698,6 +698,7 @@ def delete(self, *args, **kwargs) -> tuple: - If this experiment is the working version and has versions, archive all versions and this one - If this experiment is the working version and does not have versions, delete it """ + self.experimentchannel_set.update(deleted=True, audit_action=AuditAction.AUDIT) if self.is_working_version and not self.has_versions: return super().delete(*args, **kwargs) From 4e369ef8e171f45bb86bb8aee1de47ba3dcd58d5 Mon Sep 17 00:00:00 2001 From: Chris Smit Date: Wed, 11 Sep 2024 10:46:21 +0200 Subject: [PATCH 29/46] Refactor: Extract version specific variables to the view to centralize version decisions --- apps/experiments/email.py | 4 +- apps/experiments/models.py | 8 +- apps/experiments/views/experiment.py | 81 +++++++++++++------ templates/experiments/chat/chat_ui.html | 2 +- templates/experiments/chat/input_bar.html | 10 +-- templates/experiments/email/invitation.html | 2 +- .../experiments/experiment_invitations.html | 4 +- templates/experiments/experiment_review.html | 2 +- templates/experiments/pre_survey.html | 6 +- .../experiments/start_experiment_session.html | 4 +- 10 files changed, 79 insertions(+), 44 deletions(-) diff --git a/apps/experiments/email.py b/apps/experiments/email.py index 8322e7cb6..644d4cb73 100644 --- a/apps/experiments/email.py +++ b/apps/experiments/email.py @@ -10,11 +10,13 @@ def send_experiment_invitation(experiment_session: ExperimentSession): if not experiment_session.participant: raise Exception("Session has no participant!") + experiment_version_name = experiment_session.experiment_version.name email_context = { "session": experiment_session, + "experiment_name": experiment_version_name, } send_mail( - subject=_("You're invited to {}!").format(experiment_session.experiment_version.name), + subject=_("You're invited to {}!").format(experiment_version_name), message=render_to_string("experiments/email/invitation.txt", context=email_context), from_email=settings.DEFAULT_FROM_EMAIL, recipient_list=[experiment_session.participant.email], diff --git a/apps/experiments/models.py b/apps/experiments/models.py index 36115374a..90a215d8a 100644 --- a/apps/experiments/models.py +++ b/apps/experiments/models.py @@ -1038,11 +1038,11 @@ def user_already_engaged(self) -> bool: def get_platform_name(self) -> str: return self.experiment_channel.get_platform_display() - def get_pre_survey_link(self): - return self.experiment_version.pre_survey.get_link(self.participant, self) + def get_pre_survey_link(self, experiment_version: Experiment): + return experiment_version.pre_survey.get_link(self.participant, self) - def get_post_survey_link(self): - return self.experiment_version.post_survey.get_link(self.participant, self) + def get_post_survey_link(self, experiment_version: Experiment): + return experiment_version.post_survey.get_link(self.participant, self) def is_stale(self) -> bool: """A Channel Session is considered stale if the experiment that the channel points to differs from the diff --git a/apps/experiments/views/experiment.py b/apps/experiments/views/experiment.py index 6791c4ee9..52c5994f5 100644 --- a/apps/experiments/views/experiment.py +++ b/apps/experiments/views/experiment.py @@ -714,14 +714,15 @@ def experiment_chat_session(request, team_slug: str, experiment_id: int, session session = get_object_or_404( ExperimentSession, participant__user=request.user, experiment_id=experiment_id, id=session_id ) + experiment_version = experiment.default_version + version_specific_vars = { + "assistant": experiment_version.assistant, + "experiment_name": experiment_version.name, + } return TemplateResponse( request, "experiments/experiment_chat.html", - { - "experiment": experiment, - "session": session, - "active_tab": "experiments", - }, + {"experiment": experiment, "session": session, "active_tab": "experiments", **version_specific_vars}, ) @@ -751,12 +752,17 @@ def experiment_session_message(request, team_slug: str, experiment_id: int, sess tool_resource.files.add(*created_files) + experiment_version = experiment.default_version result = get_response_for_webchat_task.delay( experiment_session_id=session.id, - experiment_id=experiment.id, + experiment_id=experiment_version.id, message_text=message_text, attachments=attachments, ) + + version_specific_vars = { + "assistant": experiment_version.assistant, + } return TemplateResponse( request, "experiments/chat/experiment_response_htmx.html", @@ -766,6 +772,7 @@ def experiment_session_message(request, team_slug: str, experiment_id: int, sess "message_text": message_text, "task_id": result.task_id, "created_files": created_files, + **version_specific_vars, }, ) @@ -833,15 +840,15 @@ def poll_messages(request, team_slug: str, experiment_id: int, session_id: int): def start_session_public(request, team_slug: str, experiment_id: str): try: experiment = get_object_or_404(Experiment, public_id=experiment_id, team=request.team) - experiment = experiment.default_version + experiment_version = experiment.default_version except ValidationError: # old links dont have uuids raise Http404 - if not experiment.is_public: + if not experiment_version.is_public: raise Http404 - consent = experiment.consent_form + consent = experiment_version.consent_form user = get_real_user_or_none(request.user) if request.method == "POST": form = ConsentForm(consent, request.POST, initial={"identifier": user.email if user else None}) @@ -853,7 +860,7 @@ def start_session_public(request, team_slug: str, experiment_id: str): identifier = user.email if user else str(uuid.uuid4()) session = WebChannel.start_new_session( - experiment_version=experiment, + experiment_version=experiment_version, participant_user=user, participant_identifier=identifier, timezone=request.session.get("detected_tz", None), @@ -864,12 +871,16 @@ def start_session_public(request, team_slug: str, experiment_id: str): form = ConsentForm( consent, initial={ - "experiment_id": experiment.id, + "experiment_id": experiment_version.id, "identifier": user.email if user else None, }, ) consent_notice = consent.get_rendered_content() + version_specific_vars = { + "experiment_name": experiment_version.name, + "experiment_description": experiment_version.description, + } return TemplateResponse( request, "experiments/start_experiment_session.html", @@ -878,6 +889,7 @@ def start_session_public(request, team_slug: str, experiment_id: str): "experiment": experiment, "consent_notice": mark_safe(consent_notice), "form": form, + **version_specific_vars, }, ) @@ -886,6 +898,7 @@ def start_session_public(request, team_slug: str, experiment_id: str): @permission_required("experiments.invite_participants", raise_exception=True) def experiment_invitations(request, team_slug: str, experiment_id: str): experiment = get_object_or_404(Experiment, id=experiment_id, team=request.team) + experiment_version = experiment.default_version sessions = experiment.sessions.order_by("-created_at").filter( status__in=["setup", "pending"], participant__isnull=False, @@ -905,7 +918,7 @@ def experiment_invitations(request, team_slug: str, experiment_id: str): else: with transaction.atomic(): session = WebChannel.start_new_session( - experiment_version=experiment.default_version, + experiment_version=experiment_version, participant_identifier=post_form.cleaned_data["email"], session_status=SessionStatus.SETUP, timezone=request.session.get("detected_tz", None), @@ -915,14 +928,14 @@ def experiment_invitations(request, team_slug: str, experiment_id: str): else: form = post_form + version_specific_vars = { + "experiment_name": experiment_version.name, + "experiment_description": experiment_version.description, + } return TemplateResponse( request, "experiments/experiment_invitations.html", - { - "invitation_form": form, - "experiment": experiment, - "sessions": sessions, - }, + {"invitation_form": form, "experiment": experiment, "sessions": sessions, **version_specific_vars}, ) @@ -997,6 +1010,10 @@ def start_session_from_invite(request, team_slug: str, experiment_id: str, sessi form = ConsentForm(consent, initial=initial) consent_notice = consent.get_rendered_content() + version_specific_vars = { + "experiment_name": default_version.name, + "experiment_description": default_version.description, + } return TemplateResponse( request, "experiments/start_experiment_session.html", @@ -1005,6 +1022,7 @@ def start_session_from_invite(request, team_slug: str, experiment_id: str, sessi "experiment": default_version, "consent_notice": mark_safe(consent_notice), "form": form, + **version_specific_vars, }, ) @@ -1025,6 +1043,14 @@ def experiment_pre_survey(request, team_slug: str, experiment_id: str, session_i ) else: form = SurveyCompletedForm() + + default_version = request.experiment.default_version + experiment_session = request.experiment_session + version_specific_vars = { + "experiment_name": default_version.name, + "experiment_description": default_version.description, + "pre_survey_link": experiment_session.get_pre_survey_link(default_version), + } return TemplateResponse( request, "experiments/pre_survey.html", @@ -1032,7 +1058,8 @@ def experiment_pre_survey(request, team_slug: str, experiment_id: str, session_i "active_tab": "experiments", "form": form, "experiment": request.experiment, - "experiment_session": request.experiment_session, + "experiment_session": experiment_session, + **version_specific_vars, }, ) @@ -1067,6 +1094,7 @@ def experiment_review(request, team_slug: str, experiment_id: str, session_id: s form = None survey_link = None survey_text = None + experiment_version = request.experiment.default_version if request.method == "POST": # no validation needed request.experiment_session.status = SessionStatus.COMPLETE @@ -1075,11 +1103,17 @@ def experiment_review(request, team_slug: str, experiment_id: str, session_id: s return HttpResponseRedirect( reverse("experiments:experiment_complete", args=[team_slug, experiment_id, session_id]) ) - elif request.experiment.post_survey: + elif experiment_version.post_survey: form = SurveyCompletedForm() - survey_link = request.experiment_session.get_post_survey_link() - survey_text = request.experiment.post_survey.confirmation_text.format(survey_link=survey_link) - + survey_link = request.experiment_session.get_post_survey_link(experiment_version) + survey_text = experiment_version.post_survey.confirmation_text.format(survey_link=survey_link) + + version_specific_vars = { + "experiment.post_survey": experiment_version.post_survey, + "survey_link": survey_link, + "survey_text": survey_text, + "experiment_name": experiment_version.name, + } return TemplateResponse( request, "experiments/experiment_review.html", @@ -1087,10 +1121,9 @@ def experiment_review(request, team_slug: str, experiment_id: str, session_id: s "experiment": request.experiment, "experiment_session": request.experiment_session, "active_tab": "experiments", - "survey_link": survey_link, - "survey_text": survey_text, "form": form, "available_tags": [t.name for t in Tag.objects.filter(team__slug=team_slug, is_system_tag=False).all()], + **version_specific_vars, }, ) diff --git a/templates/experiments/chat/chat_ui.html b/templates/experiments/chat/chat_ui.html index be5d6b734..5042e7245 100644 --- a/templates/experiments/chat/chat_ui.html +++ b/templates/experiments/chat/chat_ui.html @@ -8,7 +8,7 @@
{% include "experiments/chat/components/system_icon.html" %}
-

Hello, you can ask me anything you want about {{ experiment.name }}.

+

Hello, you can ask me anything you want about {{ experiment_name }}.

{% endif %} diff --git a/templates/experiments/chat/input_bar.html b/templates/experiments/chat/input_bar.html index e87be5069..a09e02dd1 100644 --- a/templates/experiments/chat/input_bar.html +++ b/templates/experiments/chat/input_bar.html @@ -13,12 +13,12 @@ hx-post="{% url 'experiments:experiment_session_message' request.team.slug experiment.id session.id %}" hx-swap="outerHTML" hx-indicator="#message-submit" - {% if session.experiment.assistant %}enctype="multipart/form-data"{% endif %} + {% if experiment.assistant %}enctype="multipart/form-data"{% endif %} > {% csrf_token %}
- {% if session.experiment.assistant %} + {% if assistant %}
-{% if session.experiment.assistant %} +{% if assistant %}