Skip to content

Commit

Permalink
fix(langchain_tools_demo): fix agent concurrency between restarts (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtisvg committed Jan 29, 2024
1 parent 8730268 commit 2584154
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 33 deletions.
14 changes: 10 additions & 4 deletions langchain_tools_demo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from langchain.agents.agent import AgentExecutor
from langchain.globals import set_verbose # type: ignore
from langchain.llms.vertexai import VertexAI
from langchain.memory import ConversationBufferMemory
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain.prompts.chat import ChatPromptTemplate
from langchain_core import messages

from tools import initialize_tools

Expand Down Expand Up @@ -72,17 +73,22 @@ async def create_client_session(user_id_token: Optional[str]) -> aiohttp.ClientS
connector=await get_connector(),
connector_owner=False,
headers=headers,
raise_for_status=handle_error_response,
raise_for_status=True,
)


# Agent
async def init_agent(user_id_token: Optional[Any]) -> UserAgent:
async def init_agent(
user_id_token: Optional[Any], history: list[messages.BaseMessage]
) -> UserAgent:
"""Load an agent executor with tools and LLM"""
print("Initializing agent..")
llm = VertexAI(max_output_tokens=512, model_name="gemini-pro")
memory = ConversationBufferMemory(
memory_key="chat_history", input_key="input", output_key="output"
chat_memory=ChatMessageHistory(messages=history),
memory_key="chat_history",
input_key="input",
output_key="output",
)
client = await create_client_session(user_id_token)
tools = await initialize_tools(client)
Expand Down
66 changes: 42 additions & 24 deletions langchain_tools_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@
import os
import uuid
from contextlib import asynccontextmanager
from typing import Any

import uvicorn
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse
from fastapi.responses import PlainTextResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
message_to_dict,
messages_from_dict,
messages_to_dict,
)
from markdown import markdown
from starlette.middleware.sessions import SessionMiddleware

Expand All @@ -47,28 +56,23 @@ async def lifespan(app: FastAPI):
# TODO: set secret_key for production
app.add_middleware(SessionMiddleware, secret_key="SECRET_KEY")
templates = Jinja2Templates(directory="templates")
BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}]
BASE_HISTORY: list[BaseMessage] = [
AIMessage(content="I am an SFO Airport Assistant, ready to assist you.")
]


@app.route("/", methods=["GET", "POST"])
async def index(request: Request):
"""Render the default template."""
request.session["client_id"] = os.getenv("CLIENT_ID")
if "uuid" not in request.session:
request.session["uuid"] = str(uuid.uuid4())
request.session["messages"] = BASE_HISTORY
# Agent setup
if request.session["uuid"] in user_agents:
user_agent = user_agents[request.session["uuid"]]
else:
user_agent = await init_agent(user_id_token=None)
user_agents[request.session["uuid"]] = user_agent
agent = await get_agent(request.session)
print(request.session["history"])
return templates.TemplateResponse(
"index.html",
{
"request": request,
"messages": request.session["messages"],
"client_id": request.session["client_id"],
"messages": request.session["history"],
"client_id": request.session.get("client_id"),
},
)

Expand All @@ -82,10 +86,7 @@ async def login_google(
if user_id_token is None:
raise HTTPException(status_code=401, detail="No user credentials found")
# create new request session
request.session["uuid"] = str(uuid.uuid4())
request.session["messages"] = BASE_HISTORY
user_agent = await init_agent(user_id_token)
user_agents[request.session["uuid"]] = user_agent
_ = await get_agent(request.session)
print("Logged in to Google.")

# Redirect to source URL
Expand All @@ -105,26 +106,43 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):
)

# Add user message to chat history
request.session["messages"] += [{"role": "user", "content": prompt}]
user_agent = user_agents[request.session["uuid"]]
request.session["history"].append(message_to_dict(HumanMessage(content=prompt)))
user_agent = await get_agent(request.session)
try:
# Send prompt to LLM
response = await user_agent.agent.ainvoke({"input": prompt})
request.session["messages"] += [
{"role": "assistant", "content": response["output"]}
]
# Return assistant response
request.session["history"].append(
message_to_dict(AIMessage(content=response["output"]))
)
return markdown(response["output"])
except Exception as err:
raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}")


async def get_agent(session: dict[str, Any]):
global user_agents
if "uuid" not in session:
session["uuid"] = str(uuid.uuid4())
id = session["uuid"]
if "history" not in session:
session["history"] = messages_to_dict(BASE_HISTORY)
if uuid not in user_agents:
user_agents[id] = await init_agent(
session["uuid"], messages_from_dict(session["history"])
)
return user_agents[id]


@app.post("/reset")
async def reset(request: Request):
"""Reset agent"""
global user_agents
uuid = request.session["uuid"]

if "uuid" not in request.session:
raise HTTPException(status_code=400, detail=f"No session to reset.")

uuid = request.session["uuid"]
global user_agents
if uuid not in user_agents.keys():
raise HTTPException(status_code=500, detail=f"Current agent not found")

Expand Down
4 changes: 2 additions & 2 deletions langchain_tools_demo/static/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ div.chat-content>span {
padding: 0;
}

div.chat-wrapper div.chat-content span.assistant {
div.chat-wrapper div.chat-content span.ai {
position: relative;
width: 70%;
height: auto;
Expand All @@ -119,7 +119,7 @@ div.chat-wrapper div.chat-content span.assistant {
border-radius: 2px 15px 15px 15px;
}

div.chat-wrapper div.chat-content span.user {
div.chat-wrapper div.chat-content span.human {
position: relative;
float: right;
width: 70%;
Expand Down
4 changes: 2 additions & 2 deletions langchain_tools_demo/static/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ $('#resetButton').click(async (e) => {
async function submitMessage() {
let msg = $('.chat-bar input').val();
// Add message to UI
log("user", msg)
log("human", msg)
// Clear message
$('.chat-bar input').val('');
$('.mdl-progress').show()
Expand All @@ -43,7 +43,7 @@ async function submitMessage() {
let answer = await askQuestion(msg);
$('.mdl-progress').hide();
// Add response to UI
log("assistant", answer)
log("ai", answer)
} catch (err) {
window.alert(`Error when submitting question: ${err}`);
}
Expand Down
2 changes: 1 addition & 1 deletion langchain_tools_demo/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ <h1>SFO Airport Assistant</h1>
{# Add Chat history #}
{% if messages %}
{% for message in messages %}
<span class="chat-bubble {{ message["role"] }}">{{ message["content"] | safe }}</span>
<span class="chat-bubble {{ message["type"] }}">{{ message["data"]["content"] | safe }}</span>
{% endfor %}
{% endif %}
</div>
Expand Down
1 change: 1 addition & 0 deletions langchain_tools_demo/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import aiohttp
import google.oauth2.id_token # type: ignore
from google.auth.transport.requests import Request # type: ignore
from langchain.agents.agent import ExceptionTool # type: ignore
from langchain.tools import StructuredTool
from pydantic.v1 import BaseModel, Field

Expand Down

0 comments on commit 2584154

Please sign in to comment.