-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #51 from PathOnAI/agent_refactor
make refactoring: 1) ToolRegistry, 2) AgentFactory, 3) AgentManager
- Loading branch information
Showing
24 changed files
with
856 additions
and
1,148 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from litemultiagent.agents.base import BaseAgent | ||
from typing import List, Optional | ||
from litemultiagent.tools.registry import ToolRegistry, Tool | ||
class AtomicAgent(BaseAgent): | ||
def __init__(self, agent_name: str, agent_description, parameter_description, tool_names: List[str], meta_task_id: Optional[str] = None, task_id: Optional[int] = None): | ||
print(tool_names) | ||
tool_registry = ToolRegistry() | ||
available_tools = {} | ||
tools = [] | ||
for tool_name in tool_names: | ||
available_tools[tool_name] = tool_registry.get_tool(tool_name).func | ||
tools.append(tool_registry.get_tool_description(tool_name)) | ||
super().__init__(agent_name, agent_description, parameter_description, tools, available_tools, meta_task_id, task_id) | ||
|
||
def execute(self, task: str) -> str: | ||
return self.send_prompt(task) | ||
|
||
def __call__(self, task: str) -> str: | ||
return self.execute(task) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from litemultiagent.agents.base import BaseAgent | ||
from litemultiagent.tools.registry import ToolRegistry, Tool | ||
from typing import List, Dict, Any, Optional | ||
|
||
class CompositeAgent(BaseAgent): | ||
def __init__(self, agent_name: str, agent_description, parameter_description, sub_agent_configs: List[Dict[str, Any]], tool_names: List[str], meta_task_id: Optional[str] = None, task_id: Optional[int] = None): | ||
#super().__init__(agent_name, tools, meta_task_id, task_id) | ||
self.tool_registry = ToolRegistry() | ||
self.available_tools = {} | ||
self.tools = [] | ||
for tool_name in tool_names: | ||
self.available_tools[tool_name] = self.tool_registry.get_tool(tool_name).func | ||
self.tools.append(self.tool_registry.get_tool_description(tool_name)) | ||
self.sub_agents = self._build_sub_agents(sub_agent_configs) | ||
self._register_sub_agents_as_tools() | ||
super().__init__(agent_name, agent_description, parameter_description, self.tools, self.available_tools, meta_task_id, task_id) | ||
|
||
|
||
def _build_sub_agents(self, sub_agent_configs: List[Dict[str, Any]]) -> List[BaseAgent]: | ||
from litemultiagent.core.agent_factory import AgentFactory # Import here to avoid circular dependency | ||
return [AgentFactory.create_agent(config) for config in sub_agent_configs] | ||
|
||
def _register_sub_agents_as_tools(self): | ||
for sub_agent in self.sub_agents: | ||
ToolRegistry.register(Tool( | ||
sub_agent.agent_name, | ||
sub_agent, | ||
sub_agent.agent_description, | ||
{ | ||
"task": { | ||
"type": "string", | ||
"description": sub_agent.parameter_description, | ||
"required": True | ||
} | ||
} | ||
)) | ||
# Update the tools and available_tools after registering sub-agents | ||
self.tools.extend([self.tool_registry.get_tool_description(sub_agent.agent_name) for sub_agent in self.sub_agents]) | ||
self.available_tools.update({sub_agent.agent_name: sub_agent for sub_agent in self.sub_agents}) | ||
|
||
def execute(self, task: str) -> str: | ||
# Implementation of task execution using sub-agents | ||
# This could involve breaking down the task and delegating to sub-agents | ||
return self.send_prompt(task) | ||
|
||
def __call__(self, task: str) -> str: | ||
return self.execute(task) |
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Dict, Any | ||
from litemultiagent.agents.atomic import AtomicAgent | ||
from litemultiagent.agents.composite import CompositeAgent | ||
from litemultiagent.agents.base import BaseAgent | ||
|
||
class AgentFactory: | ||
@staticmethod | ||
def create_agent(config: Dict[str, Any]) -> BaseAgent: | ||
agent_type = config["type"] | ||
agent_name = config["name"] | ||
meta_task_id = config.get("meta_task_id") | ||
task_id = config.get("task_id") | ||
tools = config.get("tools", []) | ||
agent_description = config["agent_description"] | ||
parameter_description = config["parameter_description"] | ||
|
||
if agent_type == "atomic": | ||
return AtomicAgent(agent_name, agent_description, parameter_description, tools, meta_task_id=meta_task_id, task_id=task_id) | ||
elif agent_type == "composite": | ||
sub_agent_configs = config.get("sub_agents", []) | ||
return CompositeAgent(agent_name, agent_description, parameter_description, sub_agent_configs, tools, meta_task_id=meta_task_id, task_id=task_id) | ||
else: | ||
raise ValueError(f"Unknown agent type: {agent_type}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import Dict, Any | ||
from litemultiagent.core.agent_factory import AgentFactory | ||
from litemultiagent.agents.base import BaseAgent | ||
|
||
class AgentManager: | ||
def __init__(self): | ||
self.agents: Dict[str, BaseAgent] = {} | ||
|
||
def get_agent(self, config: Dict[str, Any]) -> BaseAgent: | ||
agent_name = config['name'] | ||
if agent_name not in self.agents: | ||
self.agents[agent_name] = AgentFactory.create_agent(config) | ||
return self.agents[agent_name] | ||
|
||
def execute_task(self, agent_name: str, task: str) -> str: | ||
if agent_name not in self.agents: | ||
raise ValueError(f"Agent '{agent_name}' not found. Create the agent first using get_agent().") | ||
return self.agents[agent_name].execute(task) |
Oops, something went wrong.