Skip to content

Commit

Permalink
fix: Add max_completion_tokens to openai param validation (#6550)
Browse files Browse the repository at this point in the history
* fix: Add max_completion_tokens to openai param validation
* feat(client): Handle max_completion_tokens in openai sdk
  • Loading branch information
cephalization authored Feb 24, 2025
1 parent ba6ffe2 commit c99ee6f
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 3 deletions.
5 changes: 5 additions & 0 deletions js/.changeset/brown-months-shout.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@arizeai/phoenix-client": patch
---

Update type definitions to include max_completion_tokens openai parameter
4 changes: 4 additions & 0 deletions js/packages/phoenix-client/src/__generated__/api/v1.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class _ToolKwargs(TypedDict, total=False):
class _InvocationParameters(TypedDict, total=False):
frequency_penalty: float
max_completion_tokens: int
max_tokens: int
presence_penalty: float
reasoning_effort: ChatCompletionReasoningEffort
seed: int
Expand Down Expand Up @@ -197,8 +198,10 @@ def to_openai(
if obj["type"] == "openai":
openai_params: v1.PromptOpenAIInvocationParametersContent
openai_params = obj["openai"]
if "max_completion_tokens" in openai_params:
ans["max_completion_tokens"] = openai_params["max_completion_tokens"]
if "max_tokens" in openai_params:
ans["max_completion_tokens"] = openai_params["max_tokens"]
ans["max_tokens"] = openai_params["max_tokens"]
if "temperature" in openai_params:
ans["temperature"] = openai_params["temperature"]
if "top_p" in openai_params:
Expand All @@ -214,8 +217,10 @@ def to_openai(
elif obj["type"] == "azure_openai":
azure_params: v1.PromptAzureOpenAIInvocationParametersContent
azure_params = obj["azure_openai"]
if "max_completion_tokens" in azure_params:
ans["max_completion_tokens"] = azure_params["max_completion_tokens"]
if "max_tokens" in azure_params:
ans["max_completion_tokens"] = azure_params["max_tokens"]
ans["max_tokens"] = azure_params["max_tokens"]
if "temperature" in azure_params:
ans["temperature"] = azure_params["temperature"]
if "top_p" in azure_params:
Expand Down Expand Up @@ -299,7 +304,9 @@ def from_openai(
else:
assert_never(model_provider)
if "max_completion_tokens" in obj and obj["max_completion_tokens"] is not None:
content["max_tokens"] = obj["max_completion_tokens"]
content["max_completion_tokens"] = obj["max_completion_tokens"]
if "max_tokens" in obj and obj["max_tokens"] is not None:
content["max_tokens"] = obj["max_tokens"]
if "temperature" in obj and obj["temperature"] is not None:
content["temperature"] = obj["temperature"]
if "top_p" in obj and obj["top_p"] is not None:
Expand Down
36 changes: 36 additions & 0 deletions packages/phoenix-client/tests/canary/sdk/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,42 @@ class TestCompletionCreateParamsBase:
"max_completion_tokens": randint(1, 256),
"top_p": random(),
},
{
"model": token_hex(8),
"messages": [
{
"role": "system",
"content": "You are a UI generator. Convert the user input into a UI.",
},
{
"role": "user",
"content": "Make a form for {{ feature }}.",
},
],
"response_format": cast(
"ResponseFormat",
type_to_response_format_param(
create_model("Response", ui=(_UI, ...)),
),
),
"temperature": random(),
"max_tokens": randint(1, 256),
"top_p": random(),
},
{
"model": token_hex(8),
"messages": [
{
"role": "user",
"content": "What is the latest population estimate for {{ location }}?",
}
],
"tools": _TOOLS,
"tool_choice": "required",
"temperature": random(),
"max_tokens": randint(1, 256),
"top_p": random(),
},
],
)
def test_round_trip(self, obj: CompletionCreateParamsBase) -> None:
Expand Down
8 changes: 8 additions & 0 deletions schemas/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -2309,6 +2309,10 @@
"type": "integer",
"title": "Max Tokens"
},
"max_completion_tokens": {
"type": "integer",
"title": "Max Completion Tokens"
},
"frequency_penalty": {
"type": "number",
"title": "Frequency Penalty"
Expand Down Expand Up @@ -2539,6 +2543,10 @@
"type": "integer",
"title": "Max Tokens"
},
"max_completion_tokens": {
"type": "integer",
"title": "Max Completion Tokens"
},
"frequency_penalty": {
"type": "number",
"title": "Frequency Penalty"
Expand Down
1 change: 1 addition & 0 deletions src/phoenix/server/api/helpers/prompts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ class AnthropicToolDefinition(PromptModel):
class PromptOpenAIInvocationParametersContent(PromptModel):
temperature: float = UNDEFINED
max_tokens: int = UNDEFINED
max_completion_tokens: int = UNDEFINED
frequency_penalty: float = UNDEFINED
presence_penalty: float = UNDEFINED
top_p: float = UNDEFINED
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/prompts/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ class TestClient:
top_p=random(),
presence_penalty=random(),
frequency_penalty=random(),
max_tokens=randint(1, 256),
seed=randint(24, 42),
messages=[
{"role": "system", "content": "You are {role}."},
Expand All @@ -403,6 +404,7 @@ class TestClient:
top_p=random(),
presence_penalty=random(),
frequency_penalty=random(),
max_completion_tokens=randint(1, 256),
seed=randint(24, 42),
messages=[
{
Expand Down

0 comments on commit c99ee6f

Please sign in to comment.