diff --git a/apps/channels/tests/test_web_channel.py b/apps/channels/tests/test_web_channel.py index 1c2aaad95..9a6261461 100644 --- a/apps/channels/tests/test_web_channel.py +++ b/apps/channels/tests/test_web_channel.py @@ -4,6 +4,7 @@ from apps.channels.models import ChannelPlatform from apps.chat.channels import WebChannel +from apps.chat.models import Chat @pytest.mark.django_db() @@ -51,3 +52,15 @@ def test_start_new_session_uses_default_version(self, check_and_process_seed_mes _session_used, experiment_used = check_and_process_seed_message.call_args[0] assert experiment_used == new_version assert session.experiment == experiment + assert session.chat.metadata.get(Chat.MetadataKeys.EXPERIMENT_VERSION) == "default" + + @patch("apps.events.tasks.enqueue_static_triggers", Mock()) + @patch("apps.chat.channels.WebChannel.check_and_process_seed_message") + def test_start_new_session_uses_specified_version(self, check_and_process_seed_message, experiment): + new_version = experiment.create_new_version() + session = WebChannel.start_new_session(experiment, "jack@titanic.com", version=1) + + _session_used, experiment_used = check_and_process_seed_message.call_args[0] + assert experiment_used == new_version + assert session.experiment == experiment + assert session.chat.metadata.get(Chat.MetadataKeys.EXPERIMENT_VERSION) == 1 diff --git a/apps/experiments/tests/test_models.py b/apps/experiments/tests/test_models.py index 623c021b1..a257873d3 100644 --- a/apps/experiments/tests/test_models.py +++ b/apps/experiments/tests/test_models.py @@ -6,6 +6,7 @@ from django.utils import timezone from freezegun import freeze_time +from apps.chat.models import Chat from apps.events.actions import ScheduleTriggerAction from apps.events.models import EventActionType, ScheduledMessage, TimePeriod from apps.experiments.models import Experiment, ExperimentRoute, ParticipantData, SafetyLayer, SyntheticVoice @@ -98,6 +99,7 @@ def test_get_for_team_do_not_include_other_team_exclusive_voices(self): assert voice1 not in voices_queryset +@pytest.mark.django_db() class TestExperimentSession: def _construct_event_action(self, time_period: TimePeriod, experiment_id: int, frequency=1, repetitions=1) -> tuple: params = self._get_params(experiment_id, time_period, frequency, repetitions) @@ -113,7 +115,6 @@ def _get_params(self, experiment_id: int, time_period: TimePeriod = TimePeriod.D "experiment_id": experiment_id, } - @pytest.mark.django_db() @freeze_time("2024-01-01") def test_get_participant_scheduled_messages_custom_params(self): session = ExperimentSessionFactory() @@ -168,7 +169,6 @@ def _make_expected_dict(external_id): ] assert participant.get_schedules_for_experiment(experiment, as_dict=True) == expected_dict_version - @pytest.mark.django_db() @pytest.mark.parametrize( ("repetitions", "total_triggers", "expected_triggers_remaining"), [ @@ -202,7 +202,6 @@ def test_get_schedules_for_experiment_as_dict(self, repetitions, total_triggers, assert schedule["total_triggers"] == total_triggers assert schedule["triggers_remaining"] == expected_triggers_remaining - @pytest.mark.django_db() @freeze_time("2024-01-01") @pytest.mark.parametrize( ("time_period", "repetitions", "total_triggers", "expected"), @@ -281,7 +280,6 @@ def test_get_schedules_for_experiment_as_string(self, time_period, repetitions, next_trigger = "Next trigger is at Monday, 01 January 2024 00:00:00 UTC." assert schedule == expected.format(message=message, next_trigger=next_trigger) - @pytest.mark.django_db() def test_get_participant_scheduled_messages_includes_child_experiments(self): session = ExperimentSessionFactory() team = session.team @@ -297,7 +295,6 @@ def test_get_participant_scheduled_messages_includes_child_experiments(self): assert len(participant.get_schedules_for_experiment(session2.experiment)) == 1 assert len(participant.get_schedules_for_experiment(session.experiment)) == 2 - @pytest.mark.django_db() @pytest.mark.parametrize("use_custom_experiment", [False, True]) def test_scheduled_message_experiment(self, use_custom_experiment): """ScheduledMessages should use the experiment specified in the linked action's params""" @@ -343,7 +340,6 @@ def test_should_mark_complete(self, repetitions, total_triggers, end_date, expec ) assert scheduled_message._should_mark_complete() == expected - @pytest.mark.django_db() def test_get_participant_data_name(self): participant = ParticipantFactory() session = ExperimentSessionFactory(participant=participant, team=participant.team) @@ -369,7 +365,6 @@ def test_get_participant_data_name(self): "first_name": "Jimmy", } - @pytest.mark.django_db() @freeze_time("2022-01-01 08:00:00") @pytest.mark.parametrize("use_participant_tz", [False, True]) def test_get_participant_data_timezone(self, use_participant_tz): @@ -400,7 +395,6 @@ def test_get_participant_data_timezone(self, use_participant_tz): participant_data.pop("scheduled_messages") assert participant_data == expected_data - @pytest.mark.django_db() @pytest.mark.parametrize("fail_silently", [True, False]) @patch("apps.chat.channels.ChannelBase.from_experiment_session") @patch("apps.chat.bots.TopicBot.process_input") @@ -425,6 +419,17 @@ def _test(): else: _test() + @pytest.mark.parametrize( + ("chat_metadata_version", "expected_display_val"), + [ + ("default", "Default version"), + ("1", "v1"), + ], + ) + def test_experiment_version_for_display(self, chat_metadata_version, expected_display_val, experiment_session): + experiment_session.chat.set_metadata(Chat.MetadataKeys.EXPERIMENT_VERSION, chat_metadata_version) + assert experiment_session.experiment_version_for_display == expected_display_val + class TestParticipant: @pytest.mark.django_db()