Skip to content

Commit

Permalink
DST-151: add code for integration (#8)
Browse files Browse the repository at this point in the history
Co-authored-by: Yoom Lam <yoom@navapbc.com>

Integrates chainlit code with ingestion and retrieval function from vector db Chroma.
  • Loading branch information
ccheng26 authored Apr 4, 2024
1 parent 73f2961 commit 7284f1f
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 23 deletions.
2 changes: 1 addition & 1 deletion 02-household-queries/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ tanf.pdf
guru_cards_for_nava.json
chroma_db/
*cache/
*.log
*.log
189 changes: 173 additions & 16 deletions 02-household-queries/chainlit-household-bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,28 @@
import chainlit as cl
from chainlit.input_widget import Select, Switch, Slider

from llm import ollama_client
import chromadb
from chromadb.config import Settings

from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain_google_genai import GoogleGenerativeAIEmbeddings

import os

from ingest import add_json_html_data_to_vector_db, add_pdf_to_vector_db, ingest_call
from llm import google_gemini_client, ollama_client #, gpt4all_client
from retrieval import retrieval_call

OLLAMA_LLMS = ["openhermes", "llama2", "mistral"]
OTHER_LLMS = ["someOtherLLM"]
GOOGLE_LLMS = ["gemini-pro"]
# GPT4ALL_LLMS = ["gpt4all"]

GOOGLE_EMBEDDINGS=["Google::models/embedding-001"]
OPEN_SOURCE_EMBEDDINGS=["all-MiniLM-L6-v2"]
HUGGING_FACE_EMBEDDINGS=["HuggingFace::all-MiniLM-L6-v2", "HuggingFace::all-mpnet-base-v2"]

@cl.on_chat_start
async def init_chat():
Expand All @@ -25,19 +42,28 @@ async def init_chat():
value="chooseBetter",
label="Demo choosing better response",
),
cl.Action(name="uploadDefaultFiles", value="upload_default_files", label="Load default files into vector DB"),
cl.Action(name="uploadFilesToVectorAct", value="upload_files_to_vector", label="Upload files for vector DB"),
cl.Action(name="resetDB", value="reset_db", label="Reset DB"),
],
).send()

# memory = ConversationBufferMemory(memory_key="chat_history", input_key="input", output_key="output", return_messages=True)

settings = await cl.ChatSettings(
[
Select(
id="model",
label="LM Model",
values=OLLAMA_LLMS + OTHER_LLMS,
label="LLM Model",
values=OLLAMA_LLMS + GOOGLE_LLMS,
# values=OLLAMA_LLMS + GOOGLE_LLMS + GPT4ALL_LLMS,
initial_index=0,
),
Select(
id="embedding",
label="Embeddings",
values= GOOGLE_EMBEDDINGS + OPEN_SOURCE_EMBEDDINGS + HUGGING_FACE_EMBEDDINGS,
initial_index=0,
),
Switch(id="use_vector_db", label="Use vector db sources", initial=os.environ.get("USE_VECTOR_DB", False)),
Slider(
id="temperature",
label="LLM Temperature",
Expand Down Expand Up @@ -72,6 +98,12 @@ async def on_click_settings(action: cl.Action):
E.g. **bold**, *italic*, `code`, [links](https://www.example.com), etc.
"""

@cl.action_callback("resetDB")
async def on_click_resetDB(action: cl.Action):
# reset db after changing to avoid error: embedding dimension does not match collection dimensionality
await init_persistent_client_if_needed()
persistent_client = cl.user_session.get("persistent_client")
persistent_client.reset()

@cl.action_callback("stepsDemoAct")
async def on_click_stepsDemo(action: cl.Action):
Expand Down Expand Up @@ -113,6 +145,9 @@ async def update_settings(settings):
print("Settings updated:", pprint.pformat(settings, indent=4))
cl.user_session.set("settings", settings)
await set_llm_model()
await set_embeddings()
if settings["use_vector_db"]:
await set_vector_db()


async def set_llm_model():
Expand All @@ -126,9 +161,10 @@ async def set_llm_model():
client = None
if llm_name in OLLAMA_LLMS:
client = ollama_client(llm_name, settings=llm_settings)
elif llm_name in OTHER_LLMS:
await cl.Message(content=f"TODO: Initialize {llm_name} client").send()
client = None # TODO: Initialize LLM client here...
elif llm_name in GOOGLE_LLMS:
client = google_gemini_client(llm_name, settings=llm_settings)
# elif llm_name in GPT4ALL_LLMS:
# client = gpt4all_client()
else:
await cl.Message(content=f"Could not initialize model: {llm_name}").send()
return
Expand All @@ -137,18 +173,70 @@ async def set_llm_model():
await msg.stream_token(f"Done setting up {llm_name} LLM")
await msg.send()

async def set_embeddings():
settings = cl.user_session.get("settings")
embeddings = settings["embedding"]
msg = cl.Message(
author="backend",
content=f"Setting up embedding: `{embeddings}`...\n",
)
embedding = None
if embeddings in GOOGLE_EMBEDDINGS:
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
model_name= embeddings.split('::')[1]
embedding = GoogleGenerativeAIEmbeddings(model=model_name, google_api_key=GOOGLE_API_KEY)
elif embeddings in OPEN_SOURCE_EMBEDDINGS:
embedding = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
elif embeddings in HUGGING_FACE_EMBEDDINGS:
model_name= embeddings.split('::')[1]
embeddings = HuggingFaceEmbeddings(model_name=model_name)
else:
await cl.Message(content=f"Could not initialize embedding: {embeddings}").send()
return
cl.user_session.set("embedding", embedding)
await msg.stream_token(f"Done setting up {embeddings} embedding")
await msg.send()

async def set_vector_db():
await init_embedding_function_if_needed()
embeddings = cl.user_session.get("embedding")
msg = cl.Message(
author="backend",
content=f"Setting up Chroma DB with `{embeddings}`...\n",
)
persistent_client = chromadb.PersistentClient(settings=Settings(allow_reset=True), path="./chroma_db")
cl.user_session.set("persistent_client", persistent_client)
vectordb=Chroma(
client=persistent_client,
collection_name="resources",
persist_directory="./chroma_db",
embedding_function=embeddings
)

cl.user_session.set("vectordb", vectordb)
await msg.stream_token("Done setting up vector db")
await msg.send()

async def init_llm_client_if_needed():
client = cl.user_session.get("client")
if not client:
await set_llm_model()

async def init_embedding_function_if_needed():
embedding = cl.user_session.get("embedding")
if not embedding:
await set_embeddings()
async def init_persistent_client_if_needed():
persistent_client=cl.user_session.get("persistent_client")
if persistent_client is None:
await set_vector_db()

@cl.on_message
async def message_submitted(message: cl.Message):
await init_llm_client_if_needed()
await init_embedding_function_if_needed()
settings = cl.user_session.get("settings")

client=cl.user_session.get("client")
vectordb=cl.user_session.get("vectordb")
# 3 ways to manage history for LLM:
# 1. Use Chainlit
# message_history = cl.user_session.get("message_history")
Expand All @@ -157,13 +245,22 @@ async def message_submitted(message: cl.Message):
# 3. Use LlmPrompts lp.register_answer

# Reminder to use make_async for long running tasks: https://docs.chainlit.io/guides/sync-async#long-running-synchronous-tasks

# If options `streaming` is set, or `use_vector_db` is not set, the RAG chain will not be called
if settings["streaming"]:
await call_llm_async(message)
if settings["use_vector_db"]:
await cl.Message("Change the setting to use non-streaming instead").send()
else:
await call_llm_async(message)

else:
response = call_llm(message)
await cl.Message(content=f"*Response*: {response}").send()

if settings["use_vector_db"] and vectordb:
await retrieval_function(vectordb=vectordb, llm=client)
response = retrieval_call(client, vectordb, message.content)
answer = f"Result: {response['result']} \nSources: \n" + "\n".join([doc.metadata for doc in response['source_documents']])
await cl.Message(content=answer).send()
else:
response = call_llm(message)
await cl.Message(content=f"*Response*: {response}").send()

@cl.step(type="llm", show_input=True)
async def call_llm_async(message: cl.Message):
Expand All @@ -182,3 +279,63 @@ def call_llm(message: cl.Message):
client = cl.user_session.get("client")
response = client.invoke(message.content)
return response


@cl.action_callback("uploadDefaultFiles")
async def on_click_upload_default_files(action: cl.Action):
await set_vector_db()
vectordb= cl.user_session.get("vectordb")
msg = cl.Message(content="Processing files...", disable_feedback=True)
await msg.send()

ingest_call(vectordb)
msg.content = "Processing default files done. You can now ask questions!"
await msg.update()

@cl.action_callback("uploadFilesToVectorAct")
async def on_click_upload_file_query(action: cl.Action):
files = None
# Wait for the user to upload a file
while files is None:
files = await cl.AskFileMessage(
content="Please upload a pdf or json file to begin!",
accept=["text/plain", "application/pdf", "application/json"],
max_size_mb=20,
timeout=180,
).send()
file = files[0]
# initialize db
await set_vector_db()
vectordb=cl.user_session.get("vectordb")
if(file.type == "application/pdf"):
add_pdf_to_vector_db(vectordb=vectordb, file_path=file.path)
elif(file.type == "application/json"):
add_json_html_data_to_vector_db(vectordb=vectordb, file_path=file.path, content_key="content", index_key="preferredPhrase")
msg = cl.Message(content=f"Processing `{file.name}`...", disable_feedback=True)
await msg.send()
msg.content = f"Processing `{file.name}` done. You can now ask questions!"
await msg.update()


async def retrieval_function(vectordb, llm):
retriever = vectordb.as_retriever(search_kwargs={"k": 1})
message_history = ChatMessageHistory()

memory = ConversationBufferMemory(
memory_key="chat_history",
output_key="answer",
chat_memory=message_history,
return_messages=True,
)

# Create a chain that uses the Chroma vector store
chain = ConversationalRetrievalChain.from_llm(
llm,
chain_type="stuff",
retriever=retriever,
memory=memory,
return_source_documents=True,
)

# Let the user know that the system is ready
cl.user_session.set("chain", chain)
52 changes: 50 additions & 2 deletions 02-household-queries/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import dotenv

from langchain_community.llms.ollama import Ollama
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.llms.ollama import Ollama
from langchain_community.llms import GPT4All
from langchain_google_genai import ChatGoogleGenerativeAI

import os
dotenv.load_dotenv()

def ollama_client(model_name=None, callbacks=None, settings=None, print_to_stdout=False):
if not callbacks:
Expand All @@ -13,7 +19,7 @@ def ollama_client(model_name=None, callbacks=None, settings=None, print_to_stdou
# "temperature": 0.1,
# "system": "",
# "template": "",
# See langchain_community/llms/ollama.py
# See https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/ollama.py
"stop": None
}

Expand All @@ -22,3 +28,45 @@ def ollama_client(model_name=None, callbacks=None, settings=None, print_to_stdou
return Ollama(model=model_name, callbacks=callbacks, **settings)

# Add LLM client for other LLMs here...
def gpt4all_client(model_path="./models/mistral-7b-instruct-v0.1.Q4_0.gguf", callbacks=None, settings=None,print_to_stdout=False):
# Open source option
# download Mistral at https://mistral.ai/news/announcing-mistral-7b/
if not callbacks:
callbacks = []
if print_to_stdout:
callbacks.append(StreamingStdOutCallbackHandler())

if not settings:
settings = {
# "temp": 0.1,
# See https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/gpt4all.py
"stop": None
}

print("LLM settings:", settings)

return GPT4All(model=model_path,max_tokens=1000, verbose=True,repeat_last_n=0, **settings)

def google_gemini_client(model_name="gemini-pro", callbacks=None, settings=None, print_to_stdout=False):
# Get a Google API key by following the steps after clicking on Get an API key button
# at https://ai.google.dev/tutorials/setup
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
if not callbacks:
callbacks = []
if print_to_stdout:
callbacks.append(StreamingStdOutCallbackHandler())

if not settings:
settings = {
# "temperature": 0.1,
# "top_k": 1
# See https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/gpt4all.py
"stop": None
}

print("LLM settings:", settings)
return ChatGoogleGenerativeAI(model=model_name,
verbose = True,google_api_key=GOOGLE_API_KEY,
convert_system_message_to_human=True, **settings)


8 changes: 4 additions & 4 deletions 02-household-queries/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA

def retrieval_call(llm, vectordb):
def retrieval_call(llm, vectordb, question):
# Create the retrieval chain
template = """
You are a helpful AI assistant.
Expand All @@ -23,11 +23,10 @@ def retrieval_call(llm, vectordb):

)

question = os.environ.get("USER_QUERY")
if not question:
# question = os.environ.get("USER_QUERY")
if question is None:
print("Please state your question here: ")
question = input()

# Invoke the retrieval chain
response=retrieval_chain.invoke({"query":question})
print("\n## QUERY: ", question)
Expand All @@ -36,3 +35,4 @@ def retrieval_call(llm, vectordb):
for d in response["source_documents"]:
print(d)
print()
return response

0 comments on commit 7284f1f

Please sign in to comment.