Skip to content

Commit

Permalink
Merge pull request #51 from PathOnAI/agent_refactor
Browse files Browse the repository at this point in the history
make refactoring: 1) ToolRegistry, 2) AgentFactory, 3) AgentManager
  • Loading branch information
TataKKKL authored Sep 28, 2024
2 parents 744e62e + 7d5049a commit 54fe764
Show file tree
Hide file tree
Showing 24 changed files with 856 additions and 1,148 deletions.
Binary file added files/attention.pdf
Binary file not shown.
19 changes: 19 additions & 0 deletions litemultiagent/agents/atomic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from litemultiagent.agents.base import BaseAgent
from typing import List, Optional
from litemultiagent.tools.registry import ToolRegistry, Tool
class AtomicAgent(BaseAgent):
def __init__(self, agent_name: str, agent_description, parameter_description, tool_names: List[str], meta_task_id: Optional[str] = None, task_id: Optional[int] = None):
print(tool_names)
tool_registry = ToolRegistry()
available_tools = {}
tools = []
for tool_name in tool_names:
available_tools[tool_name] = tool_registry.get_tool(tool_name).func
tools.append(tool_registry.get_tool_description(tool_name))
super().__init__(agent_name, agent_description, parameter_description, tools, available_tools, meta_task_id, task_id)

def execute(self, task: str) -> str:
return self.send_prompt(task)

def __call__(self, task: str) -> str:
return self.execute(task)
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,74 @@
from openai import OpenAI
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from config import agent_to_model, model_cost
#from litemultiagent.config.agent_config import agent_to_model, model_cost
# from litemultiagent.agents.base import BaseAgent
# from litemultiagent.tools.registry import ToolRegistry, Tool
from supabase import create_client, Client
from litellm import completion
import os
from dotenv import load_dotenv
_ = load_dotenv()
from litellm import completion
from datetime import datetime
import csv
import json
import os
from datetime import datetime

import csv
import json
import os
from datetime import datetime
_ = load_dotenv()

logger = logging.getLogger(__name__)

# Initialize Supabase client only if environment variables are set

model_cost = {
"gpt-4o-mini": {
"input_price_per_1m": 0.15,
"output_price_per_1m": 0.6,
},
"gemini/gemini-pro": {
"input_price_per_1m": 0,
"output_price_per_1m": 0,
},
"claude-3-5-sonnet-20240620": {
"input_price_per_1m": 3,
"output_price_per_1m": 15,
},
"groq/llama3-8b-8192": {
"input_price_per_1m": 0.05,
"output_price_per_1m": 0.08,
},
}

agent_to_model = {
"main_agent":
{
"model_name" : "gpt-4o-mini",
"tool_choice" : "auto",
},
"io_agent": {
"model_name" : "gpt-4o-mini",
"tool_choice" : "auto",
},
"retrieval_agent": {
"model_name" : "gpt-4o-mini",
"tool_choice" : "auto",
},
"web_retrieval_agent":{
"model_name" : "gpt-4o-mini",
"tool_choice" : "auto",
},
"db_retrieval_agent":{
"model_name" : "gpt-4o-mini",
"tool_choice" : "auto",
},
"exec_agent":{
"model_name" : "gpt-4o-mini",
"tool_choice" : "auto",
},
"file_retrieval_agent":{
"model_name" : "gpt-4o-mini",
"tool_choice" : "auto",
},
}

# Initialize Supabase client
url = os.getenv("SUPABASE_URL")
key = os.getenv("SUPABASE_ANON_KEY")
supabase: Optional[Client] = None
Expand All @@ -32,9 +80,8 @@
except Exception as e:
logger.error(f"Failed to initialize Supabase client: {e}")


class BaseAgent:
def __init__(self, agent_name: str, tools: List[Dict[str, Any]], available_tools: Dict[str, callable],
def __init__(self, agent_name: str, agent_description, parameter_description, tools: List[Dict[str, Any]], available_tools: Dict[str, callable],
meta_task_id: Optional[str] = None, task_id: Optional[int] = None, save_to="csv", log="log"):
self.agent_name = agent_name
self.tools = tools
Expand All @@ -46,13 +93,14 @@ def __init__(self, agent_name: str, tools: List[Dict[str, Any]], available_tools
self.task_id = task_id
self.save_to = save_to
self.log = log


self.agent_description = agent_description
self.parameter_description = parameter_description

def send_prompt(self, content: str) -> str:
self.messages.append({"role": "user", "content": content})
return self._send_completion_request()

# ... (rest of the BaseAgent methods remain the same)
def _send_completion_request(self, depth: int = 0) -> str:
if depth >= 8:
return None
Expand Down
47 changes: 47 additions & 0 deletions litemultiagent/agents/composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from litemultiagent.agents.base import BaseAgent
from litemultiagent.tools.registry import ToolRegistry, Tool
from typing import List, Dict, Any, Optional

class CompositeAgent(BaseAgent):
def __init__(self, agent_name: str, agent_description, parameter_description, sub_agent_configs: List[Dict[str, Any]], tool_names: List[str], meta_task_id: Optional[str] = None, task_id: Optional[int] = None):
#super().__init__(agent_name, tools, meta_task_id, task_id)
self.tool_registry = ToolRegistry()
self.available_tools = {}
self.tools = []
for tool_name in tool_names:
self.available_tools[tool_name] = self.tool_registry.get_tool(tool_name).func
self.tools.append(self.tool_registry.get_tool_description(tool_name))
self.sub_agents = self._build_sub_agents(sub_agent_configs)
self._register_sub_agents_as_tools()
super().__init__(agent_name, agent_description, parameter_description, self.tools, self.available_tools, meta_task_id, task_id)


def _build_sub_agents(self, sub_agent_configs: List[Dict[str, Any]]) -> List[BaseAgent]:
from litemultiagent.core.agent_factory import AgentFactory # Import here to avoid circular dependency
return [AgentFactory.create_agent(config) for config in sub_agent_configs]

def _register_sub_agents_as_tools(self):
for sub_agent in self.sub_agents:
ToolRegistry.register(Tool(
sub_agent.agent_name,
sub_agent,
sub_agent.agent_description,
{
"task": {
"type": "string",
"description": sub_agent.parameter_description,
"required": True
}
}
))
# Update the tools and available_tools after registering sub-agents
self.tools.extend([self.tool_registry.get_tool_description(sub_agent.agent_name) for sub_agent in self.sub_agents])
self.available_tools.update({sub_agent.agent_name: sub_agent for sub_agent in self.sub_agents})

def execute(self, task: str) -> str:
# Implementation of task execution using sub-agents
# This could involve breaking down the task and delegating to sub-agents
return self.send_prompt(task)

def __call__(self, task: str) -> str:
return self.execute(task)
64 changes: 0 additions & 64 deletions litemultiagent/config.py

This file was deleted.

File renamed without changes.
23 changes: 23 additions & 0 deletions litemultiagent/core/agent_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Dict, Any
from litemultiagent.agents.atomic import AtomicAgent
from litemultiagent.agents.composite import CompositeAgent
from litemultiagent.agents.base import BaseAgent

class AgentFactory:
@staticmethod
def create_agent(config: Dict[str, Any]) -> BaseAgent:
agent_type = config["type"]
agent_name = config["name"]
meta_task_id = config.get("meta_task_id")
task_id = config.get("task_id")
tools = config.get("tools", [])
agent_description = config["agent_description"]
parameter_description = config["parameter_description"]

if agent_type == "atomic":
return AtomicAgent(agent_name, agent_description, parameter_description, tools, meta_task_id=meta_task_id, task_id=task_id)
elif agent_type == "composite":
sub_agent_configs = config.get("sub_agents", [])
return CompositeAgent(agent_name, agent_description, parameter_description, sub_agent_configs, tools, meta_task_id=meta_task_id, task_id=task_id)
else:
raise ValueError(f"Unknown agent type: {agent_type}")
18 changes: 18 additions & 0 deletions litemultiagent/core/agent_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Dict, Any
from litemultiagent.core.agent_factory import AgentFactory
from litemultiagent.agents.base import BaseAgent

class AgentManager:
def __init__(self):
self.agents: Dict[str, BaseAgent] = {}

def get_agent(self, config: Dict[str, Any]) -> BaseAgent:
agent_name = config['name']
if agent_name not in self.agents:
self.agents[agent_name] = AgentFactory.create_agent(config)
return self.agents[agent_name]

def execute_task(self, agent_name: str, task: str) -> str:
if agent_name not in self.agents:
raise ValueError(f"Agent '{agent_name}' not found. Create the agent first using get_agent().")
return self.agents[agent_name].execute(task)
Loading

0 comments on commit 54fe764

Please sign in to comment.