-
Notifications
You must be signed in to change notification settings - Fork 69
/
Copy pathmain.py
314 lines (266 loc) · 12.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import asyncio
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import BaseTool
from llama_index.core.workflow import Context
from llama_index.llms.openai import OpenAI
from workflow import (
AgentConfig,
ConciergeAgent,
ProgressEvent,
ToolRequestEvent,
ToolApprovedEvent,
)
from utils import FunctionToolWithContext
def get_initial_state() -> dict:
return {
"username": None,
"session_token": None,
"account_id": None,
"account_balance": None,
}
def get_stock_lookup_tools() -> list[BaseTool]:
def lookup_stock_price(ctx: Context, stock_symbol: str) -> str:
"""Useful for looking up a stock price."""
ctx.write_event_to_stream(
ProgressEvent(msg=f"Looking up stock price for {stock_symbol}")
)
return f"Symbol {stock_symbol} is currently trading at $100.00"
def search_for_stock_symbol(ctx: Context, company_name: str) -> str:
"""Useful for searching for a stock symbol given a free-form company name."""
ctx.write_event_to_stream(ProgressEvent(msg="Searching for stock symbol"))
return company_name.upper()
return [
FunctionToolWithContext.from_defaults(fn=lookup_stock_price),
FunctionToolWithContext.from_defaults(fn=search_for_stock_symbol),
]
def get_authentication_tools() -> list[BaseTool]:
async def is_authenticated(ctx: Context) -> bool:
"""Checks if the user has a session token."""
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
user_state = await ctx.get("user_state")
return user_state["session_token"] is not None
async def store_username(ctx: Context, username: str) -> None:
"""Adds the username to the user state."""
ctx.write_event_to_stream(ProgressEvent(msg="Recording username"))
user_state = await ctx.get("user_state")
user_state["username"] = username
await ctx.set("user_state", user_state)
async def login(ctx: Context, password: str) -> str:
"""Given a password, logs in and stores a session token in the user state."""
user_state = await ctx.get("user_state")
username = user_state["username"]
ctx.write_event_to_stream(ProgressEvent(msg=f"Logging in user {username}"))
# todo: actually check the password
session_token = "1234567890"
user_state["session_token"] = session_token
user_state["account_id"] = "123"
user_state["account_balance"] = 1000
await ctx.set("user_state", user_state)
return f"Logged in user {username} with session token {session_token}. They have an account with id {user_state['account_id']} and a balance of ${user_state['account_balance']}."
return [
FunctionToolWithContext.from_defaults(async_fn=store_username),
FunctionToolWithContext.from_defaults(async_fn=login),
FunctionToolWithContext.from_defaults(async_fn=is_authenticated),
]
def get_account_balance_tools() -> list[BaseTool]:
async def is_authenticated(ctx: Context) -> bool:
"""Checks if the user has a session token."""
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
user_state = await ctx.get("user_state")
return user_state["session_token"] is not None
async def get_account_id(ctx: Context, account_name: str) -> str:
"""Useful for looking up an account ID."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(msg=f"Looking up account ID for {account_name}")
)
user_state = await ctx.get("user_state")
account_id = user_state["account_id"]
return f"Account id is {account_id}"
async def get_account_balance(ctx: Context, account_id: str) -> str:
"""Useful for looking up an account balance."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(msg=f"Looking up account balance for {account_id}")
)
user_state = await ctx.get("user_state")
account_balance = user_state["account_balance"]
return f"Account {account_id} has a balance of ${account_balance}"
return [
FunctionToolWithContext.from_defaults(async_fn=get_account_id),
FunctionToolWithContext.from_defaults(async_fn=get_account_balance),
FunctionToolWithContext.from_defaults(async_fn=is_authenticated),
]
def get_transfer_money_tools() -> list[BaseTool]:
async def is_authenticated(ctx: Context) -> bool:
"""Checks if the user has a session token."""
ctx.write_event_to_stream(ProgressEvent(msg="Checking if authenticated"))
user_state = await ctx.get("user_state")
return user_state["session_token"] is not None
async def transfer_money(
ctx: Context, from_account_id: str, to_account_id: str, amount: int
) -> str:
"""Useful for transferring money between accounts."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(
msg=f"Transferring {amount} from {from_account_id} to account {to_account_id}"
)
)
return f"Transferred {amount} to account {to_account_id}"
async def balance_sufficient(ctx: Context, account_id: str, amount: int) -> bool:
"""Useful for checking if an account has enough money to transfer."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(msg="Checking if balance is sufficient")
)
user_state = await ctx.get("user_state")
return user_state["account_balance"] >= amount
async def has_balance(ctx: Context) -> bool:
"""Useful for checking if an account has a balance."""
is_auth = await is_authenticated(ctx)
if not is_auth:
raise ValueError("User is not authenticated!")
ctx.write_event_to_stream(
ProgressEvent(msg="Checking if account has a balance")
)
user_state = await ctx.get("user_state")
return (
user_state["account_balance"] is not None
and user_state["account_balance"] > 0
)
return [
FunctionToolWithContext.from_defaults(async_fn=transfer_money),
FunctionToolWithContext.from_defaults(async_fn=balance_sufficient),
FunctionToolWithContext.from_defaults(async_fn=has_balance),
FunctionToolWithContext.from_defaults(async_fn=is_authenticated),
]
def get_agent_configs() -> list[AgentConfig]:
return [
AgentConfig(
name="Stock Lookup Agent",
description="Looks up stock prices and symbols",
system_prompt="""
You are a helpful assistant that is looking up stock prices.
The user may not know the stock symbol of the company they're interested in,
so you can help them look it up by the name of the company.
You can only look up stock symbols given to you by the search_for_stock_symbol tool, don't make them up. Trust the output of the search_for_stock_symbol tool even if it doesn't make sense to you.
""",
tools=get_stock_lookup_tools(),
),
AgentConfig(
name="Authentication Agent",
description="Handles user authentication",
system_prompt="""
You are a helpful assistant that is authenticating a user.
Your task is to get a valid session token stored in the user state.
To do this, the user must supply you with a username and a valid password. You can ask them to supply these.
If the user supplies a username and password, call the tool "login" to log them in.
Once the user is logged in and authenticated, you can transfer them to another agent.
""",
tools=get_authentication_tools(),
),
AgentConfig(
name="Account Balance Agent",
description="Checks account balances",
system_prompt="""
You are a helpful assistant that is looking up account balances.
The user may not know the account ID of the account they're interested in,
so you can help them look it up by the name of the account.
The user can only do this if they are authenticated, which you can check with the is_authenticated tool.
If they aren't authenticated, tell them to authenticate first and call the "RequestTransfer" tool.
If they're trying to transfer money, they have to check their account balance first, which you can help with.
""",
tools=get_account_balance_tools(),
),
AgentConfig(
name="Transfer Money Agent",
description="Handles money transfers between accounts",
system_prompt="""
You are a helpful assistant that transfers money between accounts.
The user can only do this if they are authenticated, which you can check with the is_authenticated tool.
If they aren't authenticated, tell them to authenticate first and call the "RequestTransfer" tool.
""",
tools=get_transfer_money_tools(),
tools_requiring_human_confirmation=["transfer_money"],
),
]
async def main():
"""Main function to run the workflow."""
from colorama import Fore, Style
llm = OpenAI(model="gpt-4o", temperature=0.4)
memory = ChatMemoryBuffer.from_defaults(llm=llm)
initial_state = get_initial_state()
agent_configs = get_agent_configs()
workflow = ConciergeAgent(timeout=None)
# draw a diagram of the workflow
# draw_all_possible_flows(workflow, filename="workflow.html")
handler = workflow.run(
user_msg="Hello!",
agent_configs=agent_configs,
llm=llm,
chat_history=[],
initial_state=initial_state,
)
while True:
async for event in handler.stream_events():
if isinstance(event, ToolRequestEvent):
print(
Fore.GREEN
+ "SYSTEM >> I need approval for the following tool call:"
+ Style.RESET_ALL
)
print(event.tool_name)
print(event.tool_kwargs)
print()
approved = input("Do you approve? (y/n): ")
if "y" in approved.lower():
handler.ctx.send_event(
ToolApprovedEvent(
tool_id=event.tool_id,
tool_name=event.tool_name,
tool_kwargs=event.tool_kwargs,
approved=True,
)
)
else:
reason = input("Why not? (reason): ")
handler.ctx.send_event(
ToolApprovedEvent(
tool_name=event.tool_name,
tool_id=event.tool_id,
tool_kwargs=event.tool_kwargs,
approved=False,
response=reason,
)
)
elif isinstance(event, ProgressEvent):
print(Fore.GREEN + f"SYSTEM >> {event.msg}" + Style.RESET_ALL)
result = await handler
print(Fore.BLUE + f"AGENT >> {result['response']}" + Style.RESET_ALL)
# update the memory with only the new chat history
for i, msg in enumerate(result["chat_history"]):
if i >= len(memory.get()):
memory.put(msg)
user_msg = input("USER >> ")
if user_msg.strip().lower() in ["exit", "quit", "bye"]:
break
# pass in the existing context and continue the conversation
handler = workflow.run(
ctx=handler.ctx,
user_msg=user_msg,
agent_configs=agent_configs,
llm=llm,
chat_history=memory.get(),
initial_state=initial_state,
)
if __name__ == "__main__":
asyncio.run(main())