diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index cc46a82ae..c5d331153 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -2,10 +2,11 @@ # Licensed under the MIT license. import abc +import copy from pathlib import Path -from typing import MutableSequence, Optional, Sequence -from uuid import uuid4 +from typing import MutableSequence, Sequence +import uuid from pyrit.common.path import RESULTS_PATH from pyrit.models import ( @@ -170,13 +171,7 @@ def get_prompt_ids_by_orchestrator(self, *, orchestrator_id: str) -> list[str]: return prompt_ids - def duplicate_conversation_for_new_orchestrator( - self, - *, - new_orchestrator_id: str, - conversation_id: str, - new_conversation_id: Optional[str] = None, - ) -> None: + def duplicate_conversation_for_new_orchestrator(self, *, new_orchestrator_id: str, conversation_id: str) -> str: """ Duplicates a conversation from one orchestrator to another. @@ -187,21 +182,21 @@ def duplicate_conversation_for_new_orchestrator( Args: new_orchestrator_id (str): The new orchestrator ID to assign to the duplicated conversations. conversation_id (str): The conversation ID with existing conversations. - new_conversation_id (str): The new conversation ID to assign to the duplicated conversations. - If no new_conversation_id is provided, a new one will be generated. + Returns: + The uuid for the new conversation. """ - new_conversation_id = new_conversation_id or str(uuid4()) - if conversation_id == new_conversation_id: - raise ValueError("The new conversation ID must be different from the existing conversation ID.") - prompt_pieces = self._get_prompt_pieces_with_conversation_id(conversation_id=conversation_id) + new_conversation_id = str(uuid.uuid4()) + # Deep copy objects to prevent any mutability-related issues that could arise due to in-memory databases. + prompt_pieces = copy.deepcopy(self._get_prompt_pieces_with_conversation_id(conversation_id=conversation_id)) for piece in prompt_pieces: - piece.id = uuid4() + piece.id = uuid.uuid4() if piece.orchestrator_identifier["id"] == new_orchestrator_id: raise ValueError("The new orchestrator ID must be different from the existing orchestrator ID.") piece.orchestrator_identifier["id"] = new_orchestrator_id piece.conversation_id = new_conversation_id self.add_request_pieces_to_memory(request_pieces=prompt_pieces) + return new_conversation_id def export_conversation_by_orchestrator_id( self, *, orchestrator_id: str, file_path: Path = None, export_type: str = "json" diff --git a/tests/memory/test_memory_interface.py b/tests/memory/test_memory_interface.py index 212c41e65..4b4cfcef0 100644 --- a/tests/memory/test_memory_interface.py +++ b/tests/memory/test_memory_interface.py @@ -62,9 +62,7 @@ def test_add_request_pieces_to_memory( assert len(memory.get_all_prompt_pieces()) == num_conversations -@pytest.mark.parametrize("new_conversation_id1", [None, "12345"]) -@pytest.mark.parametrize("new_conversation_id2", [None, "23456"]) -def test_duplicate_memory(memory: MemoryInterface, new_conversation_id1: str | None, new_conversation_id2: str | None): +def test_duplicate_memory(memory: MemoryInterface): orchestrator1 = Orchestrator() orchestrator2 = Orchestrator() conversation_id_1 = "11111" @@ -114,15 +112,13 @@ def test_duplicate_memory(memory: MemoryInterface, new_conversation_id1: str | N memory.add_request_pieces_to_memory(request_pieces=pieces) assert len(memory.get_all_prompt_pieces()) == 5 orchestrator3 = Orchestrator() - memory.duplicate_conversation_for_new_orchestrator( + new_conversation_id1 = memory.duplicate_conversation_for_new_orchestrator( new_orchestrator_id=orchestrator3.get_identifier()["id"], conversation_id=conversation_id_1, - new_conversation_id=new_conversation_id1, ) - memory.duplicate_conversation_for_new_orchestrator( + new_conversation_id2 = memory.duplicate_conversation_for_new_orchestrator( new_orchestrator_id=orchestrator3.get_identifier()["id"], conversation_id=conversation_id_2, - new_conversation_id=new_conversation_id2, ) all_pieces = memory.get_all_prompt_pieces() assert len(all_pieces) == 9 @@ -138,40 +134,15 @@ def test_duplicate_memory(memory: MemoryInterface, new_conversation_id1: str | N assert len([p for p in all_pieces if p.conversation_id == new_conversation_id2]) == 2 -def test_duplicate_memory_conversation_id_collision(memory: MemoryInterface): - orchestrator1 = Orchestrator() - orchestrator2 = Orchestrator() - conversation_id_1 = "11111" - pieces = [ - PromptRequestPiece( - role="user", - original_value="original prompt text", - converted_value="Hello, how are you?", - conversation_id=conversation_id_1, - sequence=0, - orchestrator_identifier=orchestrator1.get_identifier(), - ), - ] - memory.add_request_pieces_to_memory(request_pieces=pieces) - assert len(memory.get_all_prompt_pieces()) == 1 - with pytest.raises(ValueError): - memory.duplicate_conversation_for_new_orchestrator( - new_orchestrator_id=str(orchestrator2.get_identifier()["id"]), - conversation_id=conversation_id_1, - new_conversation_id=conversation_id_1, - ) - - def test_duplicate_memory_orchestrator_id_collision(memory: MemoryInterface): orchestrator1 = Orchestrator() - conversation_id_1 = "11111" - conversation_id_2 = "22222" + conversation_id = "11111" pieces = [ PromptRequestPiece( role="user", original_value="original prompt text", converted_value="Hello, how are you?", - conversation_id=conversation_id_1, + conversation_id=conversation_id, sequence=0, orchestrator_identifier=orchestrator1.get_identifier(), ), @@ -181,8 +152,7 @@ def test_duplicate_memory_orchestrator_id_collision(memory: MemoryInterface): with pytest.raises(ValueError): memory.duplicate_conversation_for_new_orchestrator( new_orchestrator_id=str(orchestrator1.get_identifier()["id"]), - conversation_id=conversation_id_1, - new_conversation_id=conversation_id_2, + conversation_id=conversation_id, )