Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT generalize XPIA orchestrator #163

Merged
merged 15 commits into from
Apr 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
985 changes: 491 additions & 494 deletions doc/demo/5_xpia.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions doc/demo/5_xpia.py
Original file line number Diff line number Diff line change
@@ -77,7 +77,7 @@
from pyrit.score import SubStringScorer
from pyrit.orchestrator import XPIATestOrchestrator

abs_prompt_target = AzureBlobStorageTarget(
abs_target = AzureBlobStorageTarget(
container_url=os.environ.get("AZURE_STORAGE_ACCOUNT_CONTAINER_URL"),
sas_token=os.environ.get("AZURE_STORAGE_ACCOUNT_SAS_TOKEN"),
)
@@ -88,12 +88,12 @@
attack_content=jailbreak_prompt,
processing_prompt=processing_prompt_template,
processing_target=processing_target,
prompt_target=abs_prompt_target,
attack_setup_target=abs_target,
scorer=scorer,
verbose=True,
)

score = xpia_orchestrator.process()
score = xpia_orchestrator.execute()
print(score)

# clean up storage container
2 changes: 1 addition & 1 deletion pyrit/memory/memory_embedding.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
# Licensed under the MIT license.

import os
from pyrit.embedding.azure_text_embedding import AzureTextEmbedding
from pyrit.embedding import AzureTextEmbedding
from pyrit.interfaces import EmbeddingSupport
from pyrit.memory.memory_models import EmbeddingData, PromptRequestPiece

2 changes: 1 addition & 1 deletion pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.models import ChatMessage
from pyrit.common.path import RESULTS_PATH
from pyrit.models.prompt_request_response import group_conversation_request_pieces_by_sequence
from pyrit.models import group_conversation_request_pieces_by_sequence


class MemoryInterface(abc.ABC):
11 changes: 9 additions & 2 deletions pyrit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,15 @@
from pyrit.models.models import * # noqa: F403, F401

from pyrit.models.prompt_request_piece import PromptRequestPiece, PromptResponseError, PromptDataType
from pyrit.models.prompt_request_response import PromptRequestResponse
from pyrit.models.prompt_request_response import PromptRequestResponse, group_conversation_request_pieces_by_sequence
from pyrit.models.identifiers import Identifier


__all__ = ["PromptRequestPiece", "PromptResponseError", "PromptDataType", "PromptRequestResponse", "Identifier"]
__all__ = [
"PromptRequestPiece",
"PromptResponseError",
"PromptDataType",
"PromptRequestResponse",
"Identifier",
"group_conversation_request_pieces_by_sequence",
]
8 changes: 7 additions & 1 deletion pyrit/orchestrator/__init__.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,11 @@
from pyrit.orchestrator.red_teaming_orchestrator import RedTeamingOrchestrator
from pyrit.orchestrator.end_token_red_teaming_orchestrator import EndTokenRedTeamingOrchestrator
from pyrit.orchestrator.scoring_red_teaming_orchestrator import ScoringRedTeamingOrchestrator
from pyrit.orchestrator.xpia_orchestrator import XPIATestOrchestrator
from pyrit.orchestrator.xpia_orchestrator import (
XPIATestOrchestrator,
XPIAOrchestrator,
XPIAManualProcessingOrchestrator,
)

__all__ = [
"Orchestrator",
@@ -15,4 +19,6 @@
"EndTokenRedTeamingOrchestrator",
"ScoringRedTeamingOrchestrator",
"XPIATestOrchestrator",
"XPIAOrchestrator",
"XPIAManualProcessingOrchestrator",
]
2 changes: 1 addition & 1 deletion pyrit/orchestrator/orchestrator_class.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from pyrit.memory import MemoryInterface, DuckDBMemory
from pyrit.models import PromptDataType, Identifier
from pyrit.prompt_converter import PromptConverter
from pyrit.prompt_normalizer.normalizer_request import NormalizerRequest, NormalizerRequestPiece
from pyrit.prompt_normalizer import NormalizerRequest, NormalizerRequestPiece

logger = logging.getLogger(__name__)

3 changes: 1 addition & 2 deletions pyrit/orchestrator/red_teaming_orchestrator.py
Original file line number Diff line number Diff line change
@@ -9,8 +9,7 @@
from pyrit.memory import MemoryInterface
from pyrit.models import AttackStrategy, ChatMessage
from pyrit.orchestrator import Orchestrator
from pyrit.prompt_normalizer import NormalizerRequestPiece, PromptNormalizer
from pyrit.prompt_normalizer.normalizer_request import NormalizerRequest
from pyrit.prompt_normalizer import NormalizerRequestPiece, PromptNormalizer, NormalizerRequest
from pyrit.prompt_target import PromptTarget, PromptChatTarget
from pyrit.prompt_converter import PromptConverter

168 changes: 142 additions & 26 deletions pyrit/orchestrator/xpia_orchestrator.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
# Licensed under the MIT license.

import logging
from typing import Optional
from typing import Callable, Optional, Union
from uuid import uuid4
from pyrit.score import SupportTextClassification

@@ -16,59 +16,65 @@
logger = logging.getLogger(__name__)


class XPIATestOrchestrator(Orchestrator):
class XPIAOrchestrator(Orchestrator):
_memory: MemoryInterface

def __init__(
self,
*,
attack_content: str,
processing_prompt: str,
processing_target: PromptTarget,
prompt_target: PromptTarget,
scorer: SupportTextClassification,
attack_setup_target: PromptTarget,
processing_callback: Callable[[], str],
scorer: Optional[SupportTextClassification] = None,
prompt_converters: Optional[list[PromptConverter]] = None,
memory: Optional[MemoryInterface] = None,
memory_labels: dict[str, str] = None,
verbose: bool = False,
prompt_target_conversation_id: Optional[str] = None,
attack_setup_target_conversation_id: Optional[str] = None,
) -> None:
"""Creates an orchestrator to set up a cross-domain prompt injection attack (XPIA) on a processing target.

The prompt_target creates the attack prompt using the attack_content,
The attack_setup_target creates the attack prompt using the attack_content,
applies converters (if any), and puts it into the attack location.
The processing_target processes the processing_prompt which should include
a reference to the attack prompt to allow it to retrieve the attack prompt.
Then, the processing_callback is executed.
The scorer scores the processing response to determine the success of the attack.

Args:
attack_content: The content to attack the processing target with, e.g., a jailbreak.
processing_prompt: The prompt to send to the processing target. This should include
placeholders to invoke plugins (if any).
processing_target: The target of the attack which processes the processing prompt.
prompt_target: The target that generates the attack prompt and gets it into the attack location.
scorer: The scorer to use to score the processing response.
attack_setup_target: The target that generates the attack prompt and gets it into the attack location.
processing_callback: The callback to execute after the attack prompt is positioned in the attack location.
This is generic on purpose to allow for flexibility.
The callback should return the processing response.
scorer: The scorer to use to score the processing response. This is optional.
If no scorer is provided the orchestrator will skip scoring.
prompt_converters: The converters to apply to the attack content before sending it to the prompt target.
memory: The memory to use to store the chat messages. If not provided, a DuckDBMemory will be used.
memory_labels: The labels to use for the memory. This is useful to identify the bot messages in the memory.
verbose: Whether to print debug information.
attack_setup_target_conversation_id: The conversation ID to use for the prompt target.
If not provided, a new one will be generated.
"""
super().__init__(
prompt_converters=prompt_converters, memory=memory, memory_labels=memory_labels, verbose=verbose
)

self._prompt_target = prompt_target
self._processing_target = processing_target
self._attack_setup_target = attack_setup_target
self._processing_callback = processing_callback
self._scorer = scorer

self._prompt_normalizer = PromptNormalizer(memory=self._memory)
self._prompt_target._memory = self._memory
self._prompt_target_conversation_id = prompt_target_conversation_id or str(uuid4())
self._attack_setup_target._memory = self._memory
self._attack_setup_target_conversation_id = attack_setup_target_conversation_id or str(uuid4())
self._processing_conversation_id = str(uuid4())
self._attack_content = str(attack_content)
self._processing_prompt = processing_prompt

def process(self) -> Score:
def execute(self) -> Union[Score, None]:
"""Executes the entire XPIA operation.

This method sends the attack content to the prompt target, processes the response
using the processing callback, and scores the processing response using the scorer.
If no scorer was provided, the method will skip scoring.
"""
logger.info(
"Sending the following prompt to the prompt target (after applying prompt "
f'converter operations) "{self._attack_content}"',
@@ -80,13 +86,79 @@ def process(self) -> Score:

response = self._prompt_normalizer.send_prompt(
normalizer_request=target_request,
target=self._prompt_target,
target=self._attack_setup_target,
labels=self._global_memory_labels,
orchestrator_identifier=self.get_identifier(),
)

logger.info(f'Received the following response from the prompt target "{response}"')

processing_response = self._processing_callback()

logger.info(f'Received the following response from the processing target "{processing_response}"')

if not self._scorer:
logger.info("No scorer provided, skipping scoring")
return None
score = self._scorer.score_text(processing_response)
logger.info(f"Score of the processing response: {score}")
return score


class XPIATestOrchestrator(XPIAOrchestrator):
def __init__(
self,
*,
attack_content: str,
processing_prompt: str,
processing_target: PromptTarget,
attack_setup_target: PromptTarget,
scorer: SupportTextClassification,
prompt_converters: Optional[list[PromptConverter]] = None,
memory: Optional[MemoryInterface] = None,
memory_labels: dict[str, str] = None,
verbose: bool = False,
attack_setup_target_conversation_id: Optional[str] = None,
) -> None:
"""Creates an orchestrator to set up a cross-domain prompt injection attack (XPIA) on a processing target.

The attack_setup_target creates the attack prompt using the attack_content,
applies converters (if any), and puts it into the attack location.
The processing_target processes the processing_prompt which should include
a reference to the attack prompt to allow it to retrieve the attack prompt.
The scorer scores the processing response to determine the success of the attack.

Args:
attack_content: The content to attack the processing target with, e.g., a jailbreak.
processing_prompt: The prompt to send to the processing target. This should include
placeholders to invoke plugins (if any).
processing_target: The target of the attack which processes the processing prompt.
attack_setup_target: The target that generates the attack prompt and gets it into the attack location.
scorer: The scorer to use to score the processing response.
prompt_converters: The converters to apply to the attack content before sending it to the prompt target.
memory: The memory to use to store the chat messages. If not provided, a DuckDBMemory will be used.
memory_labels: The labels to use for the memory. This is useful to identify the bot messages in the memory.
verbose: Whether to print debug information.
attack_setup_target_conversation_id: The conversation ID to use for the prompt target.
If not provided, a new one will be generated.
"""
super().__init__(
attack_content=attack_content,
attack_setup_target=attack_setup_target,
scorer=scorer,
processing_callback=self._process,
prompt_converters=prompt_converters,
memory=memory,
memory_labels=memory_labels,
verbose=verbose,
attack_setup_target_conversation_id=attack_setup_target_conversation_id,
)

self._processing_target = processing_target
self._processing_conversation_id = str(uuid4())
self._processing_prompt = processing_prompt

def _process(self) -> str:
processing_prompt_req = self._create_normalizer_request(
converters=[], prompt_text=self._processing_prompt, prompt_type="text"
)
@@ -98,8 +170,52 @@ def process(self) -> Score:
orchestrator_identifier=self.get_identifier(),
)

logger.info(f'Received the following response from the processing target "{processing_response}"')
return processing_response.request_pieces[0].converted_prompt_text

score = self._scorer.score_text(processing_response.request_pieces[0].converted_prompt_text)
logger.info(f"Score of the processing response: {score}")
return score

class XPIAManualProcessingOrchestrator(XPIAOrchestrator):
def __init__(
self,
*,
attack_content: str,
attack_setup_target: PromptTarget,
scorer: SupportTextClassification,
prompt_converters: Optional[list[PromptConverter]] = None,
memory: Optional[MemoryInterface] = None,
memory_labels: dict[str, str] = None,
verbose: bool = False,
attack_setup_target_conversation_id: Optional[str] = None,
) -> None:
"""Creates an orchestrator to set up a cross-domain prompt injection attack (XPIA) on a processing target.

The attack_setup_target creates the attack prompt using the attack_content,
applies converters (if any), and puts it into the attack location.
Then, the orchestrator stops to wait for the operator to trigger the processing target's execution.
The operator should paste the output of the processing target into the console.
Finally, the scorer scores the processing response to determine the success of the attack.

Args:
attack_content: The content to attack the processing target with, e.g., a jailbreak.
attack_setup_target: The target that generates the attack prompt and gets it into the attack location.
scorer: The scorer to use to score the processing response.
prompt_converters: The converters to apply to the attack content before sending it to the prompt target.
memory: The memory to use to store the chat messages. If not provided, a DuckDBMemory will be used.
memory_labels: The labels to use for the memory. This is useful to identify the bot messages in the memory.
verbose: Whether to print debug information.
attack_setup_target_conversation_id: The conversation ID to use for the prompt target.
If not provided, a new one will be generated.
"""
super().__init__(
attack_content=attack_content,
attack_setup_target=attack_setup_target,
scorer=scorer,
processing_callback=self._input,
prompt_converters=prompt_converters,
memory=memory,
memory_labels=memory_labels,
verbose=verbose,
attack_setup_target_conversation_id=attack_setup_target_conversation_id,
)

def _input(self):
return input("Please trigger the processing target's execution and paste the output here: ")
3 changes: 1 addition & 2 deletions pyrit/prompt_converter/prompt_converter.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,7 @@
import abc
from dataclasses import dataclass

from pyrit.models import PromptDataType
from pyrit.models.identifiers import Identifier
from pyrit.models import PromptDataType, Identifier


@dataclass
70 changes: 61 additions & 9 deletions tests/orchestrator/test_xpia_orchestrator.py
Original file line number Diff line number Diff line change
@@ -3,12 +3,16 @@

from typing import Generator
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.orchestrator.xpia_orchestrator import XPIATestOrchestrator
from pyrit.orchestrator import (
XPIATestOrchestrator,
XPIAOrchestrator,
XPIAManualProcessingOrchestrator,
)
import pytest

from unittest.mock import Mock, patch

from pyrit.prompt_target.prompt_target import PromptTarget
from pyrit.prompt_target import PromptTarget
from pyrit.score import Score, SupportTextClassification
from tests.mocks import get_memory_interface, MockPromptTarget

@@ -19,7 +23,7 @@ def memory_interface() -> Generator[MemoryInterface, None, None]:


@pytest.fixture
def prompt_target(memory_interface) -> PromptTarget:
def attack_setup_target(memory_interface) -> PromptTarget:
return MockPromptTarget(memory=memory_interface)


@@ -35,33 +39,81 @@ def success_scorer() -> SupportTextClassification:
return scorer


def test_xpia_orchestrator_process(prompt_target, processing_target, success_scorer):
def test_xpia_orchestrator_execute_no_scorer(attack_setup_target):
def processing_callback():
return_request_response_obj = Mock()
return_response_piece = Mock()
return_response_piece.converted_prompt_text = "test"
return_request_response_obj.request_pieces = [return_response_piece]
return return_request_response_obj

xpia_orchestrator = XPIAOrchestrator(
attack_content="test",
attack_setup_target=attack_setup_target,
processing_callback=processing_callback,
)
assert xpia_orchestrator.execute() is None


def test_xpia_orchestrator_execute(attack_setup_target, success_scorer):
def processing_callback():
return_request_response_obj = Mock()
return_response_piece = Mock()
return_response_piece.converted_prompt_text = "test"
return_request_response_obj.request_pieces = [return_response_piece]
return return_request_response_obj

xpia_orchestrator = XPIAOrchestrator(
attack_content="test",
attack_setup_target=attack_setup_target,
scorer=success_scorer,
processing_callback=processing_callback,
)
score = xpia_orchestrator.execute()
assert score.score_value
assert success_scorer.score_text.called_once


def test_xpia_manual_processing_orchestrator_execute(attack_setup_target, success_scorer, monkeypatch):
# Mocking user input to be "test"
monkeypatch.setattr("builtins.input", lambda _: "test")
xpia_orchestrator = XPIAManualProcessingOrchestrator(
attack_content="test",
attack_setup_target=attack_setup_target,
scorer=success_scorer,
)
score = xpia_orchestrator.execute()
assert score.score_value
assert success_scorer.score_text.called_once


def test_xpia_test_orchestrator_execute(attack_setup_target, processing_target, success_scorer):
with patch.object(processing_target, "send_prompt") as mock_send_to_processing_target:
xpia_orchestrator = XPIATestOrchestrator(
attack_content="test",
processing_prompt="some instructions and the required <test>",
processing_target=processing_target,
prompt_target=prompt_target,
attack_setup_target=attack_setup_target,
scorer=success_scorer,
)
score = xpia_orchestrator.process()
score = xpia_orchestrator.execute()
assert score.score_value
assert success_scorer.score_text.called_once
assert mock_send_to_processing_target.called_once


def test_xpia_orchestrator_process_async(prompt_target, processing_target, success_scorer):
def test_xpia_orchestrator_process_async(attack_setup_target, processing_target, success_scorer):
with patch.object(processing_target, "send_prompt") as mock_send_to_processing_target:
with patch.object(processing_target, "send_prompt_async") as mock_send_async_to_processing_target:
mock_send_to_processing_target.side_effect = NotImplementedError()
xpia_orchestrator = XPIATestOrchestrator(
attack_content="test",
processing_prompt="some instructions and the required <test>",
processing_target=processing_target,
prompt_target=prompt_target,
attack_setup_target=attack_setup_target,
scorer=success_scorer,
)
score = xpia_orchestrator.process()
score = xpia_orchestrator.execute()
assert score.score_value
assert success_scorer.score_text.called_once
assert mock_send_to_processing_target.called_once