Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Allow AAD Auth for AzureContentFilterScorer #455

Merged
merged 9 commits into from
Oct 15, 2024
36 changes: 24 additions & 12 deletions doc/code/scoring/2_azure_content_safety_scorers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "9868c425",
"id": "2faca040",
"metadata": {},
"source": [
"\n",
Expand All @@ -13,7 +13,9 @@
"In order to use this API, you need to configure a few environment variables:\n",
"\n",
"- AZURE_CONTENT_SAFETY_API_ENDPOINT: The endpoint for the Azure Content Safety API\n",
"- AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API\n",
"- AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API (if not using AAD Auth)\n",
"\n",
"As an alternative to key-based authentication, you may set `use_aad_auth=True` and use identity-based authentication.\n",
"\n",
"Note that this api returns a value between 0 and 7. This is different from likert scales, which return a value between 1 and 5. Because both are `float_scale` scores, these values are all normalized to floating point values between 0.0 and 1.0 and can be directly compared. This is sometimes interesting as an operator e.g. if there are scenarios where a `SelfAskLikertScorer` and `AzureContentFilterScorer` produce very different values.\n",
"\n",
Expand All @@ -23,15 +25,14 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "5adb7d1d",
"id": "e916de4f",
"metadata": {
"execution": {
"iopub.execute_input": "2024-06-06T18:25:17.882704Z",
"iopub.status.busy": "2024-06-06T18:25:17.882704Z",
"iopub.status.idle": "2024-06-06T18:25:23.864123Z",
"shell.execute_reply": "2024-06-06T18:25:23.864123Z"
},
"lines_to_next_cell": 2
"iopub.execute_input": "2024-10-11T14:58:19.595111Z",
"iopub.status.busy": "2024-10-11T14:58:19.594125Z",
"iopub.status.idle": "2024-10-11T14:58:33.275086Z",
"shell.execute_reply": "2024-10-11T14:58:33.273539Z"
}
},
"outputs": [
{
Expand All @@ -58,7 +59,9 @@
"\n",
"# Set up the Azure Content Filter\n",
"azure_content_filter = AzureContentFilterScorer(\n",
" api_key=os.environ.get(\"AZURE_CONTENT_SAFETY_API_KEY\"),\n",
" # Comment out either api_key or use_aad_auth\n",
" # api_key=os.environ.get(\"AZURE_CONTENT_SAFETY_API_KEY\"),\n",
" use_aad_auth=True,\n",
" endpoint=os.environ.get(\"AZURE_CONTENT_SAFETY_API_ENDPOINT\"),\n",
" memory=memory,\n",
")\n",
Expand All @@ -72,13 +75,22 @@
"# need to write it manually to memory as score table has a foreign key constraint\n",
"memory.add_request_response_to_memory(request=PromptRequestResponse([response]))\n",
"\n",
"# Run the request\n",
"scores = await azure_content_filter.score_async(response) # type: ignore\n",
"assert scores[0].get_value() > 0 # azure_severity should be value 2 base on the documentation\n",
"assert scores[0].get_value() > 0 # azure_severity should be value 2 based on the documentation\n",
"\n",
"for score in scores:\n",
" # score_metadata contains azure_severity original value\n",
" print(f\"{score} {score.score_metadata}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab88b29a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -100,7 +112,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.11.10"
}
},
"nbformat": 4,
Expand Down
13 changes: 10 additions & 3 deletions doc/code/scoring/2_azure_content_safety_scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
# In order to use this API, you need to configure a few environment variables:
#
# - AZURE_CONTENT_SAFETY_API_ENDPOINT: The endpoint for the Azure Content Safety API
# - AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API
# - AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API (if not using AAD Auth)
#
# As an alternative to key-based authentication, you may set `use_aad_auth=True` and use identity-based authentication.
#
# Note that this api returns a value between 0 and 7. This is different from likert scales, which return a value between 1 and 5. Because both are `float_scale` scores, these values are all normalized to floating point values between 0.0 and 1.0 and can be directly compared. This is sometimes interesting as an operator e.g. if there are scenarios where a `SelfAskLikertScorer` and `AzureContentFilterScorer` produce very different values.
#
Expand All @@ -41,7 +43,9 @@

# Set up the Azure Content Filter
azure_content_filter = AzureContentFilterScorer(
api_key=os.environ.get("AZURE_CONTENT_SAFETY_API_KEY"),
# Comment out either api_key or use_aad_auth
# api_key=os.environ.get("AZURE_CONTENT_SAFETY_API_KEY"),
use_aad_auth=True,
nina-msft marked this conversation as resolved.
Show resolved Hide resolved
endpoint=os.environ.get("AZURE_CONTENT_SAFETY_API_ENDPOINT"),
memory=memory,
)
Expand All @@ -55,9 +59,12 @@
# need to write it manually to memory as score table has a foreign key constraint
memory.add_request_response_to_memory(request=PromptRequestResponse([response]))

# Run the request
scores = await azure_content_filter.score_async(response) # type: ignore
assert scores[0].get_value() > 0 # azure_severity should be value 2 base on the documentation
assert scores[0].get_value() > 0 # azure_severity should be value 2 based on the documentation

for score in scores:
# score_metadata contains azure_severity original value
print(f"{score} {score.score_metadata}")

# %%
21 changes: 17 additions & 4 deletions pyrit/score/azure_content_filter_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.ai.contentsafety.models import AnalyzeTextOptions, AnalyzeImageOptions, TextCategory, ImageData
from azure.ai.contentsafety import ContentSafetyClient
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential


# Supported image formats for Azure as per https://learn.microsoft.com/en-us/azure/ai-services/content-safety/
Expand All @@ -38,6 +39,7 @@ def __init__(
*,
endpoint: str = None,
api_key: str = None,
use_aad_auth: bool = False,
harm_categories: list[TextCategory] = None,
memory: MemoryInterface = None,
) -> None:
Expand All @@ -51,6 +53,8 @@ def __init__(
Defaults to the API_KEY_ENVIRONMENT_VARIABLE environment variable.
endpoint (str, optional): The endpoint URL for the Azure OpenAI service.
Defaults to the ENDPOINT_URI_ENVIRONMENT_VARIABLE environment variable.
use_aad_auth (bool, optional): Attempt to use DefaultAzureCredential
If set to true, attempt to use DefaultAzureCredential for auth
harm_categories: The harm categories you want to query for as per defined in
azure.ai.contentsafety.models.TextCategory.
"""
Expand All @@ -60,17 +64,26 @@ def __init__(
else:
self._score_categories = [category.value for category in TextCategory]

self._api_key = default_values.get_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
self._endpoint = default_values.get_required_value(
env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint
)

if not use_aad_auth:
self._api_key = default_values.get_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
else:
if api_key:
raise ValueError("Please specify either use_add_auth or api_key")
else:
self._api_key = None

if self._api_key is not None and self._endpoint is not None:
riedgar-ms marked this conversation as resolved.
Show resolved Hide resolved
self._azure_cf_client = ContentSafetyClient(self._endpoint, AzureKeyCredential(self._api_key))
elif use_aad_auth and self._endpoint is not None:
self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=DefaultAzureCredential())
else:
raise ValueError("Please provide the Azure Content Safety API key and endpoint")
raise ValueError("Please provide the Azure Content Safety endpoint")

async def score_async(self, request_response: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]:
"""Evaluating the input text or image using the Azure Content Filter API
Expand Down
Loading