Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT return ID in conversation duplication code #296

Merged
merged 6 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# Licensed under the MIT license.

import abc
import copy
from pathlib import Path

from typing import MutableSequence, Optional, Sequence
from uuid import uuid4
from typing import MutableSequence, Sequence
import uuid

from pyrit.common.path import RESULTS_PATH
from pyrit.models import (
Expand Down Expand Up @@ -170,13 +171,7 @@ def get_prompt_ids_by_orchestrator(self, *, orchestrator_id: str) -> list[str]:

return prompt_ids

def duplicate_conversation_for_new_orchestrator(
self,
*,
new_orchestrator_id: str,
conversation_id: str,
new_conversation_id: Optional[str] = None,
) -> None:
def duplicate_conversation_for_new_orchestrator(self, *, new_orchestrator_id: str, conversation_id: str) -> str:
"""
Duplicates a conversation from one orchestrator to another.

Expand All @@ -187,21 +182,21 @@ def duplicate_conversation_for_new_orchestrator(
Args:
new_orchestrator_id (str): The new orchestrator ID to assign to the duplicated conversations.
conversation_id (str): The conversation ID with existing conversations.
new_conversation_id (str): The new conversation ID to assign to the duplicated conversations.
If no new_conversation_id is provided, a new one will be generated.
Returns:
The uuid for the new conversation.
"""
new_conversation_id = new_conversation_id or str(uuid4())
if conversation_id == new_conversation_id:
raise ValueError("The new conversation ID must be different from the existing conversation ID.")
prompt_pieces = self._get_prompt_pieces_with_conversation_id(conversation_id=conversation_id)
new_conversation_id = str(uuid.uuid4())
# Deep copy objects to prevent any mutability-related issues that could arise due to in-memory databases.
prompt_pieces = copy.deepcopy(self._get_prompt_pieces_with_conversation_id(conversation_id=conversation_id))
for piece in prompt_pieces:
piece.id = uuid4()
piece.id = uuid.uuid4()
if piece.orchestrator_identifier["id"] == new_orchestrator_id:
raise ValueError("The new orchestrator ID must be different from the existing orchestrator ID.")
piece.orchestrator_identifier["id"] = new_orchestrator_id
piece.conversation_id = new_conversation_id

self.add_request_pieces_to_memory(request_pieces=prompt_pieces)
return new_conversation_id

def export_conversation_by_orchestrator_id(
self, *, orchestrator_id: str, file_path: Path = None, export_type: str = "json"
Expand Down
42 changes: 6 additions & 36 deletions tests/memory/test_memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def test_add_request_pieces_to_memory(
assert len(memory.get_all_prompt_pieces()) == num_conversations


@pytest.mark.parametrize("new_conversation_id1", [None, "12345"])
@pytest.mark.parametrize("new_conversation_id2", [None, "23456"])
def test_duplicate_memory(memory: MemoryInterface, new_conversation_id1: str | None, new_conversation_id2: str | None):
def test_duplicate_memory(memory: MemoryInterface):
orchestrator1 = Orchestrator()
orchestrator2 = Orchestrator()
conversation_id_1 = "11111"
Expand Down Expand Up @@ -114,15 +112,13 @@ def test_duplicate_memory(memory: MemoryInterface, new_conversation_id1: str | N
memory.add_request_pieces_to_memory(request_pieces=pieces)
assert len(memory.get_all_prompt_pieces()) == 5
orchestrator3 = Orchestrator()
memory.duplicate_conversation_for_new_orchestrator(
new_conversation_id1 = memory.duplicate_conversation_for_new_orchestrator(
new_orchestrator_id=orchestrator3.get_identifier()["id"],
conversation_id=conversation_id_1,
new_conversation_id=new_conversation_id1,
)
memory.duplicate_conversation_for_new_orchestrator(
new_conversation_id2 = memory.duplicate_conversation_for_new_orchestrator(
new_orchestrator_id=orchestrator3.get_identifier()["id"],
conversation_id=conversation_id_2,
new_conversation_id=new_conversation_id2,
)
all_pieces = memory.get_all_prompt_pieces()
assert len(all_pieces) == 9
Expand All @@ -138,40 +134,15 @@ def test_duplicate_memory(memory: MemoryInterface, new_conversation_id1: str | N
assert len([p for p in all_pieces if p.conversation_id == new_conversation_id2]) == 2


def test_duplicate_memory_conversation_id_collision(memory: MemoryInterface):
orchestrator1 = Orchestrator()
orchestrator2 = Orchestrator()
conversation_id_1 = "11111"
pieces = [
PromptRequestPiece(
role="user",
original_value="original prompt text",
converted_value="Hello, how are you?",
conversation_id=conversation_id_1,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
]
memory.add_request_pieces_to_memory(request_pieces=pieces)
assert len(memory.get_all_prompt_pieces()) == 1
with pytest.raises(ValueError):
memory.duplicate_conversation_for_new_orchestrator(
new_orchestrator_id=str(orchestrator2.get_identifier()["id"]),
conversation_id=conversation_id_1,
new_conversation_id=conversation_id_1,
)


def test_duplicate_memory_orchestrator_id_collision(memory: MemoryInterface):
orchestrator1 = Orchestrator()
conversation_id_1 = "11111"
conversation_id_2 = "22222"
conversation_id = "11111"
pieces = [
PromptRequestPiece(
role="user",
original_value="original prompt text",
converted_value="Hello, how are you?",
conversation_id=conversation_id_1,
conversation_id=conversation_id,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
Expand All @@ -181,8 +152,7 @@ def test_duplicate_memory_orchestrator_id_collision(memory: MemoryInterface):
with pytest.raises(ValueError):
memory.duplicate_conversation_for_new_orchestrator(
new_orchestrator_id=str(orchestrator1.get_identifier()["id"]),
conversation_id=conversation_id_1,
new_conversation_id=conversation_id_2,
conversation_id=conversation_id,
)


Expand Down
Loading