Skip to content

Commit

Permalink
fix: Azure OpenAI o1 max_completion_token and get_num_token_from_mess…
Browse files Browse the repository at this point in the history
…ages error (langgenius#9326)

Co-authored-by: wwwc <wwwc@outlook.com>
  • Loading branch information
2 people authored and JunXu01 committed Nov 9, 2024
1 parent 9cbf7a9 commit afd9972
Show file tree
Hide file tree
Showing 9 changed files with 644 additions and 387 deletions.
16 changes: 16 additions & 0 deletions api/core/model_runtime/model_providers/azure_openai/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,14 @@ class AzureBaseModel(BaseModel):
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
Expand Down Expand Up @@ -1135,6 +1143,14 @@ class AzureBaseModel(BaseModel):
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
Expand Down
10 changes: 9 additions & 1 deletion api/core/model_runtime/model_providers/azure_openai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,15 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))

if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if model.startswith("o1"):
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=1,
max_completion_tokens=20,
stream=False,
)
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
Expand Down
697 changes: 392 additions & 305 deletions api/poetry.lock

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,19 @@ ignore = [
]
"tests/*" = [
"F811", # redefined-while-unused
"F401", # unused-import
"PT001", # missing-function-docstring
"PT004", # missing-parameter-docstring
]
"core/rag/extractor/word_extractor.py" = [
"RUF100", # Unused `noqa` directive
]
"core/tools/provider/builtin/gitlab/tools/gitlab_commits.py" = [
"PLR1714", # Consider merging multiple comparisons
]

[tool.ruff.lint.pyflakes]
allowed-unused-imports=[
extend-generics=[
"_pytest.monkeypatch",
"tests.integration_tests",
]
Expand Down Expand Up @@ -149,7 +158,7 @@ nomic = "~3.1.2"
novita-client = "~0.5.7"
numpy = "~1.26.4"
oci = "~2.135.1"
openai = "~1.29.0"
openai = "~1.51.2"
openpyxl = "~3.1.5"
pandas = { version = "~2.2.2", extras = ["performance", "excel"] }
psycopg2-binary = "~2.9.6"
Expand All @@ -172,7 +181,7 @@ scikit-learn = "~1.5.1"
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
sqlalchemy = "~2.0.29"
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
tiktoken = "~0.7.0"
tiktoken = "~0.8.0"
tokenizers = "~0.15.0"
transformers = "~4.35.0"
unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def moderation_create(
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
"illicit": False,
"illicit/violent": False,
}
moderation_categories_scores = {
"harassment": 1.0,
Expand All @@ -54,13 +56,30 @@ def moderation_create(
"sexual/minors": 1.0,
"violence": 1.0,
"violence/graphic": 1.0,
"illicit": 1.0,
"illicit/violent": 1.0,
}
category_applied_input_types = {
"sexual": ["text", "image"],
"hate": ["text"],
"harassment": ["text"],
"self-harm": ["text", "image"],
"sexual/minors": ["text"],
"hate/threatening": ["text"],
"violence/graphic": ["text", "image"],
"self-harm/intent": ["text", "image"],
"self-harm/instructions": ["text", "image"],
"harassment/threatening": ["text"],
"violence": ["text", "image"],
"illicit": ["text"],
"illicit/violent": ["text"],
}

result.append(
Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
category_applied_input_types=category_applied_input_types,
)
)
else:
Expand All @@ -76,6 +95,8 @@ def moderation_create(
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
"illicit": False,
"illicit/violent": False,
}
moderation_categories_scores = {
"harassment": 0.0,
Expand All @@ -89,12 +110,30 @@ def moderation_create(
"sexual/minors": 0.0,
"violence": 0.0,
"violence/graphic": 0.0,
"illicit": 0.0,
"illicit/violent": 0.0,
}
category_applied_input_types = {
"sexual": ["text", "image"],
"hate": ["text"],
"harassment": ["text"],
"self-harm": ["text", "image"],
"sexual/minors": ["text"],
"hate/threatening": ["text"],
"violence/graphic": ["text", "image"],
"self-harm/intent": ["text", "image"],
"self-harm/instructions": ["text", "image"],
"harassment/threatening": ["text"],
"violence": ["text", "image"],
"illicit": ["text"],
"illicit/violent": ["text"],
}
result.append(
Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
category_applied_input_types=category_applied_input_types,
)
)

Expand Down
2 changes: 1 addition & 1 deletion sdks/python-client/dify_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from dify_client.client import ChatClient, CompletionClient, DifyClient
from dify_client.client import ChatClient, CompletionClient, DifyClient
Loading

0 comments on commit afd9972

Please sign in to comment.