Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add data collector for dataset generation #1193

Merged
merged 25 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
15f0990
init
liuxukun2000 Nov 18, 2024
24f6714
Merge branch 'master' into feat/data_collector
liuxukun2000 Nov 18, 2024
5068396
add base data collectors
liuxukun2000 Nov 19, 2024
09bed51
reformat code
liuxukun2000 Nov 19, 2024
252d70e
pass precheck
liuxukun2000 Nov 19, 2024
09b9d89
update 1 based on review comment
Wendong-Fan Nov 19, 2024
9d8e84b
Merge branch 'master' into feat/data_collector
liuxukun2000 Nov 27, 2024
44aaac5
Refine the code according to the comments
liuxukun2000 Nov 27, 2024
2bb9e90
Merge branch 'master' into feat/data_collector
liuxukun2000 Nov 29, 2024
ed208b8
add llm_converter
liuxukun2000 Nov 29, 2024
078537e
update license
liuxukun2000 Nov 29, 2024
75687a5
pass precommit
liuxukun2000 Nov 29, 2024
099a160
pass precommit
liuxukun2000 Nov 29, 2024
f8acfd1
Merge branch 'master' into feat/data_collector
liuxukun2000 Nov 29, 2024
255a1e4
Merge branch 'master' into feat/data_collector
liuxukun2000 Dec 6, 2024
9ce5a8a
get messages from memory
liuxukun2000 Dec 6, 2024
bae8295
pass precommit
liuxukun2000 Dec 6, 2024
31c25fa
small format fix
Wendong-Fan Dec 7, 2024
4d09e94
Merge branch 'master' into feat/data_collector
Wendong-Fan Dec 7, 2024
6ba6708
Merge branch 'master' into feat/data_collector
liuxukun2000 Dec 9, 2024
e6842fc
add pytest
liuxukun2000 Dec 9, 2024
5803941
reformat code
liuxukun2000 Dec 9, 2024
1e0d8f8
Merge branch 'master' into feat/data_collector
liuxukun2000 Dec 11, 2024
443fbfd
use class from messages/conversation
liuxukun2000 Dec 13, 2024
08e15a8
Merge branch 'master' into feat/data_collector
liuxukun2000 Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions camel/data_collector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .alpaca_collector import AlpacaDataCollector
from .base import BaseDataCollector
from .sharegpt_collector import ShareGPTDataCollector

__all__ = ["BaseDataCollector", "AlpacaDataCollector", "ShareGPTDataCollector"]
71 changes: 71 additions & 0 deletions camel/data_collector/alpaca_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Dict, List, Optional, Self, Union

from camel.agents.chat_agent import ChatAgent
from camel.data_collector.base import BaseDataCollector
from camel.messages.base import BaseMessage


class AlpacaDataCollector(BaseDataCollector):
def __init__(self):
super().__init__()
self.system_message: Optional[BaseMessage] = None
self.agent_name: Optional[str] = None

def inject(
self,
agent: Union[List[ChatAgent], ChatAgent],
name: Optional[Union[str, List[Optional[str]]]] = None,
) -> Self:
r"""Inject an agent into the data collector.

Args:
agent (Union[List[ChatAgent], ChatAgent]):
The agent to inject.
name (Optional[Union[str, List[Optional[str]]]], optional):
The name of the agent. Defaults to None.
"""
if len(self.agents) > 1:
raise ValueError("AlpacaDataCollector only supports one agent")
if isinstance(agent, list):
if len(agent) != 1:
raise ValueError("AlpacaDataCollector only supports one agent")
liuxukun2000 marked this conversation as resolved.
Show resolved Hide resolved
agent = agent[0]
if isinstance(name, list):
name = name[0]
self.agent_name = name or agent.role_name
self.system_message = agent._system_message
self._inject(agent, name)
return self

def convert(self) -> Dict[str, str]:
r"""Convert the collected data into a dictionary."""
if self.agent_name is None:
raise ValueError("No agent injected")
if history := self.history.get(self.agent_name):
if len(history) != 2:
raise ValueError(
"AlpacaDataCollector only supports one message"
)
data = dict(
instructions=self.system_message.content
if self.system_message
else "",
input=history[0][2].content,
output=history[1][2].content,
)
self.data.append(data)
return data
raise ValueError("No data collected")
141 changes: 141 additions & 0 deletions camel/data_collector/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Self, Tuple, Union

from camel.agents import ChatAgent
from camel.messages.base import BaseMessage
from camel.responses.agent_responses import ChatAgentResponse
from camel.types.enums import OpenAIBackendRole


class BaseDataCollector(ABC):
r"""Base class for data collectors."""

def __init__(self):
self.history: Dict[
str, List[Tuple[int, OpenAIBackendRole, BaseMessage]]
] = defaultdict(list)
self._recording = False
self.agents: List[ChatAgent] = []
self._id = 0
liuxukun2000 marked this conversation as resolved.
Show resolved Hide resolved
self.data: List[Dict[str, Any]] = []

def step(
self,
message: Union[BaseMessage, ChatAgentResponse],
) -> Self:
r"""Record a message.

Args:
message (Union[BaseMessage, ChatAgentResponse]):
The message to record.
"""

name = BaseMessage.role_name
role = BaseMessage.role_type.value
if isinstance(message, ChatAgentResponse):
for msg in message.msgs:
self.history[name].append((self._id, role, msg))
self._id += 1
else:
self.history[name].append((self._id, role, message))
self._id += 1
return self

def _inject(self, agent: ChatAgent) -> Self:
r"""Inject an agent.

Args:
agent (ChatAgent): The agent to inject.
"""
name = agent.role_name
if not name:
name = f"{agent.__class__.__name__}_{len(self.agents)}"
if name in [n for n, _ in self.agents]:
raise ValueError(f"Name {name} already exists")

self.agents.append((name, agent))

ori_update_memory = agent.update_memory

def update_memory(
message: BaseMessage, role: OpenAIBackendRole
liuxukun2000 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
if self._recording:
self.history[name].append((self._id, role, message))
self._id += 1
return ori_update_memory(message, role)

agent.update_memory = update_memory # type: ignore[method-assign]

return self

def inject(
self,
agent: Union[List[ChatAgent], ChatAgent],
) -> Self:
r"""Inject agents.

Args:
agent (Union[List[ChatAgent], ChatAgent]):
The agent(s) to inject.
"""
if not isinstance(agent, list):
agent = [agent]
for a, n in zip(agent, agent.role_name):
self._inject(a, n)
return self

def start(self) -> Self:
r"""Start recording."""
self._recording = True
return self

def stop(self) -> Self:
r"""Stop recording."""
self._recording = False
return self

@property
def recording(self) -> bool:
r"""Whether the collector is recording."""
return self._recording

def reset(self, reset_agents: bool = True):
r"""Reset the collector.

Args:
reset_agents (bool, optional):
Whether to reset the agents. Defaults to True.
"""
self.history = defaultdict(list)
self._id = 0
if reset_agents:
for _, agent in self.agents:
agent.reset()

@abstractmethod
def convert(self) -> Any:
r"""Convert the collected data."""
pass

def save(self, path: str):
liuxukun2000 marked this conversation as resolved.
Show resolved Hide resolved
r"""Save the collected data.

Args:
path (str): The path to save the data.
"""
raise NotImplementedError
103 changes: 103 additions & 0 deletions camel/data_collector/sharegpt_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import json
from typing import Any, Dict, List, Optional, Self, Union

from camel.agents.chat_agent import ChatAgent
from camel.data_collector.base import BaseDataCollector
from camel.messages.base import BaseMessage
from camel.messages.func_message import FunctionCallingMessage
from camel.toolkits.function_tool import FunctionTool
from camel.types.enums import OpenAIBackendRole


class ShareGPTDataCollector(BaseDataCollector):
def __init__(self):
super().__init__()
self.system_message: Optional[BaseMessage] = None
self.agent_name: Optional[str] = None
self.tools: List[FunctionTool] = []

def inject(
self,
agent: Union[List[ChatAgent], ChatAgent],
name: Optional[Union[str, List[Optional[str]]]] = None,
) -> Self:
r"""Inject an agent into the data collector."""
if len(self.agents) > 1:
raise ValueError("ShareGPTDataCollector only supports one agent")
if isinstance(agent, list):
if len(agent) != 1:
raise ValueError(
"ShareGPTDataCollector only supports one agent"
)
agent = agent[0]
if isinstance(name, list):
name = name[0]

self.agent_name = name or agent.role_name
self.system_message = agent._system_message
self.tools = list(agent.tool_dict.values())

self._inject(agent, name)
return self

def convert(self) -> Dict[str, Any]:
r"""Convert the collected data into a dictionary."""
if self.agent_name is None:
raise ValueError("No agent injected")
if history := self.history.get(self.agent_name):
data = dict(
system=self.system_message.content
if self.system_message
else "",
tools=json.dumps(
[
t.get_openai_tool_schema()["function"]
for t in self.tools
]
),
conversations=[],
)
conversations: List[Any] = []
for _, role, message in history:
if role == OpenAIBackendRole.USER:
conversations.append(
{"from": "human", "value": message.content}
)
elif role == OpenAIBackendRole.ASSISTANT:
if isinstance(message, FunctionCallingMessage):
tmp = dict(
name=message.func_name,
arguments=message.args,
)
conversations.append(
{"from": "function_call", "value": json.dumps(tmp)}
)
else:
conversations.append(
{"from": "gpt", "value": message.content}
)
elif role == OpenAIBackendRole.FUNCTION:
conversations.append(
{
"from": "observation",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to make the role here configurable. Also I think checking if the message has a result or calls is a more robust way of differentiating between function call and function result (and tool call and tool result in the future), until we have some better type safety in this area

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Caelum,

I've already switched to using memory to retrieve history and roles. Regarding "Would be good to make the role here configurable," I'm not entirely sure I fully understand what you mean. Could you clarify? Are you suggesting making the roles customizable in some way?

"value": json.dumps(message.result), # type: ignore[attr-defined]
}
)
data["conversations"] = conversations

self.data.append(data)
return data
raise ValueError("No data collected")
Loading
Loading