Skip to content

Commit

Permalink
fix: parameter extractor mock
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Jan 10, 2025
1 parent dc0657f commit 38ab4fe
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,46 @@ def invoke(
content_list = []
usage = LLMUsage.empty_usage()
system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []

def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]

tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)

return tool_call

for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments

for chunk in result:
if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls:
increase_tool_call(chunk.delta.message.tool_calls)

usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint
Expand All @@ -120,7 +155,10 @@ def invoke(
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=content or content_list),
message=AssistantPromptMessage(
content=content or content_list,
tool_calls=tools_calls,
),
usage=usage,
system_fingerprint=system_fingerprint,
)
Expand Down
44 changes: 44 additions & 0 deletions api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
from collections.abc import Callable

import pytest

# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch

from core.plugin.manager.model import PluginModelManager
from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass


def mock_plugin_daemon(
monkeypatch: MonkeyPatch,
) -> Callable[[], None]:
"""
mock openai module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""

def unpatch() -> None:
monkeypatch.undo()

monkeypatch.setattr(PluginModelManager, "invoke_llm", MockModelClass.invoke_llm)
monkeypatch.setattr(PluginModelManager, "fetch_model_providers", MockModelClass.fetch_model_providers)
monkeypatch.setattr(PluginModelManager, "get_model_schema", MockModelClass.get_model_schema)

return unpatch


MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"


@pytest.fixture
def setup_model_mock(monkeypatch):
if MOCK:
unpatch = mock_plugin_daemon(monkeypatch)

yield

if MOCK:
unpatch()
249 changes: 249 additions & 0 deletions api/tests/integration_tests/model_runtime/__mock/plugin_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import datetime
from decimal import Decimal
from collections.abc import Generator
from json import dumps

# import monkeypatch
from typing import Optional, Sequence
import uuid

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.manager.model import PluginModelManager


class MockModelClass(PluginModelManager):
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
"""
Fetch model providers for the given tenant.
"""
return [
PluginModelProviderEntity(
id=uuid.uuid4().hex,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
provider="openai",
tenant_id=tenant_id,
plugin_unique_identifier="langgenius/openai/openai",
plugin_id="langgenius/openai",
declaration=ProviderEntity(
provider="openai",
label=I18nObject(
en_US="OpenAI",
zh_Hans="OpenAI",
),
description=I18nObject(
en_US="OpenAI",
zh_Hans="OpenAI",
),
icon_small=I18nObject(
en_US="https://example.com/icon_small.png",
zh_Hans="https://example.com/icon_small.png",
),
icon_large=I18nObject(
en_US="https://example.com/icon_large.png",
zh_Hans="https://example.com/icon_large.png",
),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
models=[
AIModelEntity(
model="gpt-3.5-turbo",
label=I18nObject(
en_US="gpt-3.5-turbo",
zh_Hans="gpt-3.5-turbo",
),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL],
),
AIModelEntity(
model="gpt-3.5-turbo-instruct",
label=I18nObject(
en_US="gpt-3.5-turbo-instruct",
zh_Hans="gpt-3.5-turbo-instruct",
),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.COMPLETION,
},
features=[],
),
],
),
)
]

def get_model_schema(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict,
) -> AIModelEntity | None:
"""
Get model schema
"""
return AIModelEntity(
model=model,
label=I18nObject(
en_US="OpenAI",
zh_Hans="OpenAI",
),
model_type=ModelType(model_type),
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL] if model == "gpt-3.5-turbo" else [],
)

@staticmethod
def generate_function_call(
tools: Optional[list[PromptMessageTool]],
) -> Optional[AssistantPromptMessage.ToolCall]:
if not tools or len(tools) == 0:
return None
function: PromptMessageTool = tools[0]
function_name = function.name
function_parameters = function.parameters
function_parameters_type = function_parameters["type"]
if function_parameters_type != "object":
return None
function_parameters_properties = function_parameters["properties"]
function_parameters_required = function_parameters["required"]
parameters = {}
for parameter_name, parameter in function_parameters_properties.items():
if parameter_name not in function_parameters_required:
continue
parameter_type = parameter["type"]
if parameter_type == "string":
if "enum" in parameter:
if len(parameter["enum"]) == 0:
continue
parameters[parameter_name] = parameter["enum"][0]
else:
parameters[parameter_name] = "kawaii"
elif parameter_type == "integer":
parameters[parameter_name] = 114514
elif parameter_type == "number":
parameters[parameter_name] = 1919810.0
elif parameter_type == "boolean":
parameters[parameter_name] = True

return AssistantPromptMessage.ToolCall(
id=str(uuid.uuid4()),
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=function_name,
arguments=dumps(parameters),
),
)

@staticmethod
def mocked_chat_create_sync(
model: str,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> LLMResult:
tool_call = MockModelClass.generate_function_call(tools=tools)

return LLMResult(
id=str(uuid.uuid4()),
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content="elaina", tool_calls=[tool_call] if tool_call else []),
usage=LLMUsage(
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
prompt_unit_price=Decimal(0.0001),
completion_unit_price=Decimal(0.0002),
prompt_price_unit=Decimal(1),
prompt_price=Decimal(0.0001),
completion_price_unit=Decimal(1),
completion_price=Decimal(0.0002),
total_price=Decimal(0.0003),
currency="USD",
latency=0.001,
),
)

@staticmethod
def mocked_chat_create_stream(
model: str,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> Generator[LLMResultChunk, None, None]:
tool_call = MockModelClass.generate_function_call(tools=tools)

full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
for i in range(0, len(full_text) + 1):
if i == len(full_text):
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=full_text[i],
tool_calls=[tool_call] if tool_call else [],
),
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=full_text[i],
tool_calls=[tool_call] if tool_call else [],
),
usage=LLMUsage(
prompt_tokens=2,
completion_tokens=17,
total_tokens=19,
prompt_unit_price=Decimal(0.0001),
completion_unit_price=Decimal(0.0002),
prompt_price_unit=Decimal(1),
prompt_price=Decimal(0.0001),
completion_price_unit=Decimal(1),
completion_price=Decimal(0.0002),
total_price=Decimal(0.0003),
currency="USD",
latency=0.001,
),
),
)

def invoke_llm(
self: PluginModelManager,
*,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
):
return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools)
8 changes: 4 additions & 4 deletions api/tests/integration_tests/workflow/nodes/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
Expand All @@ -26,8 +26,8 @@
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType

"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock # noqa
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock # noqa


def init_llm_node(config: dict) -> LLMNode:
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_execute_llm(setup_openai_mock):

credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}

provider_instance = ModelProviderFactory().get_provider_instance("openai")
provider_instance = ModelProviderFactory("aa").get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
Expand Down
Loading

0 comments on commit 38ab4fe

Please sign in to comment.