From cccdaa78d6bc16f8761d9de42d435fb6e920cf39 Mon Sep 17 00:00:00 2001 From: "Richard Edgar (Microsoft)" Date: Wed, 9 Oct 2024 13:18:06 -0400 Subject: [PATCH 1/7] Adding an argument for an Azure Credential --- pyrit/auth/azure_auth.py | 14 +++++++++++--- .../azure_openai_gpto_chat_target.py | 15 +++++++++++---- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index a28bfc435..f3ea08580 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -5,7 +5,7 @@ import msal import logging -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, TokenProvider from azure.identity import AzureCliCredential from azure.identity import ManagedIdentityCredential from azure.identity import InteractiveBrowserCredential @@ -119,15 +119,23 @@ def get_access_token_from_interactive_login(scope: str = AZURE_COGNITIVE_SERVICE logger.error(f"Failed to obtain token for '{scope}': {e}") raise - def get_token_provider_from_default_azure_credential(scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE): """Connect to an AOAI endpoint via default Azure credential. + Returns: + Authentication token provider + """ + return get_token_provider_from_azure_credential(DefaultAzureCredential(), scope) + + +def get_token_provider_from_azure_credential(credential: TokenProvider, scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE): + """Connect to an AOAI endpoint via default Azure credential. + Returns: Authentication token provider """ try: - token_provider = get_bearer_token_provider(DefaultAzureCredential(), scope) + token_provider = get_bearer_token_provider(credential, scope) return token_provider except Exception as e: logger.error(f"Failed to obtain token for '{scope}': {e}") diff --git a/pyrit/prompt_target/prompt_chat_target/azure_openai_gpto_chat_target.py b/pyrit/prompt_target/prompt_chat_target/azure_openai_gpto_chat_target.py index a4609ec8d..3c09f1362 100644 --- a/pyrit/prompt_target/prompt_chat_target/azure_openai_gpto_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/azure_openai_gpto_chat_target.py @@ -3,14 +3,16 @@ import logging import json -from typing import MutableSequence, Optional +from typing import MutableSequence, Optional, Union + +from azure.identity import DefaultAzureCredential, ManagedIdentityCredential from openai import AsyncAzureOpenAI from openai import BadRequestError from openai.types.chat import ChatCompletion -from pyrit.auth.azure_auth import get_token_provider_from_default_azure_credential +from pyrit.auth.azure_auth import get_token_provider_from_azure_credential from pyrit.common import default_values from pyrit.exceptions import PyritException, EmptyResponseException from pyrit.exceptions import handle_bad_request_exception, pyrit_target_retry @@ -43,6 +45,7 @@ def __init__( api_key: str = None, headers: str = None, use_aad_auth: bool = False, + aad_credential: Union[DefaultAzureCredential, ManagedIdentityCredential, None] = None, memory: MemoryInterface = None, api_version: str = "2024-02-01", max_tokens: int = 1024, @@ -114,8 +117,12 @@ def __init__( logger.info("No headers have been passed, setting empty default headers") if use_aad_auth: - logger.info("Authenticating with DefaultAzureCredential() for Azure Cognitive Services") - token_provider = get_token_provider_from_default_azure_credential() + logger.info("Using aad_auth") + if aad_credential is None: + logger.info("Authenticating with DefaultAzureCredential() for Azure Cognitive Services") + aad_credential = DefaultAzureCredential() + + token_provider = get_token_provider_from_azure_credential(aad_credential) self._async_client = AsyncAzureOpenAI( azure_ad_token_provider=token_provider, From 49becfe30b7ffe70eb386e9cf82aea6acbb54c6b Mon Sep 17 00:00:00 2001 From: "Richard Edgar (Microsoft)" Date: Thu, 10 Oct 2024 15:47:16 -0400 Subject: [PATCH 2/7] Revert "Adding an argument for an Azure Credential" This reverts commit cccdaa78d6bc16f8761d9de42d435fb6e920cf39. --- pyrit/auth/azure_auth.py | 14 +++----------- .../azure_openai_gpto_chat_target.py | 15 ++++----------- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index f3ea08580..a28bfc435 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -5,7 +5,7 @@ import msal import logging -from azure.core.credentials import AccessToken, TokenProvider +from azure.core.credentials import AccessToken from azure.identity import AzureCliCredential from azure.identity import ManagedIdentityCredential from azure.identity import InteractiveBrowserCredential @@ -119,23 +119,15 @@ def get_access_token_from_interactive_login(scope: str = AZURE_COGNITIVE_SERVICE logger.error(f"Failed to obtain token for '{scope}': {e}") raise -def get_token_provider_from_default_azure_credential(scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE): - """Connect to an AOAI endpoint via default Azure credential. - - Returns: - Authentication token provider - """ - return get_token_provider_from_azure_credential(DefaultAzureCredential(), scope) - -def get_token_provider_from_azure_credential(credential: TokenProvider, scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE): +def get_token_provider_from_default_azure_credential(scope: str = AZURE_COGNITIVE_SERVICES_DEFAULT_SCOPE): """Connect to an AOAI endpoint via default Azure credential. Returns: Authentication token provider """ try: - token_provider = get_bearer_token_provider(credential, scope) + token_provider = get_bearer_token_provider(DefaultAzureCredential(), scope) return token_provider except Exception as e: logger.error(f"Failed to obtain token for '{scope}': {e}") diff --git a/pyrit/prompt_target/prompt_chat_target/azure_openai_gpto_chat_target.py b/pyrit/prompt_target/prompt_chat_target/azure_openai_gpto_chat_target.py index 3c09f1362..a4609ec8d 100644 --- a/pyrit/prompt_target/prompt_chat_target/azure_openai_gpto_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/azure_openai_gpto_chat_target.py @@ -3,16 +3,14 @@ import logging import json -from typing import MutableSequence, Optional, Union - -from azure.identity import DefaultAzureCredential, ManagedIdentityCredential +from typing import MutableSequence, Optional from openai import AsyncAzureOpenAI from openai import BadRequestError from openai.types.chat import ChatCompletion -from pyrit.auth.azure_auth import get_token_provider_from_azure_credential +from pyrit.auth.azure_auth import get_token_provider_from_default_azure_credential from pyrit.common import default_values from pyrit.exceptions import PyritException, EmptyResponseException from pyrit.exceptions import handle_bad_request_exception, pyrit_target_retry @@ -45,7 +43,6 @@ def __init__( api_key: str = None, headers: str = None, use_aad_auth: bool = False, - aad_credential: Union[DefaultAzureCredential, ManagedIdentityCredential, None] = None, memory: MemoryInterface = None, api_version: str = "2024-02-01", max_tokens: int = 1024, @@ -117,12 +114,8 @@ def __init__( logger.info("No headers have been passed, setting empty default headers") if use_aad_auth: - logger.info("Using aad_auth") - if aad_credential is None: - logger.info("Authenticating with DefaultAzureCredential() for Azure Cognitive Services") - aad_credential = DefaultAzureCredential() - - token_provider = get_token_provider_from_azure_credential(aad_credential) + logger.info("Authenticating with DefaultAzureCredential() for Azure Cognitive Services") + token_provider = get_token_provider_from_default_azure_credential() self._async_client = AsyncAzureOpenAI( azure_ad_token_provider=token_provider, From 529a023b7fa95beff9404678d19ee4520533ee19 Mon Sep 17 00:00:00 2001 From: "Richard Edgar (Microsoft)" Date: Thu, 10 Oct 2024 15:52:43 -0400 Subject: [PATCH 3/7] Start hooking in DefaultAzureCredential --- pyrit/score/azure_content_filter_scorer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pyrit/score/azure_content_filter_scorer.py b/pyrit/score/azure_content_filter_scorer.py index 99aff034e..d709c81a0 100644 --- a/pyrit/score/azure_content_filter_scorer.py +++ b/pyrit/score/azure_content_filter_scorer.py @@ -12,6 +12,7 @@ from azure.ai.contentsafety.models import AnalyzeTextOptions, AnalyzeImageOptions, TextCategory, ImageData from azure.ai.contentsafety import ContentSafetyClient from azure.core.credentials import AzureKeyCredential +from azure.identity import DefaultAzureCredential # Supported image formats for Azure as per https://learn.microsoft.com/en-us/azure/ai-services/content-safety/ @@ -38,6 +39,7 @@ def __init__( *, endpoint: str = None, api_key: str = None, + use_aad_auth: bool = False, harm_categories: list[TextCategory] = None, memory: MemoryInterface = None, ) -> None: @@ -51,6 +53,8 @@ def __init__( Defaults to the API_KEY_ENVIRONMENT_VARIABLE environment variable. endpoint (str, optional): The endpoint URL for the Azure OpenAI service. Defaults to the ENDPOINT_URI_ENVIRONMENT_VARIABLE environment variable. + use_aad_auth (bool, optional): Attempt to use DefaultAzureCredential + If set to true, and api_key is None, attempt to use DefaultAzureCredential for auth harm_categories: The harm categories you want to query for as per defined in azure.ai.contentsafety.models.TextCategory. """ @@ -69,8 +73,10 @@ def __init__( if self._api_key is not None and self._endpoint is not None: self._azure_cf_client = ContentSafetyClient(self._endpoint, AzureKeyCredential(self._api_key)) + elif use_aad_auth and self._endpoint is not None: + self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=DefaultAzureCredential()) else: - raise ValueError("Please provide the Azure Content Safety API key and endpoint") + raise ValueError("Please provide the Azure Content Safety endpoint, and either set api_key or use_aad_auth") async def score_async(self, request_response: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]: """Evaluating the input text or image using the Azure Content Filter API From f5aa63b7e4a8ec6035ac2ed76fd1171a4b53309d Mon Sep 17 00:00:00 2001 From: "Richard Edgar (Microsoft)" Date: Fri, 11 Oct 2024 08:59:10 -0400 Subject: [PATCH 4/7] Requested changes --- pyrit/score/azure_content_filter_scorer.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pyrit/score/azure_content_filter_scorer.py b/pyrit/score/azure_content_filter_scorer.py index d709c81a0..24ae99977 100644 --- a/pyrit/score/azure_content_filter_scorer.py +++ b/pyrit/score/azure_content_filter_scorer.py @@ -54,7 +54,7 @@ def __init__( endpoint (str, optional): The endpoint URL for the Azure OpenAI service. Defaults to the ENDPOINT_URI_ENVIRONMENT_VARIABLE environment variable. use_aad_auth (bool, optional): Attempt to use DefaultAzureCredential - If set to true, and api_key is None, attempt to use DefaultAzureCredential for auth + If set to true, attempt to use DefaultAzureCredential for auth harm_categories: The harm categories you want to query for as per defined in azure.ai.contentsafety.models.TextCategory. """ @@ -64,19 +64,26 @@ def __init__( else: self._score_categories = [category.value for category in TextCategory] - self._api_key = default_values.get_required_value( - env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key - ) self._endpoint = default_values.get_required_value( env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint ) + if not use_aad_auth: + self._api_key = default_values.get_required_value( + env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key + ) + else: + if api_key: + raise ValueError("Please specify either use_add_auth or api_key") + else: + self._api_key = None + if self._api_key is not None and self._endpoint is not None: self._azure_cf_client = ContentSafetyClient(self._endpoint, AzureKeyCredential(self._api_key)) elif use_aad_auth and self._endpoint is not None: self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=DefaultAzureCredential()) else: - raise ValueError("Please provide the Azure Content Safety endpoint, and either set api_key or use_aad_auth") + raise ValueError("Please provide the Azure Content Safety endpoint") async def score_async(self, request_response: PromptRequestPiece, *, task: Optional[str] = None) -> list[Score]: """Evaluating the input text or image using the Azure Content Filter API From 10ab174929af2b99019453af9436087657786437 Mon Sep 17 00:00:00 2001 From: "Richard Edgar (Microsoft)" Date: Fri, 11 Oct 2024 09:25:05 -0400 Subject: [PATCH 5/7] Tweaking text --- .../scoring/2_azure_content_safety_scorers.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/doc/code/scoring/2_azure_content_safety_scorers.py b/doc/code/scoring/2_azure_content_safety_scorers.py index acd577c16..a02b37763 100644 --- a/doc/code/scoring/2_azure_content_safety_scorers.py +++ b/doc/code/scoring/2_azure_content_safety_scorers.py @@ -22,13 +22,16 @@ # In order to use this API, you need to configure a few environment variables: # # - AZURE_CONTENT_SAFETY_API_ENDPOINT: The endpoint for the Azure Content Safety API -# - AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API +# - AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API (if not using AAD Auth) +# +# As an alternative to key-based authentication, you may set `use_aad_auth=True` and use identity-based authentication. # # Note that this api returns a value between 0 and 7. This is different from likert scales, which return a value between 1 and 5. Because both are `float_scale` scores, these values are all normalized to floating point values between 0.0 and 1.0 and can be directly compared. This is sometimes interesting as an operator e.g. if there are scenarios where a `SelfAskLikertScorer` and `AzureContentFilterScorer` produce very different values. # # Before you begin, ensure you are setup with the correct version of PyRIT installed and have secrets configured as described [here](../../setup/). # %% +import asyncio import os from pyrit.score import AzureContentFilterScorer from pyrit.common import default_values @@ -41,7 +44,9 @@ # Set up the Azure Content Filter azure_content_filter = AzureContentFilterScorer( - api_key=os.environ.get("AZURE_CONTENT_SAFETY_API_KEY"), + # Comment out either api_key or use_aad_auth + # api_key=os.environ.get("AZURE_CONTENT_SAFETY_API_KEY"), + use_aad_auth=True, endpoint=os.environ.get("AZURE_CONTENT_SAFETY_API_ENDPOINT"), memory=memory, ) @@ -55,9 +60,12 @@ # need to write it manually to memory as score table has a foreign key constraint memory.add_request_response_to_memory(request=PromptRequestResponse([response])) -scores = await azure_content_filter.score_async(response) # type: ignore -assert scores[0].get_value() > 0 # azure_severity should be value 2 base on the documentation +# Run the request +scores = asyncio.run(azure_content_filter.score_async(response)) # type: ignore +assert scores[0].get_value() > 0 # azure_severity should be value 2 based on the documentation for score in scores: # score_metadata contains azure_severity original value print(f"{score} {score.score_metadata}") + +# %% From 6f8c5caf2fd3d3a5b5a5e365f04135f28601ea46 Mon Sep 17 00:00:00 2001 From: "Richard Edgar (Microsoft)" Date: Fri, 11 Oct 2024 11:01:29 -0400 Subject: [PATCH 6/7] Asyncio appears unsatisfiable, so make the notebook work --- .../2_azure_content_safety_scorers.ipynb | 36 ++++++++++++------- .../scoring/2_azure_content_safety_scorers.py | 3 +- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/doc/code/scoring/2_azure_content_safety_scorers.ipynb b/doc/code/scoring/2_azure_content_safety_scorers.ipynb index 07e173f1d..c9d381693 100644 --- a/doc/code/scoring/2_azure_content_safety_scorers.ipynb +++ b/doc/code/scoring/2_azure_content_safety_scorers.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "9868c425", + "id": "2faca040", "metadata": {}, "source": [ "\n", @@ -13,7 +13,9 @@ "In order to use this API, you need to configure a few environment variables:\n", "\n", "- AZURE_CONTENT_SAFETY_API_ENDPOINT: The endpoint for the Azure Content Safety API\n", - "- AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API\n", + "- AZURE_CONTENT_SAFETY_API_KEY: The API key for the Azure Content Safety API (if not using AAD Auth)\n", + "\n", + "As an alternative to key-based authentication, you may set `use_aad_auth=True` and use identity-based authentication.\n", "\n", "Note that this api returns a value between 0 and 7. This is different from likert scales, which return a value between 1 and 5. Because both are `float_scale` scores, these values are all normalized to floating point values between 0.0 and 1.0 and can be directly compared. This is sometimes interesting as an operator e.g. if there are scenarios where a `SelfAskLikertScorer` and `AzureContentFilterScorer` produce very different values.\n", "\n", @@ -23,15 +25,14 @@ { "cell_type": "code", "execution_count": 1, - "id": "5adb7d1d", + "id": "e916de4f", "metadata": { "execution": { - "iopub.execute_input": "2024-06-06T18:25:17.882704Z", - "iopub.status.busy": "2024-06-06T18:25:17.882704Z", - "iopub.status.idle": "2024-06-06T18:25:23.864123Z", - "shell.execute_reply": "2024-06-06T18:25:23.864123Z" - }, - "lines_to_next_cell": 2 + "iopub.execute_input": "2024-10-11T14:58:19.595111Z", + "iopub.status.busy": "2024-10-11T14:58:19.594125Z", + "iopub.status.idle": "2024-10-11T14:58:33.275086Z", + "shell.execute_reply": "2024-10-11T14:58:33.273539Z" + } }, "outputs": [ { @@ -58,7 +59,9 @@ "\n", "# Set up the Azure Content Filter\n", "azure_content_filter = AzureContentFilterScorer(\n", - " api_key=os.environ.get(\"AZURE_CONTENT_SAFETY_API_KEY\"),\n", + " # Comment out either api_key or use_aad_auth\n", + " # api_key=os.environ.get(\"AZURE_CONTENT_SAFETY_API_KEY\"),\n", + " use_aad_auth=True,\n", " endpoint=os.environ.get(\"AZURE_CONTENT_SAFETY_API_ENDPOINT\"),\n", " memory=memory,\n", ")\n", @@ -72,13 +75,22 @@ "# need to write it manually to memory as score table has a foreign key constraint\n", "memory.add_request_response_to_memory(request=PromptRequestResponse([response]))\n", "\n", + "# Run the request\n", "scores = await azure_content_filter.score_async(response) # type: ignore\n", - "assert scores[0].get_value() > 0 # azure_severity should be value 2 base on the documentation\n", + "assert scores[0].get_value() > 0 # azure_severity should be value 2 based on the documentation\n", "\n", "for score in scores:\n", " # score_metadata contains azure_severity original value\n", " print(f\"{score} {score.score_metadata}\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab88b29a", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -100,7 +112,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/doc/code/scoring/2_azure_content_safety_scorers.py b/doc/code/scoring/2_azure_content_safety_scorers.py index a02b37763..5b3c37def 100644 --- a/doc/code/scoring/2_azure_content_safety_scorers.py +++ b/doc/code/scoring/2_azure_content_safety_scorers.py @@ -31,7 +31,6 @@ # Before you begin, ensure you are setup with the correct version of PyRIT installed and have secrets configured as described [here](../../setup/). # %% -import asyncio import os from pyrit.score import AzureContentFilterScorer from pyrit.common import default_values @@ -61,7 +60,7 @@ memory.add_request_response_to_memory(request=PromptRequestResponse([response])) # Run the request -scores = asyncio.run(azure_content_filter.score_async(response)) # type: ignore +scores = await azure_content_filter.score_async(response) # type: ignore assert scores[0].get_value() > 0 # azure_severity should be value 2 based on the documentation for score in scores: From e9aee309e062db1a403d1a5bc265c35a25c79968 Mon Sep 17 00:00:00 2001 From: "Richard Edgar (Microsoft)" Date: Tue, 15 Oct 2024 09:32:58 -0400 Subject: [PATCH 7/7] Swap default auth method in notebook --- .../2_azure_content_safety_scorers.ipynb | 46 ++++--------------- .../scoring/2_azure_content_safety_scorers.py | 4 +- 2 files changed, 10 insertions(+), 40 deletions(-) diff --git a/doc/code/scoring/2_azure_content_safety_scorers.ipynb b/doc/code/scoring/2_azure_content_safety_scorers.ipynb index c9d381693..6fce16bee 100644 --- a/doc/code/scoring/2_azure_content_safety_scorers.ipynb +++ b/doc/code/scoring/2_azure_content_safety_scorers.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "2faca040", + "id": "7f1cbefc", "metadata": {}, "source": [ "\n", @@ -24,28 +24,10 @@ }, { "cell_type": "code", - "execution_count": 1, - "id": "e916de4f", - "metadata": { - "execution": { - "iopub.execute_input": "2024-10-11T14:58:19.595111Z", - "iopub.status.busy": "2024-10-11T14:58:19.594125Z", - "iopub.status.idle": "2024-10-11T14:58:33.275086Z", - "shell.execute_reply": "2024-10-11T14:58:33.273539Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "AzureContentFilterScorer: Hate: 0.2857142857142857 {'azure_severity': '2'}\n", - "AzureContentFilterScorer: SelfHarm: 0.0 {'azure_severity': '0'}\n", - "AzureContentFilterScorer: Sexual: 0.0 {'azure_severity': '0'}\n", - "AzureContentFilterScorer: Violence: 0.0 {'azure_severity': '0'}\n" - ] - } - ], + "execution_count": null, + "id": "53426524", + "metadata": {}, + "outputs": [], "source": [ "import os\n", "from pyrit.score import AzureContentFilterScorer\n", @@ -60,8 +42,8 @@ "# Set up the Azure Content Filter\n", "azure_content_filter = AzureContentFilterScorer(\n", " # Comment out either api_key or use_aad_auth\n", - " # api_key=os.environ.get(\"AZURE_CONTENT_SAFETY_API_KEY\"),\n", - " use_aad_auth=True,\n", + " api_key=os.environ.get(\"AZURE_CONTENT_SAFETY_API_KEY\"),\n", + " # use_aad_auth=True,\n", " endpoint=os.environ.get(\"AZURE_CONTENT_SAFETY_API_ENDPOINT\"),\n", " memory=memory,\n", ")\n", @@ -87,7 +69,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab88b29a", + "id": "9cb0dc4e", "metadata": {}, "outputs": [], "source": [] @@ -101,18 +83,6 @@ "display_name": "pyrit-311", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.10" } }, "nbformat": 4, diff --git a/doc/code/scoring/2_azure_content_safety_scorers.py b/doc/code/scoring/2_azure_content_safety_scorers.py index 5b3c37def..2f92bffb4 100644 --- a/doc/code/scoring/2_azure_content_safety_scorers.py +++ b/doc/code/scoring/2_azure_content_safety_scorers.py @@ -44,8 +44,8 @@ # Set up the Azure Content Filter azure_content_filter = AzureContentFilterScorer( # Comment out either api_key or use_aad_auth - # api_key=os.environ.get("AZURE_CONTENT_SAFETY_API_KEY"), - use_aad_auth=True, + api_key=os.environ.get("AZURE_CONTENT_SAFETY_API_KEY"), + # use_aad_auth=True, endpoint=os.environ.get("AZURE_CONTENT_SAFETY_API_ENDPOINT"), memory=memory, )