Skip to content

Commit

Permalink
oai assistant and pattern fixes (autogenhub#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored May 28, 2024
1 parent f4a5835 commit ecbc3b7
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 111 deletions.
46 changes: 33 additions & 13 deletions examples/patterns.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import argparse
import asyncio
from typing import Any
import logging

import openai
from agnext.agent_components.model_client import OpenAI
from agnext.application_components import (
SingleThreadedAgentRuntime,
)
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.messages import ChatMessage
from agnext.chat.patterns.group_chat import GroupChat, Output
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
from agnext.chat.patterns.orchestrator import Orchestrator
from agnext.chat.types import TextMessage
from agnext.core._agent import Agent
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
from typing_extensions import Any, override

logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)

class ConcatOutput(Output):

class ConcatOutput(GroupChatOutput):
def __init__(self) -> None:
self._output = ""

Expand All @@ -32,8 +37,26 @@ def reset(self) -> None:
self._output = ""


class LoggingHandler(DefaultInterventionHandler):
@override
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
if sender is None:
print(f"Sending message to {recipient.name}: {message}")
else:
print(f"Sending message from {sender.name} to {recipient.name}: {message}")
return message

@override
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
if recipient is None:
print(f"Received response from {sender.name}: {message}")
else:
print(f"Received response from {sender.name} to {recipient.name}: {message}")
return message


async def group_chat(message: str) -> None:
runtime = SingleThreadedAgentRuntime()
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())

joe_oai_assistant = openai.beta.assistants.create(
model="gpt-3.5-turbo",
Expand Down Expand Up @@ -67,16 +90,16 @@ async def group_chat(message: str) -> None:

chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput())

response = runtime.send_message(ChatMessage(body=message, sender="host"), chat)
response = runtime.send_message(TextMessage(content=message, source="host"), chat)

while not response.done():
await runtime.process_next()

print((await response).body) # type: ignore
await response


async def orchestrator(message: str) -> None:
runtime = SingleThreadedAgentRuntime()
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())

developer_oai_assistant = openai.beta.assistants.create(
model="gpt-3.5-turbo",
Expand Down Expand Up @@ -117,17 +140,14 @@ async def orchestrator(message: str) -> None:
)

response = runtime.send_message(
ChatMessage(
body=message,
sender="customer",
),
TextMessage(content=message, source="customer"),
chat,
)

while not response.done():
await runtime.process_next()

print((await response).body) # type: ignore
print((await response).content) # type: ignore


if __name__ == "__main__":
Expand Down
21 changes: 11 additions & 10 deletions src/agnext/chat/agents/oai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Dict

import openai

from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
Expand All @@ -15,22 +17,18 @@ def __init__(
client: openai.AsyncClient,
assistant_id: str,
thread_id: str,
tools: Dict[str, Callable[..., str]] | None = None,
) -> None:
super().__init__(name, description, runtime)
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
self._current_session_window_length = 0
# TODO: investigate why this is 1, as setting this to 0 causes the earlest message in the window to be ignored.
self._current_session_window_length = 1
self._tools = tools or {}

# TODO: use require_response
@message_handler(TextMessage)
async def on_chat_message_with_cancellation(
self, message: TextMessage, cancellation_token: CancellationToken
) -> None:
print("---------------")
print(f"{self.name} received message from {message.source}: {message.content}")
print("---------------")

async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
# Save the message to the thread.
_ = await self._client.beta.threads.messages.create(
thread_id=self._thread_id,
Expand All @@ -43,7 +41,7 @@ async def on_chat_message_with_cancellation(
@message_handler(Reset)
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the current session window.
self._current_session_window_length = 0
self._current_session_window_length = 1

@message_handler(RespondNow)
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
Expand All @@ -61,6 +59,9 @@ async def on_respond_now(self, message: RespondNow, cancellation_token: Cancella
# TODO: handle other statuses.
raise ValueError(f"Run did not complete successfully: {run}")

# Increment the current session window length.
self._current_session_window_length += 1

# Get the last message from the run.
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
last_message_content = response.data[0].content
Expand Down
39 changes: 21 additions & 18 deletions src/agnext/chat/patterns/group_chat.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,64 @@
from typing import Any, List, Protocol, Sequence

from agnext.chat.types import Reset, RespondNow

from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ..types import Reset, RespondNow, TextMessage


class Output(Protocol):
class GroupChatOutput(Protocol):
def on_message_received(self, message: Any) -> None: ...

def get_output(self) -> Any: ...

def reset(self) -> None: ...


class GroupChat(BaseChatAgent):
class GroupChat(BaseChatAgent, TypeRoutedAgent):
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
agents: Sequence[BaseChatAgent],
num_rounds: int,
output: Output,
output: GroupChatOutput,
) -> None:
super().__init__(name, description, runtime)
self._agents = agents
self._num_rounds = num_rounds
self._history: List[Any] = []
self._output = output
super().__init__(name, description, runtime)

@property
def subscriptions(self) -> Sequence[type]:
agent_sublists = [agent.subscriptions for agent in self._agents]
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]

async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
if isinstance(message, Reset):
# Reset the history.
self._history = []
# TODO: reset sub-agents?
@message_handler(Reset)
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
self._history.clear()

if isinstance(message, RespondNow):
# TODO reset...
return self._output.get_output()
@message_handler(RespondNow)
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> Any:
return self._output.get_output()

@message_handler(TextMessage)
async def on_text_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
# Should this class disallow it?

self._history.append(message)
round = 0
prev_speaker = None

while round < self._num_rounds:
# TODO: add support for advanced speaker selection.
# Select speaker (round-robin for now).
speaker = self._agents[round % len(self._agents)]

# Send the last message to all agents.
for agent in [agent for agent in self._agents]:
# Send the last message to all agents except the previous speaker.
for agent in [agent for agent in self._agents if agent is not prev_speaker]:
# TODO gather and await
_ = await self._send_message(
self._history[-1],
Expand All @@ -66,19 +67,21 @@ async def on_message(self, message: Any, cancellation_token: CancellationToken)
)
# TODO handle if response is not None

# Request the speaker to speak.
response = await self._send_message(
RespondNow(),
speaker,
cancellation_token=cancellation_token,
)

if response is not None:
# 4. Append the response to the history.
# Append the response to the history.
self._history.append(response)
self._output.on_message_received(response)

# 6. Increment the round.
# Increment the round.
round += 1
prev_speaker = speaker

output = self._output.get_output()
self._output.reset()
Expand Down
Loading

0 comments on commit ecbc3b7

Please sign in to comment.