From b616e064f62ad889ff06b62bc2b55bf6467dbf3f Mon Sep 17 00:00:00 2001 From: sethjuarez Date: Thu, 14 Nov 2024 14:30:44 -0800 Subject: [PATCH] added headers to non-streaming chat calls for AOAI execute invoker --- runtime/prompty/prompty/azure/executor.py | 38 ++++++++++++++++++++--- runtime/prompty/prompty/core.py | 7 ++--- runtime/prompty/prompty/tracer.py | 4 +-- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/runtime/prompty/prompty/azure/executor.py b/runtime/prompty/prompty/azure/executor.py index a6b571a..f77a4b2 100644 --- a/runtime/prompty/prompty/azure/executor.py +++ b/runtime/prompty/prompty/azure/executor.py @@ -1,10 +1,12 @@ +import json import azure.identity import importlib.metadata from typing import AsyncIterator, Iterator -from openai import AzureOpenAI, AsyncAzureOpenAI +from openai import APIResponse, AzureOpenAI, AsyncAzureOpenAI -from prompty.tracer import Tracer +from prompty.tracer import Tracer, sanitize from ..core import AsyncPromptyStream, Prompty, PromptyStream +from openai.types.chat.chat_completion import ChatCompletion from ..invoker import Invoker, InvokerFactory VERSION = importlib.metadata.version("prompty") @@ -86,7 +88,21 @@ def invoke(self, data: any) -> any: **self.parameters, } trace("inputs", args) - response = client.chat.completions.create(**args) + + if "stream" in args and args["stream"] == True: + response = client.chat.completions.create(**args) + else: + raw: APIResponse = client.chat.completions.with_raw_response.create( + **args + ) + response = ChatCompletion.model_validate_json(raw.text) + + for k, v in raw.headers.raw: + trace(k.decode("utf-8"), v.decode("utf-8")) + + trace("request_id", raw.request_id) + trace("retries_taken", raw.retries_taken) + trace("result", response) elif self.api == "completion": @@ -171,7 +187,20 @@ async def invoke_async(self, data: str) -> str: **self.parameters, } trace("inputs", args) - response = await client.chat.completions.create(**args) + + if "stream" in args and args["stream"] == True: + response = await client.chat.completions.create(**args) + else: + raw: APIResponse = await client.chat.completions.with_raw_response.create( + **args + ) + response = ChatCompletion.model_validate_json(raw.text) + for k, v in raw.headers.raw: + trace(k.decode("utf-8"), v.decode("utf-8")) + + trace("request_id", raw.request_id) + trace("retries_taken", raw.retries_taken) + trace("result", response) elif self.api == "completion": @@ -182,6 +211,7 @@ async def invoke_async(self, data: str) -> str: **self.parameters, } trace("inputs", args) + response = await client.completions.create(**args) trace("result", response) diff --git a/runtime/prompty/prompty/core.py b/runtime/prompty/prompty/core.py index f10a43f..25399cf 100644 --- a/runtime/prompty/prompty/core.py +++ b/runtime/prompty/prompty/core.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from .tracer import Tracer, to_dict +from .tracer import Tracer, to_dict, sanitize from pydantic import BaseModel, Field, FilePath from typing import AsyncIterator, Iterator, List, Literal, Dict, Callable, Set, Tuple @@ -88,10 +88,7 @@ def model_dump( serialize_as_any=serialize_as_any, ) - d["configuration"] = { - k: "*" * len(v) if "key" in k.lower() or "secret" in k.lower() else v - for k, v in d["configuration"].items() - } + d["configuration"] = {k: sanitize(k, v) for k, v in d["configuration"].items()} return d diff --git a/runtime/prompty/prompty/tracer.py b/runtime/prompty/prompty/tracer.py index 8c73df2..943fd5a 100644 --- a/runtime/prompty/prompty/tracer.py +++ b/runtime/prompty/prompty/tracer.py @@ -16,9 +16,9 @@ # clean up key value pairs for sensitive values def sanitize(key: str, value: Any) -> Any: if isinstance(value, str) and any( - [s in key.lower() for s in ["key", "token", "secret", "password", "credential"]] + [s in key.lower() for s in ["key", "secret", "password", "credential"]] ): - return len(str(value)) * "*" + return 10 * "*" elif isinstance(value, dict): return {k: sanitize(k, v) for k, v in value.items()} else: