From 4245020a7548d61b6e76eec21d7a798672a12b08 Mon Sep 17 00:00:00 2001 From: Averi Kitsch Date: Thu, 26 Oct 2023 18:27:31 -0700 Subject: [PATCH] feat: refactor demo frontend (#15) Co-authored-by: Yuan <45984206+Yuan325@users.noreply.github.com> --- DEVELOPER.md | 104 ++++++++---- cloudrun_instructions.md | 119 +++++++++++++ extension_service/app/routes.py | 2 +- langchain_tools_demo/Dockerfile | 32 ++++ langchain_tools_demo/agent.py | 194 ++++++++++++++++++++++ langchain_tools_demo/main.py | 81 +++++++++ langchain_tools_demo/pyproject.toml | 2 +- langchain_tools_demo/requirements.txt | 12 +- langchain_tools_demo/run_app.py | 88 ---------- langchain_tools_demo/static/favicon.png | Bin 0 -> 310 bytes langchain_tools_demo/static/index.css | 130 +++++++++++++++ langchain_tools_demo/static/index.js | 70 ++++++++ langchain_tools_demo/templates/index.html | 69 ++++++++ 13 files changed, 782 insertions(+), 121 deletions(-) create mode 100644 cloudrun_instructions.md create mode 100644 langchain_tools_demo/Dockerfile create mode 100644 langchain_tools_demo/agent.py create mode 100644 langchain_tools_demo/main.py delete mode 100644 langchain_tools_demo/run_app.py create mode 100644 langchain_tools_demo/static/favicon.png create mode 100644 langchain_tools_demo/static/index.css create mode 100644 langchain_tools_demo/static/index.js create mode 100644 langchain_tools_demo/templates/index.html diff --git a/DEVELOPER.md b/DEVELOPER.md index f57b833d..e49a8bf5 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -1,5 +1,9 @@ # DEVELOPER.md +## Pre-reqs + +See [Pre-reqs](./cloudrun_instructions.md). + ## Setup We recommend using Python 3.11+ and installing the requirements into a virtualenv: @@ -12,41 +16,85 @@ If you are developing or otherwise running tests, install the test requirements pip install -r extension_service/requirements-test.txt -r langchain_tools_demo/requirements-test.txt ``` -## Running the server +## Run the app locally +### Running the extension service -Create your database config: -```bash -cd extension_service -cp example-config.yml config.yml -``` +1. Change into the service directory: -Add your values to `config.yml` + ```bash + cd extension_service + ``` -Prepare the database: -```bash -python run_database_init.py -``` +1. Create your database config: -To run the app using uvicorn, execute the following: -```bash -python run_app.py -``` + ```bash + cp example-config.yml config.yml + ``` -## Running the frontend +1. Add your database config to `config.yml`: -To run the app using streamlit, execute the following: -```bash -cd langchain_tools_demo -streamlit run run_app.py -``` +1. Start the Cloud SQL Proxy or AlloyDB SSH tunnel. + +1. To run the app using uvicorn, execute the following: + + ```bash + python run_app.py + ``` + +### Running the frontend + +1. [Optional] Set up [Application Default Credentials](https://cloud.google.com/docs/authentication/application-default-credentials#GAC): + + ```bash + export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json + ``` + +1. Change into the demo directory: + + ```bash + cd langchain_tools_demo + ``` + +1. Set the server port: + + ```bash + export PORT=9090 + ``` + +1. [Optional] Set `BASE_URL` environment variable: + + ```bash + export BASE_URL= + ``` + +1. [Optional] Set `DEBUG` environment variable: + + ```bash + export DEBUG=True + ``` + +1. To run the app using uvicorn, execute the following: + + ```bash + python main.py + ``` + + Note: for hot reloading of the app use: `uvicorn main:app --host 0.0.0.0 --port 9090 --reload` + +1. View app at `http://localhost:9090/` ## Testing -Run pytest to automatically run all tests: -```bash -export DB_USER="" -export DB_PASS="" -export DB_NAME="" -pytest -``` +1. Set environment variables: + + ```bash + export DB_USER="" + export DB_PASS="" + export DB_NAME="" + ``` + +1. Run pytest to automatically run all tests: + ```bash + pytest + ``` diff --git a/cloudrun_instructions.md b/cloudrun_instructions.md new file mode 100644 index 00000000..c0a9111c --- /dev/null +++ b/cloudrun_instructions.md @@ -0,0 +1,119 @@ +# Deploy to Cloud Run + +## Pre-reqs + +* Google Cloud Project +* Enabled APIs: + * Cloud Run + * Vertex AI + * Cloud SQL or AlloyDB + * Compute + * Cloud Build + * Artifact Registry + * Service Networking +* Cloud SQL PostgreSQL instance or AlloyDB cluster and primary instance + +## Datastore Setup + + +## Deployment + +1. For easier deployment, set environment variables: + + ```bash + export PROJECT_ID= + ``` + +1. Create a backend service account: + + ```bash + gcloud iam service-accounts create extension-identity + ``` + +1. Grant permissions to access Cloud SQL and/or AlloyDB: + + ```bash + gcloud projects add-iam-policy-binding $PROJECT_ID \ + --member serviceAccount:extension-identity@$PROJECT_ID.iam.gserviceaccount.com \ + --role roles/cloudsql.client + ``` + + ```bash + gcloud projects add-iam-policy-binding $PROJECT_ID \ + --member serviceAccount:extension-identity@$PROJECT_ID.iam.gserviceaccount.com \ + --role roles/alloydb.client + ``` + +1. Change into the service directory: + + ```bash + cd extension_service + ``` + +1. Deploy backend service to Cloud Run: + + * For Cloud SQL: + + ```bash + gcloud run deploy extension-service \ + --source . \ + --no-allow-unauthenticated \ + --service-account extension-identity \ + --region us-central1 \ + --add-cloudsql-instances + ``` + + * For AlloyDB: + + ```bash + gcloud alpha run deploy extension-service \ + --source . \ + --no-allow-unauthenticated \ + --service-account extension-identity \ + --region us-central1 \ + --network=default \ + --subnet=default + ``` + +1. Retrieve extension URL: + + ```bash + export EXTENSION_URL=$(gcloud run services describe extension-service --format 'value(status.url)') + ``` + +1. Create a frontend service account: + + ```bash + gcloud iam service-accounts create demo-identity + ``` + +1. Grant the service account access to invoke the backend service and VertexAI API: + + ```bash + gcloud run services add-iam-policy-binding extension-service \ + --member serviceAccount:demo-identity@$PROJECT_ID.iam.gserviceaccount.com \ + --role roles/run.invoker + ``` + ```bash + gcloud projects add-iam-policy-binding $PROJECT_ID \ + --member serviceAccount:demo-identity@$PROJECT_ID.iam.gserviceaccount.com \ + --role roles/aiplatform.user + ``` + +1. Change into the service directory: + + ```bash + cd langchain_tools-demos + ``` + +1. Deploy to Cloud Run: + + ```bash + gcloud run deploy demo-service \ + --source . \ + --allow-unauthenticated \ + --set-env-vars=BASE_URL=$EXTENSION_URL \ + --service-account demo-identity + ``` + + Note: Your organization may not allow unauthenticated requests. Deploy with `--no-allow-unauthenticated` and use the proxy to view the frontend: `gcloud run services proxy demo-service`. \ No newline at end of file diff --git a/extension_service/app/routes.py b/extension_service/app/routes.py index 96efe412..6e1fe7f3 100644 --- a/extension_service/app/routes.py +++ b/extension_service/app/routes.py @@ -49,7 +49,7 @@ async def amenities_search(query: str, top_k: int, request: Request): embed_service: Embeddings = request.app.state.embed_service query_embedding = embed_service.embed_query(query) - results = await ds.amenities_search(query_embedding, 0.7, top_k) + results = await ds.amenities_search(query_embedding, 0.3, top_k) return results diff --git a/langchain_tools_demo/Dockerfile b/langchain_tools_demo/Dockerfile new file mode 100644 index 00000000..bb9d537c --- /dev/null +++ b/langchain_tools_demo/Dockerfile @@ -0,0 +1,32 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Use the official lightweight Python image. +# https://hub.docker.com/_/python +FROM python:3.11-slim + +# Allow statements and log messages to immediately appear in the logs +ENV PYTHONUNBUFFERED True + +WORKDIR /app + +# Install production dependencies. +COPY ./requirements.txt requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +# Copy local code to the container image. +COPY . ./ + +# Run the web service on container startup. +CMD exec uvicorn main:app --host 0.0.0.0 --port $PORT diff --git a/langchain_tools_demo/agent.py b/langchain_tools_demo/agent.py new file mode 100644 index 00000000..35819063 --- /dev/null +++ b/langchain_tools_demo/agent.py @@ -0,0 +1,194 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional + +import google.auth.transport.requests # type: ignore +import google.oauth2.id_token # type: ignore +import requests +from langchain.agents import AgentType, initialize_agent +from langchain.agents.agent import AgentExecutor +from langchain.llms.vertexai import VertexAI +from langchain.memory import ConversationBufferMemory +from langchain.tools import tool +from pydantic.v1 import BaseModel, Field + +DEBUG = bool(os.getenv("DEBUG", default=False)) +BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") + + +# Agent +def init_agent() -> AgentExecutor: + """Load an agent executor with tools and LLM""" + print("Initializing agent..") + llm = VertexAI(max_output_tokens=512, verbose=DEBUG) + memory = ConversationBufferMemory( + memory_key="chat_history", + ) + agent = initialize_agent( + tools, + llm, + agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, + verbose=DEBUG, + memory=memory, + handle_parsing_errors=True, + max_iterations=3, + ) + agent.agent.llm_chain.verbose = DEBUG # type: ignore + + return agent + + +# Helper functions +def get_request(url: str, params: dict) -> requests.Response: + """Helper method to make backend requests""" + if "http://" in url: + response = requests.get( + url, + params=params, + ) + else: + response = requests.get( + url, + params=params, + headers={"Authorization": f"Bearer {get_id_token(url)}"}, + ) + return response + + +def get_id_token(url: str) -> str: + """Helper method to generate ID tokens for authenticated requests""" + auth_req = google.auth.transport.requests.Request() + return google.oauth2.id_token.fetch_id_token(auth_req, url) + + +def get_date(): + from datetime import datetime + + now = datetime.now() + return now.strftime("%Y-%m-%dT%H:%M:%S") + + +# Arg Schema for tools +class IdInput(BaseModel): + id: int = Field(description="Unique identifier") + + +class QueryInput(BaseModel): + query: str = Field(description="Search query") + + +class ListFlights(BaseModel): + departure_airport: Optional[str] = Field( + description="Departure airport 3-letter code" + ) + arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code") + date: str = Field(description="Date of flight departure", default=get_date()) + + +# Tool Functions +@tool( + "Get Flight", + args_schema=IdInput, +) +def get_flight(id: int): + """ + Use this tool to get info for a specific flight. + Takes an id and returns info on the flight. + """ + response = get_request( + f"{BASE_URL}/flights", + {"flight_id": id}, + ) + if response.status_code != 200: + return f"Error trying to find flight: {response.text}" + + return response.json() + + +@tool( + "List Flights", + args_schema=ListFlights, +) +def list_flights(departure_airport: str, arrival_airport: str, date: str): + """Use this tool to list all flights matching search criteria.""" + response = get_request( + f"{BASE_URL}/flights/search", + { + "departure_airport": departure_airport, + "arrival_airport": arrival_airport, + "date": date, + }, + ) + if response.status_code != 200: + return f"Error searching flights: {response.text}" + + return response.json() + + +@tool("Get Amenity", args_schema=IdInput) +def get_amenity(id: int): + """ + Use this tool to get info for a specific airport amenity. + Takes an id and returns info on the amenity. + Always use the id from the search_amenities tool. + """ + response = get_request( + f"{BASE_URL}/amenities", + {"id": id}, + ) + if response.status_code != 200: + return f"Error trying to find amenity: {response.text}" + + return response.json() + + +@tool("Search Amenities", args_schema=QueryInput) +def search_amenities(query: str): + """Use this tool to recommended airport amenities at SFO. + Returns several amenities that are related to the query. + Only recommend amenities that are returned by this query. + """ + response = get_request( + f"{BASE_URL}/amenities/search", {"top_k": "5", "query": query} + ) + if response.status_code != 200: + return f"Error searching amenities: {response.text}" + + return response.json() + + +@tool( + "Get Airport", + args_schema=IdInput, +) +def get_airport(id: int): + """ + Use this tool to get info for a specific airport. + Takes an id and returns info on the airport. + Always use the id from the search_airports tool. + """ + response = get_request( + f"{BASE_URL}/airports", + {"id": id}, + ) + if response.status_code != 200: + return f"Error trying to find airport: {response.text}" + + return response.json() + + +# Tools for agent +tools = [get_flight, list_flights, get_amenity, search_amenities, get_airport] diff --git a/langchain_tools_demo/main.py b/langchain_tools_demo/main.py new file mode 100644 index 00000000..79463065 --- /dev/null +++ b/langchain_tools_demo/main.py @@ -0,0 +1,81 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import uvicorn +from fastapi import Body, FastAPI, HTTPException, Request +from fastapi.responses import HTMLResponse, PlainTextResponse +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates +from langchain.agents.agent import AgentExecutor +from markdown import markdown +from starlette.middleware.sessions import SessionMiddleware + +from agent import init_agent + +app = FastAPI() +app.mount("/static", StaticFiles(directory="static"), name="static") +# TODO: set secret_key for production +app.add_middleware(SessionMiddleware, secret_key="SECRET_KEY") +templates = Jinja2Templates(directory="templates") + +agents: dict[str, AgentExecutor] = {} +BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}] + + +@app.get("/", response_class=HTMLResponse) +def index(request: Request): + """Render the default template.""" + request.session.clear() # Clear chat history, if needed + if "uuid" not in request.session: + request.session["uuid"] = str(uuid.uuid4()) + request.session["messages"] = BASE_HISTORY + return templates.TemplateResponse( + "index.html", {"request": request, "messages": request.session["messages"]} + ) + + +@app.post("/chat", response_class=PlainTextResponse) +async def chat_handler(request: Request, prompt: str = Body(embed=True)): + """Handler for LangChain chat requests""" + # Retrieve user prompt + if not prompt: + raise HTTPException(status_code=400, detail="Error: No user query") + + # Add user message to chat history + request.session["messages"] += [{"role": "user", "content": prompt}] + # Agent setup + if "uuid" in request.session and request.session["uuid"] in agents: + agent = agents[request.session["uuid"]] + else: + agent = init_agent() + agents[request.session["uuid"]] = agent + try: + # Send prompt to LLM + response = agent.invoke({"input": prompt}) + request.session["messages"] += [ + {"role": "assistant", "content": response["output"]} + ] + # Return assistant response + return markdown(response["output"]) + except Exception as err: + print(err) + raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") + + +if __name__ == "__main__": + PORT = int(os.getenv("PORT", default=8080)) + uvicorn.run(app, host="0.0.0.0", port=PORT) diff --git a/langchain_tools_demo/pyproject.toml b/langchain_tools_demo/pyproject.toml index 80defaec..7c546217 100644 --- a/langchain_tools_demo/pyproject.toml +++ b/langchain_tools_demo/pyproject.toml @@ -3,4 +3,4 @@ profile = "black" [tool.mypy] python_version = 3.11 -warn_unused_configs = true +warn_unused_configs = true \ No newline at end of file diff --git a/langchain_tools_demo/requirements.txt b/langchain_tools_demo/requirements.txt index 798d3662..94a041ec 100644 --- a/langchain_tools_demo/requirements.txt +++ b/langchain_tools_demo/requirements.txt @@ -1,3 +1,9 @@ -google-cloud-aiplatform==1.34.0 -langchain==0.0.310 -streamlit==1.27.2 \ No newline at end of file +fastapi==0.104.0 +google-cloud-aiplatform==1.35.0 +google-auth==2.23.3 +itsdangerous==2.1.2 +jinja2==3.1.2 +langchain==0.0.320 +markdown==3.5 +types-Markdown==3.5.0.0 +uvicorn[standard]==0.23.2 \ No newline at end of file diff --git a/langchain_tools_demo/run_app.py b/langchain_tools_demo/run_app.py deleted file mode 100644 index d09e1668..00000000 --- a/langchain_tools_demo/run_app.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import requests -import streamlit as st -from langchain.agents import AgentType, initialize_agent -from langchain.llms import VertexAI -from langchain.memory import ConversationBufferMemory -from langchain.tools import Tool - -st.title("Database Extension Testing") -BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") - - -def find_similar_toys(desc: str) -> str: - params = {"top_k": "5", "query": desc} - response = requests.get(f"{BASE_URL}/semantic_similarity_search", params) - - if response.status_code != 200: - return f"Error trying to find similar toys: {response.text}" - - results = [ - "Here are is list of toys related to the query in JSON format. Only use this list in making recommendations to the customer. " - ] + [f"{r}" for r in response.json()] - if len(results) <= 1: - return "There are no toys matching that query. Please try again or let the user know there are no results." - output = "\n".join(results) - # print(results) - return output - - -# Initialize chat history -if "messages" not in st.session_state: - st.session_state.messages = [] - -if "agent" not in st.session_state: - tools = [ - Tool.from_function( - name="find_similar_toys", - func=find_similar_toys, - description="useful when you need a toy recommendation. Returns several toys that are related to the query. Only recommend toys that are returned by this query.", - ), - ] - llm = VertexAI(max_output_tokens=512, verbose=True) - memory = ConversationBufferMemory(memory_key="chat_history") - st.session_state.agent = initialize_agent( - tools, - llm, - agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - memory=memory, - ) - st.session_state.agent.agent.llm_chain.verbose = True # type: ignore - -# Display chat messages from history on app rerun -for message in st.session_state.messages: - with st.chat_message(message["role"]): - st.markdown(message["content"]) - - -# React to user input -if prompt := st.chat_input("What is up?"): - # Display user message in chat message container - st.chat_message("user").markdown(prompt) - # Add user message to chat history - st.session_state.messages.append({"role": "user", "content": prompt}) - - response = st.session_state.agent.invoke({"input": prompt}) - # Display assistant response in chat message container - with st.chat_message("assistant"): - st.markdown(response["output"]) - # Add assistant response to chat history - st.session_state.messages.append( - {"role": "assistant", "content": response["output"]} - ) diff --git a/langchain_tools_demo/static/favicon.png b/langchain_tools_demo/static/favicon.png new file mode 100644 index 0000000000000000000000000000000000000000..0400c9f04712e3603b09d23ad9c96545ee265bad GIT binary patch literal 310 zcmV-60m=S}P)Px#@JU2LR7gwhRapwcFch4gOg));GWBH2NJl2GSss-lLi3X*q;Iw?6E<$s#`W=Y zfFppjspMGg4&X5lE-fc&3RKjKWt_YMSO7%;E^x+T$(~I;!JUBrbFy`d3zm=7ri=jn($W4nF61`j#a_Xw^oxm`CkM~0M zfD3Z12PfXDnKMBGEM5>9y=c1M<4N`+u5v+S1dw6gUGx6{FYvKJ8399@Bme*a07*qo IM6N<$f}~w~d;kCd literal 0 HcmV?d00001 diff --git a/langchain_tools_demo/static/index.css b/langchain_tools_demo/static/index.css new file mode 100644 index 00000000..22a4eaca --- /dev/null +++ b/langchain_tools_demo/static/index.css @@ -0,0 +1,130 @@ +/** + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the `License`); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an `AS IS` BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +html, +body { + height: 100%; +} + +body { + color: #212A3E; + font-family: 'DM Sans', sans-serif; +} + +.container { + position: relative; + width: 70%; + margin: 0px auto; + height: 100%; + display: flex; + flex-direction: column; +} + +.chat-header { + position: relative; + font-size: 16px; + font-weight: 500; + text-align: center; +} + +.chat-wrapper { + display: flex; + flex-direction: column; + overflow: hidden; + height: 100%; +} + +.chat-content { + overflow-y: scroll; +} + +div.chat-content>span { + margin-bottom: 12px; +} + +.chat-bar { + position: relative; + margin-top: auto; + border: 1px solid #212A3E; + border-radius: 10px; + overflow: hidden; + min-height: 48.1px; +} + +.chat-input-container { + display: flex; +} + +.chat-input-container .ip-msg { + width: 100%; + font-size: 14px; + padding: 15px; + color: #212A3E; + border: none; +} + +.chat-input-container span.btn-group { + position: relative; + margin: auto; + margin-right: 10px; + display: flex; +} + +.chat-bubble { + display: block; + padding: 50px; + overflow-wrap: break-word; +} + +.chat-bubble p { + margin: 0; + padding: 0; +} + +div.chat-wrapper div.chat-content span.assistant { + position: relative; + width: 70%; + height: auto; + display: inline-block; + padding: 10px 12px; + background: #9BA4B5; + color: #212A3E; + border-radius: 2px 15px 15px 15px; +} + +div.chat-wrapper div.chat-content span.user { + position: relative; + float: right; + width: 70%; + height: auto; + display: inline-block; + padding: 10px 12px; + background: #212A3E; + color: #F1F6F9; + border-radius: 15px 2px 15px 15px; +} + +*:focus { + outline: none; +} + +.mdl-progress { + width: 100%; + display: none; +} +.mdl-progress.mdl-progress__indeterminate>.bar1 { + background-color: #394867 +} diff --git a/langchain_tools_demo/static/index.js b/langchain_tools_demo/static/index.js new file mode 100644 index 00000000..564660ca --- /dev/null +++ b/langchain_tools_demo/static/index.js @@ -0,0 +1,70 @@ +/** + * Copyright 2023 Google, LLC + * + * Licensed under the Apache License, Version 2.0 (the `License`); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an `AS IS` BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Submit chat message via click +$('.btn-group span').click(async (e) => { + await submitMessage(); +}); + +// Submit chat message via enter +$(document).on("keypress",async (e) => { + if (e.which == 13) { + await submitMessage(); + } +}); + +async function submitMessage() { + let msg = $('.chat-bar input').val(); + // Add message to UI + log("user", msg) + // Clear message + $('.chat-bar input').val(''); + $('.mdl-progress').show() + try { + // Prompt LLM + let answer = await askQuestion(msg); + $('.mdl-progress').hide(); + // Add response to UI + log("assistant", answer) + } catch (err) { + window.alert(`Error when submitting question: ${err}`); + } +} + +// Send request to backend +async function askQuestion(prompt) { + const response = await fetch('/chat', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ prompt }), + }); + if (response.ok) { + const text = await response.text(); + return text + } else { + console.error(await response.text()) + return "Sorry, we couldn't answer your question 😢" + } +} + +// Helper function to print to chatroom +function log(name, msg) { + let message = `${msg}`; + $('.chat-content').append(message); + $('.chat-content').scrollTop($('.chat-content').prop("scrollHeight")); +} diff --git a/langchain_tools_demo/templates/index.html b/langchain_tools_demo/templates/index.html new file mode 100644 index 00000000..7bb31693 --- /dev/null +++ b/langchain_tools_demo/templates/index.html @@ -0,0 +1,69 @@ + + + + + + + + Assistant + + + + + + + + + + + + +
+
+

SFO Airport Assistant

+
+
+
+ {# Add Chat history #} + {% if messages %} + {% for message in messages %} + {{ message["content"] | safe }} + {% endfor %} + {% endif %} +
+
+
+
+ + + + send + + +
+
+
+
+ + + + +