Skip to content

Commit

Permalink
added headers to non-streaming chat calls for AOAI execute invoker
Browse files Browse the repository at this point in the history
  • Loading branch information
sethjuarez committed Nov 14, 2024
1 parent 1990149 commit b616e06
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
38 changes: 34 additions & 4 deletions runtime/prompty/prompty/azure/executor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import azure.identity
import importlib.metadata
from typing import AsyncIterator, Iterator
from openai import AzureOpenAI, AsyncAzureOpenAI
from openai import APIResponse, AzureOpenAI, AsyncAzureOpenAI

from prompty.tracer import Tracer
from prompty.tracer import Tracer, sanitize
from ..core import AsyncPromptyStream, Prompty, PromptyStream
from openai.types.chat.chat_completion import ChatCompletion
from ..invoker import Invoker, InvokerFactory

VERSION = importlib.metadata.version("prompty")
Expand Down Expand Up @@ -86,7 +88,21 @@ def invoke(self, data: any) -> any:
**self.parameters,
}
trace("inputs", args)
response = client.chat.completions.create(**args)

if "stream" in args and args["stream"] == True:
response = client.chat.completions.create(**args)
else:
raw: APIResponse = client.chat.completions.with_raw_response.create(
**args
)
response = ChatCompletion.model_validate_json(raw.text)

for k, v in raw.headers.raw:
trace(k.decode("utf-8"), v.decode("utf-8"))

trace("request_id", raw.request_id)
trace("retries_taken", raw.retries_taken)

trace("result", response)

elif self.api == "completion":
Expand Down Expand Up @@ -171,7 +187,20 @@ async def invoke_async(self, data: str) -> str:
**self.parameters,
}
trace("inputs", args)
response = await client.chat.completions.create(**args)

if "stream" in args and args["stream"] == True:
response = await client.chat.completions.create(**args)
else:
raw: APIResponse = await client.chat.completions.with_raw_response.create(
**args
)
response = ChatCompletion.model_validate_json(raw.text)
for k, v in raw.headers.raw:
trace(k.decode("utf-8"), v.decode("utf-8"))

trace("request_id", raw.request_id)
trace("retries_taken", raw.retries_taken)

trace("result", response)

elif self.api == "completion":
Expand All @@ -182,6 +211,7 @@ async def invoke_async(self, data: str) -> str:
**self.parameters,
}
trace("inputs", args)

response = await client.completions.create(**args)
trace("result", response)

Expand Down
7 changes: 2 additions & 5 deletions runtime/prompty/prompty/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from pathlib import Path

from .tracer import Tracer, to_dict
from .tracer import Tracer, to_dict, sanitize
from pydantic import BaseModel, Field, FilePath
from typing import AsyncIterator, Iterator, List, Literal, Dict, Callable, Set, Tuple

Expand Down Expand Up @@ -88,10 +88,7 @@ def model_dump(
serialize_as_any=serialize_as_any,
)

d["configuration"] = {
k: "*" * len(v) if "key" in k.lower() or "secret" in k.lower() else v
for k, v in d["configuration"].items()
}
d["configuration"] = {k: sanitize(k, v) for k, v in d["configuration"].items()}
return d


Expand Down
4 changes: 2 additions & 2 deletions runtime/prompty/prompty/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# clean up key value pairs for sensitive values
def sanitize(key: str, value: Any) -> Any:
if isinstance(value, str) and any(
[s in key.lower() for s in ["key", "token", "secret", "password", "credential"]]
[s in key.lower() for s in ["key", "secret", "password", "credential"]]
):
return len(str(value)) * "*"
return 10 * "*"
elif isinstance(value, dict):
return {k: sanitize(k, v) for k, v in value.items()}
else:
Expand Down

0 comments on commit b616e06

Please sign in to comment.