Skip to content

Commit

Permalink
Add UAI to llama deployment (#2473)
Browse files Browse the repository at this point in the history
* add uai

* fix typo

* fix typo

* reformat
  • Loading branch information
xuke444 authored Jul 24, 2023
1 parent 6871e32 commit 1ed88da
Show file tree
Hide file tree
Showing 4 changed files with 536 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,37 @@ def get_parameter_type(sample_input_ex, sample_output_ex=None):
model = load_model(model_path)


def get_aacs_access_key():
key = os.environ.get("CONTENT_SAFETY_KEY")

if key:
return key

uai_client_id = os.environ.get("UAI_CLIENT_ID")
if not uai_client_id:
raise RuntimeError(
"Cannot get AACS access key, both UAI_CLIENT_ID and CONTENT_SAFETY_KEY are not set, exiting..."
)

subscription_id = os.environ.get("SUBSCRIPTION_ID")
resource_group_name = os.environ.get("RESOURCE_GROUP_NAME")
aacs_account_name = os.environ.get("CONTENT_SAFETY_ACCOUNT_NAME")
from azure.mgmt.cognitiveservices import CognitiveServicesManagementClient
from azure.identity import ManagedIdentityCredential

credential = ManagedIdentityCredential(client_id=uai_client_id)
cs_client = CognitiveServicesManagementClient(credential, subscription_id)
key = cs_client.accounts.list_keys(
resource_group_name=resource_group_name, account_name=aacs_account_name
).key1

return key


def init():
global inputs_collector, outputs_collector, aacs_client
endpoint = os.environ.get("CONTENT_SAFETY_ENDPOINT")
key = os.environ.get("CONTENT_SAFETY_KEY")
key = get_aacs_access_key()

# Create an Content Safety client
headers_policy = HeadersPolicy()
Expand Down
Loading

0 comments on commit 1ed88da

Please sign in to comment.