diff --git a/02-household-queries/.gitignore b/02-household-queries/.gitignore index b7e09ac..30f7162 100644 --- a/02-household-queries/.gitignore +++ b/02-household-queries/.gitignore @@ -4,4 +4,4 @@ tanf.pdf guru_cards_for_nava.json chroma_db/ *cache/ -*.log \ No newline at end of file +*.log diff --git a/02-household-queries/chainlit-household-bot.py b/02-household-queries/chainlit-household-bot.py index 1a341fe..92e0123 100755 --- a/02-household-queries/chainlit-household-bot.py +++ b/02-household-queries/chainlit-household-bot.py @@ -6,11 +6,28 @@ import chainlit as cl from chainlit.input_widget import Select, Switch, Slider -from llm import ollama_client +import chromadb +from chromadb.config import Settings + +from langchain.chains import ConversationalRetrievalChain +from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceEmbeddings +from langchain_community.vectorstores import Chroma +from langchain.memory import ChatMessageHistory, ConversationBufferMemory +from langchain_google_genai import GoogleGenerativeAIEmbeddings + +import os + +from ingest import add_json_html_data_to_vector_db, add_pdf_to_vector_db, ingest_call +from llm import google_gemini_client, ollama_client #, gpt4all_client +from retrieval import retrieval_call OLLAMA_LLMS = ["openhermes", "llama2", "mistral"] -OTHER_LLMS = ["someOtherLLM"] +GOOGLE_LLMS = ["gemini-pro"] +# GPT4ALL_LLMS = ["gpt4all"] +GOOGLE_EMBEDDINGS=["Google::models/embedding-001"] +OPEN_SOURCE_EMBEDDINGS=["all-MiniLM-L6-v2"] +HUGGING_FACE_EMBEDDINGS=["HuggingFace::all-MiniLM-L6-v2", "HuggingFace::all-mpnet-base-v2"] @cl.on_chat_start async def init_chat(): @@ -25,19 +42,28 @@ async def init_chat(): value="chooseBetter", label="Demo choosing better response", ), + cl.Action(name="uploadDefaultFiles", value="upload_default_files", label="Load default files into vector DB"), + cl.Action(name="uploadFilesToVectorAct", value="upload_files_to_vector", label="Upload files for vector DB"), + cl.Action(name="resetDB", value="reset_db", label="Reset DB"), ], ).send() - # memory = ConversationBufferMemory(memory_key="chat_history", input_key="input", output_key="output", return_messages=True) - settings = await cl.ChatSettings( [ Select( id="model", - label="LM Model", - values=OLLAMA_LLMS + OTHER_LLMS, + label="LLM Model", + values=OLLAMA_LLMS + GOOGLE_LLMS, + # values=OLLAMA_LLMS + GOOGLE_LLMS + GPT4ALL_LLMS, initial_index=0, ), + Select( + id="embedding", + label="Embeddings", + values= GOOGLE_EMBEDDINGS + OPEN_SOURCE_EMBEDDINGS + HUGGING_FACE_EMBEDDINGS, + initial_index=0, + ), + Switch(id="use_vector_db", label="Use vector db sources", initial=os.environ.get("USE_VECTOR_DB", False)), Slider( id="temperature", label="LLM Temperature", @@ -72,6 +98,12 @@ async def on_click_settings(action: cl.Action): E.g. **bold**, *italic*, `code`, [links](https://www.example.com), etc. """ +@cl.action_callback("resetDB") +async def on_click_resetDB(action: cl.Action): + # reset db after changing to avoid error: embedding dimension does not match collection dimensionality + await init_persistent_client_if_needed() + persistent_client = cl.user_session.get("persistent_client") + persistent_client.reset() @cl.action_callback("stepsDemoAct") async def on_click_stepsDemo(action: cl.Action): @@ -113,6 +145,9 @@ async def update_settings(settings): print("Settings updated:", pprint.pformat(settings, indent=4)) cl.user_session.set("settings", settings) await set_llm_model() + await set_embeddings() + if settings["use_vector_db"]: + await set_vector_db() async def set_llm_model(): @@ -126,9 +161,10 @@ async def set_llm_model(): client = None if llm_name in OLLAMA_LLMS: client = ollama_client(llm_name, settings=llm_settings) - elif llm_name in OTHER_LLMS: - await cl.Message(content=f"TODO: Initialize {llm_name} client").send() - client = None # TODO: Initialize LLM client here... + elif llm_name in GOOGLE_LLMS: + client = google_gemini_client(llm_name, settings=llm_settings) + # elif llm_name in GPT4ALL_LLMS: + # client = gpt4all_client() else: await cl.Message(content=f"Could not initialize model: {llm_name}").send() return @@ -137,18 +173,70 @@ async def set_llm_model(): await msg.stream_token(f"Done setting up {llm_name} LLM") await msg.send() +async def set_embeddings(): + settings = cl.user_session.get("settings") + embeddings = settings["embedding"] + msg = cl.Message( + author="backend", + content=f"Setting up embedding: `{embeddings}`...\n", + ) + embedding = None + if embeddings in GOOGLE_EMBEDDINGS: + GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') + model_name= embeddings.split('::')[1] + embedding = GoogleGenerativeAIEmbeddings(model=model_name, google_api_key=GOOGLE_API_KEY) + elif embeddings in OPEN_SOURCE_EMBEDDINGS: + embedding = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") + elif embeddings in HUGGING_FACE_EMBEDDINGS: + model_name= embeddings.split('::')[1] + embeddings = HuggingFaceEmbeddings(model_name=model_name) + else: + await cl.Message(content=f"Could not initialize embedding: {embeddings}").send() + return + cl.user_session.set("embedding", embedding) + await msg.stream_token(f"Done setting up {embeddings} embedding") + await msg.send() + +async def set_vector_db(): + await init_embedding_function_if_needed() + embeddings = cl.user_session.get("embedding") + msg = cl.Message( + author="backend", + content=f"Setting up Chroma DB with `{embeddings}`...\n", + ) + persistent_client = chromadb.PersistentClient(settings=Settings(allow_reset=True), path="./chroma_db") + cl.user_session.set("persistent_client", persistent_client) + vectordb=Chroma( + client=persistent_client, + collection_name="resources", + persist_directory="./chroma_db", + embedding_function=embeddings + ) + + cl.user_session.set("vectordb", vectordb) + await msg.stream_token("Done setting up vector db") + await msg.send() async def init_llm_client_if_needed(): client = cl.user_session.get("client") if not client: await set_llm_model() - +async def init_embedding_function_if_needed(): + embedding = cl.user_session.get("embedding") + if not embedding: + await set_embeddings() +async def init_persistent_client_if_needed(): + persistent_client=cl.user_session.get("persistent_client") + if persistent_client is None: + await set_vector_db() @cl.on_message async def message_submitted(message: cl.Message): await init_llm_client_if_needed() + await init_embedding_function_if_needed() settings = cl.user_session.get("settings") - + client=cl.user_session.get("client") + vectordb=cl.user_session.get("vectordb") # 3 ways to manage history for LLM: # 1. Use Chainlit # message_history = cl.user_session.get("message_history") @@ -157,13 +245,22 @@ async def message_submitted(message: cl.Message): # 3. Use LlmPrompts lp.register_answer # Reminder to use make_async for long running tasks: https://docs.chainlit.io/guides/sync-async#long-running-synchronous-tasks - + # If options `streaming` is set, or `use_vector_db` is not set, the RAG chain will not be called if settings["streaming"]: - await call_llm_async(message) + if settings["use_vector_db"]: + await cl.Message("Change the setting to use non-streaming instead").send() + else: + await call_llm_async(message) + else: - response = call_llm(message) - await cl.Message(content=f"*Response*: {response}").send() - + if settings["use_vector_db"] and vectordb: + await retrieval_function(vectordb=vectordb, llm=client) + response = retrieval_call(client, vectordb, message.content) + answer = f"Result: {response['result']} \nSources: \n" + "\n".join([doc.metadata for doc in response['source_documents']]) + await cl.Message(content=answer).send() + else: + response = call_llm(message) + await cl.Message(content=f"*Response*: {response}").send() @cl.step(type="llm", show_input=True) async def call_llm_async(message: cl.Message): @@ -182,3 +279,63 @@ def call_llm(message: cl.Message): client = cl.user_session.get("client") response = client.invoke(message.content) return response + + +@cl.action_callback("uploadDefaultFiles") +async def on_click_upload_default_files(action: cl.Action): + await set_vector_db() + vectordb= cl.user_session.get("vectordb") + msg = cl.Message(content="Processing files...", disable_feedback=True) + await msg.send() + + ingest_call(vectordb) + msg.content = "Processing default files done. You can now ask questions!" + await msg.update() + +@cl.action_callback("uploadFilesToVectorAct") +async def on_click_upload_file_query(action: cl.Action): + files = None + # Wait for the user to upload a file + while files is None: + files = await cl.AskFileMessage( + content="Please upload a pdf or json file to begin!", + accept=["text/plain", "application/pdf", "application/json"], + max_size_mb=20, + timeout=180, + ).send() + file = files[0] + # initialize db + await set_vector_db() + vectordb=cl.user_session.get("vectordb") + if(file.type == "application/pdf"): + add_pdf_to_vector_db(vectordb=vectordb, file_path=file.path) + elif(file.type == "application/json"): + add_json_html_data_to_vector_db(vectordb=vectordb, file_path=file.path, content_key="content", index_key="preferredPhrase") + msg = cl.Message(content=f"Processing `{file.name}`...", disable_feedback=True) + await msg.send() + msg.content = f"Processing `{file.name}` done. You can now ask questions!" + await msg.update() + + +async def retrieval_function(vectordb, llm): + retriever = vectordb.as_retriever(search_kwargs={"k": 1}) + message_history = ChatMessageHistory() + + memory = ConversationBufferMemory( + memory_key="chat_history", + output_key="answer", + chat_memory=message_history, + return_messages=True, + ) + + # Create a chain that uses the Chroma vector store + chain = ConversationalRetrievalChain.from_llm( + llm, + chain_type="stuff", + retriever=retriever, + memory=memory, + return_source_documents=True, + ) + + # Let the user know that the system is ready + cl.user_session.set("chain", chain) diff --git a/02-household-queries/llm.py b/02-household-queries/llm.py index 49fa831..0c14475 100644 --- a/02-household-queries/llm.py +++ b/02-household-queries/llm.py @@ -1,6 +1,12 @@ +import dotenv -from langchain_community.llms.ollama import Ollama from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain_community.llms.ollama import Ollama +from langchain_community.llms import GPT4All +from langchain_google_genai import ChatGoogleGenerativeAI + +import os +dotenv.load_dotenv() def ollama_client(model_name=None, callbacks=None, settings=None, print_to_stdout=False): if not callbacks: @@ -13,7 +19,7 @@ def ollama_client(model_name=None, callbacks=None, settings=None, print_to_stdou # "temperature": 0.1, # "system": "", # "template": "", - # See langchain_community/llms/ollama.py + # See https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/ollama.py "stop": None } @@ -22,3 +28,45 @@ def ollama_client(model_name=None, callbacks=None, settings=None, print_to_stdou return Ollama(model=model_name, callbacks=callbacks, **settings) # Add LLM client for other LLMs here... +def gpt4all_client(model_path="./models/mistral-7b-instruct-v0.1.Q4_0.gguf", callbacks=None, settings=None,print_to_stdout=False): + # Open source option + # download Mistral at https://mistral.ai/news/announcing-mistral-7b/ + if not callbacks: + callbacks = [] + if print_to_stdout: + callbacks.append(StreamingStdOutCallbackHandler()) + + if not settings: + settings = { + # "temp": 0.1, + # See https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/gpt4all.py + "stop": None + } + + print("LLM settings:", settings) + + return GPT4All(model=model_path,max_tokens=1000, verbose=True,repeat_last_n=0, **settings) + +def google_gemini_client(model_name="gemini-pro", callbacks=None, settings=None, print_to_stdout=False): + # Get a Google API key by following the steps after clicking on Get an API key button + # at https://ai.google.dev/tutorials/setup + GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') + if not callbacks: + callbacks = [] + if print_to_stdout: + callbacks.append(StreamingStdOutCallbackHandler()) + + if not settings: + settings = { + # "temperature": 0.1, + # "top_k": 1 + # See https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/gpt4all.py + "stop": None + } + + print("LLM settings:", settings) + return ChatGoogleGenerativeAI(model=model_name, + verbose = True,google_api_key=GOOGLE_API_KEY, + convert_system_message_to_human=True, **settings) + + diff --git a/02-household-queries/retrieval.py b/02-household-queries/retrieval.py index 75396ed..46d91ea 100644 --- a/02-household-queries/retrieval.py +++ b/02-household-queries/retrieval.py @@ -2,7 +2,7 @@ from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA -def retrieval_call(llm, vectordb): +def retrieval_call(llm, vectordb, question): # Create the retrieval chain template = """ You are a helpful AI assistant. @@ -23,11 +23,10 @@ def retrieval_call(llm, vectordb): ) - question = os.environ.get("USER_QUERY") - if not question: + # question = os.environ.get("USER_QUERY") + if question is None: print("Please state your question here: ") question = input() - # Invoke the retrieval chain response=retrieval_chain.invoke({"query":question}) print("\n## QUERY: ", question) @@ -36,3 +35,4 @@ def retrieval_call(llm, vectordb): for d in response["source_documents"]: print(d) print() + return response