Skip to content

Commit 71e95fb

Browse files
authored
fix: camel model will not work when using open source models (#45)
1 parent 47a910b commit 71e95fb

File tree

3 files changed

+1254
-1180
lines changed

3 files changed

+1254
-1180
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,6 @@ _build/
168168
# model parameter
169169
*.pth
170170

171-
logs/
171+
logs/
172+
173+
.DS_Store

crab/agents/backend_models/camel_model.py

+25-23
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
try:
2424
from camel.agents import ChatAgent
25-
from camel.configs import ChatGPTConfig
2625
from camel.messages import BaseMessage
2726
from camel.models import ModelFactory
2827
from camel.toolkits import OpenAIFunction
@@ -33,29 +32,34 @@
3332
CAMEL_ENABLED = False
3433

3534

36-
def _find_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
37-
for platform in ModelPlatformType:
38-
if platform.value.lower() == model_platform_name.lower():
39-
return platform
40-
all_models = [platform.value for platform in ModelPlatformType]
41-
raise ValueError(
42-
f"Model {model_platform_name} not found. Supported models are {all_models}"
43-
)
35+
def _get_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
36+
try:
37+
return ModelPlatformType(model_platform_name)
38+
except ValueError:
39+
all_models = [platform.value for platform in ModelPlatformType]
40+
raise ValueError(
41+
f"Model {model_platform_name} not found. Supported models are {all_models}"
42+
)
4443

4544

46-
def _find_model_type(model_name: str) -> "str | ModelType":
47-
for model in ModelType:
48-
if model.value.lower() == model_name.lower():
49-
return model
50-
return model_name
45+
def _get_model_type(model_name: str) -> "str | ModelType":
46+
try:
47+
return ModelType(model_name)
48+
except ValueError:
49+
return model_name
5150

5251

5352
def _convert_action_to_schema(
5453
action_space: list[Action] | None,
5554
) -> "list[OpenAIFunction] | None":
5655
if action_space is None:
5756
return None
58-
return [OpenAIFunction(action.entry) for action in action_space]
57+
schema_list = []
58+
for action in action_space:
59+
new_action = action.to_openai_json_schema()
60+
schema = {"type": "function", "function": new_action}
61+
schema_list.append(OpenAIFunction(action.entry, schema))
62+
return schema_list
5963

6064

6165
def _convert_tool_calls_to_action_list(
@@ -84,9 +88,8 @@ def __init__(
8488
if not CAMEL_ENABLED:
8589
raise ImportError("Please install camel-ai to use CamelModel")
8690
self.parameters = parameters or {}
87-
# TODO: a better way?
88-
self.model_type = _find_model_type(model)
89-
self.model_platform_type = _find_model_platform_type(model_platform)
91+
self.model_type = _get_model_type(model)
92+
self.model_platform_type = _get_model_platform_type(model_platform)
9093
self.client: ChatAgent | None = None
9194
self.token_usage = 0
9295

@@ -104,15 +107,14 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
104107
config = self.parameters.copy()
105108
if action_schema is not None:
106109
config["tool_choice"] = "required"
107-
config["tools"] = action_schema
110+
config["tools"] = [
111+
schema.get_openai_tool_schema() for schema in action_schema
112+
]
108113

109-
chatgpt_config = ChatGPTConfig(
110-
**config,
111-
)
112114
backend_model = ModelFactory.create(
113115
self.model_platform_type,
114116
self.model_type,
115-
model_config_dict=chatgpt_config.as_dict(),
117+
model_config_dict=config,
116118
)
117119
sysmsg = BaseMessage.make_assistant_message(
118120
role_name="Assistant",

0 commit comments

Comments
 (0)