From e72713a874dd306bd5d4f4d12d48e2f7355b5755 Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Tue, 12 Nov 2024 11:45:52 +0800 Subject: [PATCH] [Feat] Add `Sequential` and `AsyncSequential` agents (#270) * add sequential agents * display agent hierarchy * update * simplify arguments --- lagent/agents/__init__.py | 4 +- lagent/agents/agent.py | 96 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/lagent/agents/__init__.py b/lagent/agents/__init__.py index 3e673c19..f06972cc 100644 --- a/lagent/agents/__init__.py +++ b/lagent/agents/__init__.py @@ -1,9 +1,9 @@ -from .agent import Agent, AgentDict, AgentList, AsyncAgent +from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential from .react import AsyncReAct, ReAct from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder __all__ = [ 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM', 'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct', - 'AsyncReAct' + 'AsyncReAct', 'Sequential', 'AsyncSequential' ] diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py index e3b89338..cebfe423 100644 --- a/lagent/agents/agent.py +++ b/lagent/agents/agent.py @@ -2,6 +2,7 @@ import warnings from collections import OrderedDict, UserDict, UserList, abc from functools import wraps +from itertools import chain, repeat from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union from lagent.agents.aggregator import DefaultAggregator @@ -169,7 +170,22 @@ def reset(self, session_id=0): self.memory.reset(session_id=session_id) def __repr__(self): - return f"{self.__class__.__name__}(name='{self.name}', description='{self.description or ''}')" + + def _rcsv_repr(agent, n_indent=1): + res = agent.__class__.__name__ + (f"(name='{agent.name}')" + if agent.name else '') + modules = [ + f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}" + for name, agent in getattr(agent, '_agents', {}).items() + ] + if modules: + res += '(\n' + '\n'.join( + modules) + f'\n{(n_indent - 1) * " "})' + elif not res.endswith(')'): + res += '()' + return res + + return _rcsv_repr(self) class AsyncAgent(Agent): @@ -225,6 +241,78 @@ async def forward(self, return llm_response +class Sequential(Agent): + """Sequential is an agent container that forwards messages to each agent + in the order they are added.""" + + def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs): + super().__init__(**kwargs) + self._agents = OrderedDict() + if not agents: + raise ValueError('At least one agent should be provided') + if isinstance(agents[0], + Iterable) and not isinstance(agents[0], Agent): + if not agents[0]: + raise ValueError('At least one agent should be provided') + agents = agents[0] + for key, agent in enumerate(agents): + if isinstance(agents, Mapping): + key, agent = agent, agents[agent] + elif isinstance(agent, tuple): + key, agent = agent + self.add_agent(key, agent) + + def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]): + assert isinstance( + agent, (Agent, AsyncAgent + )), f'{type(agent)} is not an Agent or AsyncAgent subclass' + self._agents[str(name)] = agent + + def forward(self, + *message: AgentMessage, + session_id=0, + exit_at: Optional[int] = None, + **kwargs) -> AgentMessage: + assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' + if exit_at is None: + exit_at = len(self) - 1 + iterator = chain.from_iterable(repeat(self._agents.values())) + for _ in range(exit_at + 1): + agent = next(iterator) + if isinstance(message, AgentMessage): + message = (message, ) + message = agent(*message, session_id=session_id, **kwargs) + return message + + def __getitem__(self, key): + if isinstance(key, int) and key < 0: + assert key >= -len(self), 'index out of range' + key = len(self) + key + return self._agents[str(key)] + + def __len__(self): + return len(self._agents) + + +class AsyncSequential(Sequential, AsyncAgent): + + async def forward(self, + *message: AgentMessage, + session_id=0, + exit_at: Optional[int] = None, + **kwargs) -> AgentMessage: + assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' + if exit_at is None: + exit_at = len(self) - 1 + iterator = chain.from_iterable(repeat(self._agents.values())) + for _ in range(exit_at + 1): + agent = next(iterator) + if isinstance(message, AgentMessage): + message = (message, ) + message = await agent(*message, session_id=session_id, **kwargs) + return message + + class AgentContainerMixin: def __init_subclass__(cls): @@ -276,18 +364,20 @@ def _backup(d): setattr(cls, method, wrap_api(getattr(cls, method))) -class AgentList(UserList, Agent, AgentContainerMixin): +class AgentList(Agent, UserList, AgentContainerMixin): def __init__(self, agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None): Agent.__init__(self, memory=None) UserList.__init__(self, agents) + self.name = None -class AgentDict(UserDict, Agent, AgentContainerMixin): +class AgentDict(Agent, UserDict, AgentContainerMixin): def __init__(self, agents: Optional[Mapping[str, Union[Agent, AsyncAgent]]] = None): Agent.__init__(self, memory=None) UserDict.__init__(self, agents) + self.name = None