diff --git a/doc/code/scoring/2_azure_content_safety_scorers.ipynb b/doc/code/scoring/2_azure_content_safety_scorers.ipynb index 07e173f1d..6fce16bee 100644 --- a/doc/code/scoring/2_azure_content_safety_scorers.ipynb +++ b/doc/code/scoring/2_azure_content_safety_scorers.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "9868c425", + "id": "7f1cbefc", "metadata": {}, "source": [ "\n", @@ -13,7 +13,9 @@ "In order to use this API, you need to configure a few environment variables:\n", "\n", "- AZURE_CONTENT_SAFETY_API_ENDPOINT: The endpoint for the Azure Content Safety API\n", - "- AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API\n", + "- AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API (if not using AAD Auth)\n", + "\n", + "As an alternative to key-based authentication, you may set `use_aad_auth=True` and use identity-based authentication.\n", "\n", "Note that this api returns a value between 0 and 7. This is different from likert scales, which return a value between 1 and 5. Because both are `float_scale` scores, these values are all normalized to floating point values between 0.0 and 1.0 and can be directly compared. This is sometimes interesting as an operator e.g. if there are scenarios where a `SelfAskLikertScorer` and `AzureContentFilterScorer` produce very different values.\n", "\n", @@ -22,29 +24,10 @@ }, { "cell_type": "code", - "execution_count": 1, - "id": "5adb7d1d", - "metadata": { - "execution": { - "iopub.execute_input": "2024-06-06T18:25:17.882704Z", - "iopub.status.busy": "2024-06-06T18:25:17.882704Z", - "iopub.status.idle": "2024-06-06T18:25:23.864123Z", - "shell.execute_reply": "2024-06-06T18:25:23.864123Z" - }, - "lines_to_next_cell": 2 - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "AzureContentFilterScorer: Hate: 0.2857142857142857 {'azure_severity': '2'}\n", - "AzureContentFilterScorer: SelfHarm: 0.0 {'azure_severity': '0'}\n", - "AzureContentFilterScorer: Sexual: 0.0 {'azure_severity': '0'}\n", - "AzureContentFilterScorer: Violence: 0.0 {'azure_severity': '0'}\n" - ] - } - ], + "execution_count": null, + "id": "53426524", + "metadata": {}, + "outputs": [], "source": [ "import os\n", "from pyrit.score import AzureContentFilterScorer\n", @@ -58,7 +41,9 @@ "\n", "# Set up the Azure Content Filter\n", "azure_content_filter = AzureContentFilterScorer(\n", + " # Comment out either api_key or use_aad_auth\n", " api_key=os.environ.get(\"AZURE_CONTENT_SAFETY_API_KEY\"),\n", + " # use_aad_auth=True,\n", " endpoint=os.environ.get(\"AZURE_CONTENT_SAFETY_API_ENDPOINT\"),\n", " memory=memory,\n", ")\n", @@ -72,13 +57,22 @@ "# need to write it manually to memory as score table has a foreign key constraint\n", "memory.add_request_response_to_memory(request=PromptRequestResponse([response]))\n", "\n", + "# Run the request\n", "scores = await azure_content_filter.score_async(response) # type: ignore\n", - "assert scores[0].get_value() > 0 # azure_severity should be value 2 base on the documentation\n", + "assert scores[0].get_value() > 0 # azure_severity should be value 2 based on the documentation\n", "\n", "for score in scores:\n", " # score_metadata contains azure_severity original value\n", " print(f\"{score} {score.score_metadata}\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cb0dc4e", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -89,18 +83,6 @@ "display_name": "pyrit-311", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" } }, "nbformat": 4, diff --git a/doc/code/scoring/2_azure_content_safety_scorers.py b/doc/code/scoring/2_azure_content_safety_scorers.py index acd577c16..2f92bffb4 100644 --- a/doc/code/scoring/2_azure_content_safety_scorers.py +++ b/doc/code/scoring/2_azure_content_safety_scorers.py @@ -22,7 +22,9 @@ # In order to use this API, you need to configure a few environment variables: # # - AZURE_CONTENT_SAFETY_API_ENDPOINT: The endpoint for the Azure Content Safety API -# - AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API +# - AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API (if not using AAD Auth) +# +# As an alternative to key-based authentication, you may set `use_aad_auth=True` and use identity-based authentication. # # Note that this api returns a value between 0 and 7. This is different from likert scales, which return a value between 1 and 5. Because both are `float_scale` scores, these values are all normalized to floating point values between 0.0 and 1.0 and can be directly compared. This is sometimes interesting as an operator e.g. if there are scenarios where a `SelfAskLikertScorer` and `AzureContentFilterScorer` produce very different values. # @@ -41,7 +43,9 @@ # Set up the Azure Content Filter azure_content_filter = AzureContentFilterScorer( + # Comment out either api_key or use_aad_auth api_key=os.environ.get("AZURE_CONTENT_SAFETY_API_KEY"), + # use_aad_auth=True, endpoint=os.environ.get("AZURE_CONTENT_SAFETY_API_ENDPOINT"), memory=memory, ) @@ -55,9 +59,12 @@ # need to write it manually to memory as score table has a foreign key constraint memory.add_request_response_to_memory(request=PromptRequestResponse([response])) +# Run the request scores = await azure_content_filter.score_async(response) # type: ignore -assert scores[0].get_value() > 0 # azure_severity should be value 2 base on the documentation +assert scores[0].get_value() > 0 # azure_severity should be value 2 based on the documentation for score in scores: # score_metadata contains azure_severity original value print(f"{score} {score.score_metadata}") + +# %% diff --git a/pyrit/score/azure_content_filter_scorer.py b/pyrit/score/azure_content_filter_scorer.py index 99aff034e..24ae99977 100644 --- a/pyrit/score/azure_content_filter_scorer.py +++ b/pyrit/score/azure_content_filter_scorer.py @@ -12,6 +12,7 @@ from azure.ai.contentsafety.models import AnalyzeTextOptions, AnalyzeImageOptions, TextCategory, ImageData from azure.ai.contentsafety import ContentSafetyClient from azure.core.credentials import AzureKeyCredential +from azure.identity import DefaultAzureCredential # Supported image formats for Azure as per https://learn.microsoft.com/en-us/azure/ai-services/content-safety/ @@ -38,6 +39,7 @@ def __init__( *, endpoint: str = None, api_key: str = None, + use_aad_auth: bool = False, harm_categories: list[TextCategory] = None, memory: MemoryInterface = None, ) -> None: @@ -51,6 +53,8 @@ def __init__( Defaults to the API_KEY_ENVIRONMENT_VARIABLE environment variable. endpoint (str, optional): The endpoint URL for the Azure OpenAI service. Defaults to the ENDPOINT_URI_ENVIRONMENT_VARIABLE environment variable. + use_aad_auth (bool, optional): Attempt to use DefaultAzureCredential + If set to true, attempt to use DefaultAzureCredential for auth harm_categories: The harm categories you want to query for as per defined in azure.ai.contentsafety.models.TextCategory. """ @@ -60,17 +64,26 @@ def __init__( else: self._score_categories = [category.value for category in TextCategory] - self._api_key = default_values.get_required_value( - env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key - ) self._endpoint = default_values.get_required_value( env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint ) + if not use_aad_auth: + self._api_key = default_values.get_required_value( + env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key + ) + else: + if api_key: + raise ValueError("Please specify either use_add_auth or api_key") + else: + self._api_key = None + if self._api_key is not None and self._endpoint is not None: self._azure_cf_client = ContentSafetyClient(self._endpoint, AzureKeyCredential(self._api_key)) + elif use_aad_auth and self._endpoint is not None: + self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=DefaultAzureCredential()) else: - raise ValueError("Please provide the Azure Content Safety API key and endpoint") + raise ValueError("Please provide the Azure Content Safety endpoint") async def score_async(self, request_response: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]: """Evaluating the input text or image using the Azure Content Filter API