From 1b30652ba311ecf8297de5abdcbd88013fd4bdcc Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Thu, 4 Apr 2024 13:03:56 -0500 Subject: [PATCH] Add ruff format check; Fix existing formatting --- .github/workflows/ci-linter-ruff.yml | 3 + 01-resource-referral/debugging.py | 2 + 01-resource-referral/langgraph-workflow.py | 12 +- 01-resource-referral/my_tools.py | 83 ++++++++++---- 02-household-queries/api_helpers.py | 7 +- .../chainlit-household-bot.py | 108 ++++++++++++------ 02-household-queries/ingest.py | 23 +++- 02-household-queries/llm.py | 41 +++++-- 02-household-queries/retrieval.py | 4 +- 02-household-queries/run.py | 30 +++-- 10 files changed, 219 insertions(+), 94 deletions(-) diff --git a/.github/workflows/ci-linter-ruff.yml b/.github/workflows/ci-linter-ruff.yml index 85f3253..647de35 100644 --- a/.github/workflows/ci-linter-ruff.yml +++ b/.github/workflows/ci-linter-ruff.yml @@ -6,3 +6,6 @@ jobs: steps: - uses: actions/checkout@v4 - uses: chartboost/ruff-action@v1 + - uses: chartboost/ruff-action@v1 + with: + args: 'format --check --diff' diff --git a/01-resource-referral/debugging.py b/01-resource-referral/debugging.py index 3a59b03..3572e02 100644 --- a/01-resource-referral/debugging.py +++ b/01-resource-referral/debugging.py @@ -6,6 +6,7 @@ from langchain_core.runnables import RunnableLambda from langchain_core.prompt_values import PromptValue + def stacktrace(): traceback.print_stack() @@ -20,6 +21,7 @@ def debug_here(local_vars): def debug_runnable(prefix: str): """Useful to see output/input between Runnables in a LangChain""" + def debug_chainlink(x): print(f"{prefix if prefix else 'DEBUG_CHAINLINK'}") if isinstance(x, PromptValue): diff --git a/01-resource-referral/langgraph-workflow.py b/01-resource-referral/langgraph-workflow.py index 710b9c2..5d81c20 100755 --- a/01-resource-referral/langgraph-workflow.py +++ b/01-resource-referral/langgraph-workflow.py @@ -4,6 +4,7 @@ from dotenv import main import operator from typing import TypedDict, Annotated, Sequence + # import os import graphviz # type: ignore @@ -32,7 +33,6 @@ class WorkflowState(TypedDict): class MyWorkflow: - def __init__(self, model_name: str, tools: list): main.load_dotenv() self.graph = self._init_graph() @@ -42,7 +42,6 @@ def __init__(self, model_name: str, tools: list): # tool_executor will be used to call the tool specified by the LLM in the llm_chain self.tool_executor = ToolExecutor(tools) - def _init_graph(self): graph = StateGraph(WorkflowState) graph.add_node("decision_node", self.check_for_final_answer) @@ -92,7 +91,7 @@ def _draw_graph(self, graph: StateGraph): # Determines next node to call def decide_next_node(self, state): - print("\nNEXT_EDGE") # , json.dumps(state, indent=2)) + print("\nNEXT_EDGE") # , json.dumps(state, indent=2)) if state["final_answer"]: return END @@ -120,7 +119,7 @@ def check_for_final_answer(self, state): print("\nHAS_FINAL_ANSWER node: Waiting for more responses") def _got_responses_from_all_tools(self, state): - expected_tools = [ "spreadsheet", "211_api" ] + expected_tools = ["spreadsheet", "211_api"] return all(key in state["tool_responses"] for key in expected_tools) def run_llms(self, state): @@ -148,7 +147,7 @@ def llm_spreadsheet_query(self, state): llm_response = self.invoke_user_message("query_spreadsheet", user_message) return {"messages": [llm_response]} - def invoke_user_message(self, tool, user_message): + def invoke_user_message(self, tool, user_message): return self.llm_chain[tool].invoke(user_message) def call_211_tool(self, state): @@ -235,5 +234,4 @@ def merge_results(self, state): final_state = runnable_graph.invoke(inputs) # print("\nFINAL_STATE", type(final_state), final_state) print("\nFINAL_ANSWER") -print(final_state['final_answer']) - +print(final_state["final_answer"]) diff --git a/01-resource-referral/my_tools.py b/01-resource-referral/my_tools.py index 0e8e6bb..84b368d 100644 --- a/01-resource-referral/my_tools.py +++ b/01-resource-referral/my_tools.py @@ -11,15 +11,18 @@ from debugging import debug_runnable TWO_ONE_ONE_BASE_SEARCH_ENDPOINT = "https://api.211.org/search/v1/api" + + # TODO: adjust parameters and get LLM to set correct parameters @tool -def call_211_api(city: str, service_type:str | list[str]) -> str: +def call_211_api(city: str, service_type: str | list[str]) -> str: """Calls National 211 API for the given city and service type, such as 'Consumer Services'""" print(f"211 args: city={city}; service_type={service_type}") return directly_call_211_api(city, service_type) -def directly_call_211_api(city:str, keyword:str | list[str]) -> str: + +def directly_call_211_api(city: str, keyword: str | list[str]) -> str: if isinstance(keyword, str): return get_services_from_211(city, keyword) if isinstance(keyword, list): @@ -32,19 +35,19 @@ def directly_call_211_api(city:str, keyword:str | list[str]) -> str: raise ValueError(f"Invalid keyword type: {type(keyword)}") -def get_services_from_211(city:str, keyword:str | list[str]): +def get_services_from_211(city: str, keyword: str | list[str]): location_endpoint = f"{TWO_ONE_ONE_BASE_SEARCH_ENDPOINT}/Search/Keyword?Keyword={keyword}&Location={city}&Top=10&OrderBy=Relevance&SearchMode=Any&IncludeStateNationalRecords=true&ReturnTaxonomyTermsIfNoResults=false" - TWO_ONE_ONE_API_KEY = os.environ.get('TWO_ONE_ONE_API_KEY') + TWO_ONE_ONE_API_KEY = os.environ.get("TWO_ONE_ONE_API_KEY") headers = { - 'Accept': 'application/json', - 'Api-Key': TWO_ONE_ONE_API_KEY, + "Accept": "application/json", + "Api-Key": TWO_ONE_ONE_API_KEY, } location_search = requests.get(location_endpoint, headers=headers) # From Search: /api/Filters/ServiceAreas?StateProvince=MI, returns [] - try: + try: first_result = location_search.json()["results"][0]["document"] # difficult to find param {location_id}, location_id returns dataowner location_id = first_result["idLocation"] @@ -60,12 +63,16 @@ def get_services_from_211(city:str, keyword:str | list[str]): print("Failed to get services at location") return "[]" + # Check for csv file csv_file = "nyc_referral_csv.csv" if not os.path.exists(csv_file): - print(f"Optionally download {csv_file} from google drive: https://drive.google.com/file/d/1YHgJvZCDF5VtTO-AQ4-I3_1jGzrcOHjY/view?usp=sharing") + print( + f"Optionally download {csv_file} from google drive: https://drive.google.com/file/d/1YHgJvZCDF5VtTO-AQ4-I3_1jGzrcOHjY/view?usp=sharing" + ) input(f"Press Enter to continue without {csv_file}...") + @tool def query_spreadsheet(city: str, service_type: str | list[str]) -> str: """Search spreadsheet for support resources given the city and service type, such as 'Food Assistance'.""" @@ -76,20 +83,24 @@ def query_spreadsheet(city: str, service_type: str | list[str]) -> str: # base implementation df = pandas.read_csv(csv_file) - separated_locations = city.split(',') + separated_locations = city.split(",") city_to_search = separated_locations[0] - query = service_type if isinstance(service_type, str) else '|'.join(service_type) + query = service_type if isinstance(service_type, str) else "|".join(service_type) print(query) - results = df.query(f'needs.str.contains("{query}", case=False) & counties_served.str.contains("{city_to_search}", case=False)', engine='python') + results = df.query( + f'needs.str.contains("{query}", case=False) & counties_served.str.contains("{city_to_search}", case=False)', + engine="python", + ) if results.to_numpy().size == 0: return "[]" - - csv_with_header_to_json = results.replace(np.nan, None).to_dict('records') + + csv_with_header_to_json = results.replace(np.nan, None).to_dict("records") dict_json = json.dumps(csv_with_header_to_json, indent=2) - + return dict_json + @tool def merge_json_results(user_query, result_211_api, result_spreadsheet): """Merge JSON results from 211 API and spreadsheet""" @@ -104,17 +115,32 @@ def merge_json_results(user_query, result_211_api, result_spreadsheet): relevance_agent = create_relevance_agent() # A long list of resources confuses the LLM, so sample only 40 - sampled_resources : list[dict] = random.sample(list(deduplicated_dict.values()), min(40, len(deduplicated_dict))) + sampled_resources: list[dict] = random.sample( + list(deduplicated_dict.values()), min(40, len(deduplicated_dict)) + ) formatted_resources = _format_for_prompt(sampled_resources) - prioritized_list = relevance_agent.invoke({"resources": formatted_resources, "user_query": user_query}) + prioritized_list = relevance_agent.invoke( + {"resources": formatted_resources, "user_query": user_query} + ) return prioritized_list + ALTERNATIVE_KEYS = [ # from spreadsheet - "id", "alternate_name", "url", "website", "email", "tax_id", + "id", + "alternate_name", + "url", + "website", + "email", + "tax_id", # from 211 - "idService", "idOrganization", "name", "alternateName" - ] + "idService", + "idOrganization", + "name", + "alternateName", +] + + def _merge_and_deduplicate(dict_listA, dict_listB): """Merge 2 list of objects, removing duplicates based on object's 'name'. If 'name' is not present, use one of ALTERNATIVE_KEYS to deduplicate.""" @@ -130,9 +156,12 @@ def _merge_and_deduplicate(dict_listA, dict_listB): deduplicated_dict[obj["name"]] = obj continue - deduplicated_dict[obj["name"]] = _merge_objects(deduplicated_dict[obj["name"]], obj) + deduplicated_dict[obj["name"]] = _merge_objects( + deduplicated_dict[obj["name"]], obj + ) return deduplicated_dict + def _merge_objects(objA: dict, objB: dict): """Merge 2 objects, concatenating values if same key are in both objects""" merged_obj = {} @@ -141,7 +170,10 @@ def _merge_objects(objA: dict, objB: dict): merged_obj[key] = ";; ".join(set(values_list)) return merged_obj + APPROVED_RESOURCE_NAMES = ["Alpena CAO"] + + def _filter_approved(deduplicated_dict): """Filter collection to only include approved resource names""" for key in list(deduplicated_dict): @@ -150,10 +182,16 @@ def _filter_approved(deduplicated_dict): del deduplicated_dict[key] return deduplicated_dict + def _format_for_prompt(resources: dict): """Format resources for prompt""" - return "\n".join([f"- {resource['name']} ({resource['phone']}): provides {resource.get('needs')} for counties {resource.get('counties_served')}. {resource.get('description', '')}" - for resource in resources]) + return "\n".join( + [ + f"- {resource['name']} ({resource['phone']}): provides {resource.get('needs')} for counties {resource.get('counties_served')}. {resource.get('description', '')}" + for resource in resources + ] + ) + def create_relevance_agent(): return ( @@ -162,6 +200,7 @@ def create_relevance_agent(): | create_llm(model_name="openhermes", settings={"temperature": 0, "top_p": 0.8}) ) + def _agent_prompt_template(): template = """You are a helpful automated agent that filters and prioritizes benefits services. \ Downselect to less than 10 services total and prioritize the following list of services based on the user's query. \ diff --git a/02-household-queries/api_helpers.py b/02-household-queries/api_helpers.py index 8be548f..5652173 100644 --- a/02-household-queries/api_helpers.py +++ b/02-household-queries/api_helpers.py @@ -7,9 +7,10 @@ # Fetches data from Guru, currently not used as we're pulling data from static json files # Eventually we would like to pull the latest updated data from the Guru API GURU_ENDPOINT = "https://api.getguru.com/api/v1/" + + def get_guru_data(): url = f"{GURU_ENDPOINT}cards/3fbff9c4-56a8-4561-a7d1-09727f1b4703" - headers = { - 'Authorization': os.environ.get('GURU_TOKEN')} + headers = {"Authorization": os.environ.get("GURU_TOKEN")} response = requests.request("GET", url, headers=headers) - return response.json() \ No newline at end of file + return response.json() diff --git a/02-household-queries/chainlit-household-bot.py b/02-household-queries/chainlit-household-bot.py index 3668cfe..212bd09 100755 --- a/02-household-queries/chainlit-household-bot.py +++ b/02-household-queries/chainlit-household-bot.py @@ -10,7 +10,10 @@ from chromadb.config import Settings from langchain.chains import ConversationalRetrievalChain -from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceEmbeddings +from langchain_community.embeddings import ( + SentenceTransformerEmbeddings, + HuggingFaceEmbeddings, +) from langchain_community.vectorstores import Chroma from langchain.memory import ChatMessageHistory, ConversationBufferMemory from langchain_google_genai import GoogleGenerativeAIEmbeddings @@ -18,16 +21,20 @@ import os from ingest import add_json_html_data_to_vector_db, add_pdf_to_vector_db, ingest_call -from llm import google_gemini_client, ollama_client #, gpt4all_client +from llm import google_gemini_client, ollama_client # , gpt4all_client from retrieval import retrieval_call OLLAMA_LLMS = ["openhermes", "llama2", "mistral"] GOOGLE_LLMS = ["gemini-pro"] # GPT4ALL_LLMS = ["gpt4all"] -GOOGLE_EMBEDDINGS=["Google::models/embedding-001"] -OPEN_SOURCE_EMBEDDINGS=["all-MiniLM-L6-v2"] -HUGGING_FACE_EMBEDDINGS=["HuggingFace::all-MiniLM-L6-v2", "HuggingFace::all-mpnet-base-v2"] +GOOGLE_EMBEDDINGS = ["Google::models/embedding-001"] +OPEN_SOURCE_EMBEDDINGS = ["all-MiniLM-L6-v2"] +HUGGING_FACE_EMBEDDINGS = [ + "HuggingFace::all-MiniLM-L6-v2", + "HuggingFace::all-mpnet-base-v2", +] + @cl.on_chat_start async def init_chat(): @@ -42,8 +49,16 @@ async def init_chat(): value="chooseBetter", label="Demo choosing better response", ), - cl.Action(name="uploadDefaultFiles", value="upload_default_files", label="Load default files into vector DB"), - cl.Action(name="uploadFilesToVectorAct", value="upload_files_to_vector", label="Upload files for vector DB"), + cl.Action( + name="uploadDefaultFiles", + value="upload_default_files", + label="Load default files into vector DB", + ), + cl.Action( + name="uploadFilesToVectorAct", + value="upload_files_to_vector", + label="Upload files for vector DB", + ), cl.Action(name="resetDB", value="reset_db", label="Reset DB"), ], ).send() @@ -60,10 +75,16 @@ async def init_chat(): Select( id="embedding", label="Embeddings", - values= GOOGLE_EMBEDDINGS + OPEN_SOURCE_EMBEDDINGS + HUGGING_FACE_EMBEDDINGS, + values=GOOGLE_EMBEDDINGS + + OPEN_SOURCE_EMBEDDINGS + + HUGGING_FACE_EMBEDDINGS, initial_index=0, ), - Switch(id="use_vector_db", label="Use vector db sources", initial=os.environ.get("USE_VECTOR_DB", False)), + Switch( + id="use_vector_db", + label="Use vector db sources", + initial=os.environ.get("USE_VECTOR_DB", False), + ), Slider( id="temperature", label="LLM Temperature", @@ -98,6 +119,7 @@ async def on_click_settings(action: cl.Action): E.g. **bold**, *italic*, `code`, [links](https://www.example.com), etc. """ + @cl.action_callback("resetDB") async def on_click_resetDB(action: cl.Action): # reset db after changing to avoid error: embedding dimension does not match collection dimensionality @@ -105,6 +127,7 @@ async def on_click_resetDB(action: cl.Action): persistent_client = cl.user_session.get("persistent_client") persistent_client.reset() + @cl.action_callback("stepsDemoAct") async def on_click_stepsDemo(action: cl.Action): async with cl.Step(name="Child step A", disable_feedback=False) as child_step: @@ -173,6 +196,7 @@ async def set_llm_model(): await msg.stream_token(f"Done setting up {llm_name} LLM") await msg.send() + async def set_embeddings(): settings = cl.user_session.get("settings") embeddings = settings["embedding"] @@ -182,13 +206,15 @@ async def set_embeddings(): ) embedding = None if embeddings in GOOGLE_EMBEDDINGS: - GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') - model_name= embeddings.split('::')[1] - embedding = GoogleGenerativeAIEmbeddings(model=model_name, google_api_key=GOOGLE_API_KEY) + GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") + model_name = embeddings.split("::")[1] + embedding = GoogleGenerativeAIEmbeddings( + model=model_name, google_api_key=GOOGLE_API_KEY + ) elif embeddings in OPEN_SOURCE_EMBEDDINGS: embedding = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") elif embeddings in HUGGING_FACE_EMBEDDINGS: - model_name= embeddings.split('::')[1] + model_name = embeddings.split("::")[1] embeddings = HuggingFaceEmbeddings(model_name=model_name) else: await cl.Message(content=f"Could not initialize embedding: {embeddings}").send() @@ -197,6 +223,7 @@ async def set_embeddings(): await msg.stream_token(f"Done setting up {embeddings} embedding") await msg.send() + async def set_vector_db(): await init_embedding_function_if_needed() embeddings = cl.user_session.get("embedding") @@ -204,39 +231,47 @@ async def set_vector_db(): author="backend", content=f"Setting up Chroma DB with `{embeddings}`...\n", ) - persistent_client = chromadb.PersistentClient(settings=Settings(allow_reset=True), path="./chroma_db") + persistent_client = chromadb.PersistentClient( + settings=Settings(allow_reset=True), path="./chroma_db" + ) cl.user_session.set("persistent_client", persistent_client) - vectordb=Chroma( + vectordb = Chroma( client=persistent_client, - collection_name="resources", + collection_name="resources", persist_directory="./chroma_db", - embedding_function=embeddings + embedding_function=embeddings, ) cl.user_session.set("vectordb", vectordb) await msg.stream_token("Done setting up vector db") await msg.send() + async def init_llm_client_if_needed(): client = cl.user_session.get("client") if not client: await set_llm_model() + + async def init_embedding_function_if_needed(): embedding = cl.user_session.get("embedding") if not embedding: await set_embeddings() + + async def init_persistent_client_if_needed(): - persistent_client=cl.user_session.get("persistent_client") + persistent_client = cl.user_session.get("persistent_client") if persistent_client is None: await set_vector_db() + @cl.on_message async def message_submitted(message: cl.Message): await init_llm_client_if_needed() await init_embedding_function_if_needed() settings = cl.user_session.get("settings") - client=cl.user_session.get("client") - vectordb=cl.user_session.get("vectordb") + client = cl.user_session.get("client") + vectordb = cl.user_session.get("vectordb") # 3 ways to manage history for LLM: # 1. Use Chainlit # message_history = cl.user_session.get("message_history") @@ -245,25 +280,28 @@ async def message_submitted(message: cl.Message): # 3. Use LlmPrompts lp.register_answer # Reminder to use make_async for long running tasks: https://docs.chainlit.io/guides/sync-async#long-running-synchronous-tasks - # If options `streaming` is set, or `use_vector_db` is not set, the RAG chain will not be called + # If options `streaming` is set, or `use_vector_db` is not set, the RAG chain will not be called if settings["streaming"]: if settings["use_vector_db"]: await cl.Message("Change the setting to use non-streaming instead").send() else: await call_llm_async(message) - + else: if settings["use_vector_db"] and vectordb: await retrieval_function(vectordb=vectordb, llm=client) response = retrieval_call(client, vectordb, message.content) - source_list = [doc.metadata for doc in response['source_documents']] - sources = ', '.join([sources_item['source'] for sources_item in source_list]) + source_list = [doc.metadata for doc in response["source_documents"]] + sources = ", ".join( + [sources_item["source"] for sources_item in source_list] + ) answer = f"Result:\n{response['result']} \nSources: \n" + sources await cl.Message(content=answer).send() else: response = call_llm(message) await cl.Message(content=f"*Response*: {response}").send() + @cl.step(type="llm", show_input=True) async def call_llm_async(message: cl.Message): client = cl.user_session.get("client") @@ -286,7 +324,7 @@ def call_llm(message: cl.Message): @cl.action_callback("uploadDefaultFiles") async def on_click_upload_default_files(action: cl.Action): await set_vector_db() - vectordb= cl.user_session.get("vectordb") + vectordb = cl.user_session.get("vectordb") msg = cl.Message(content="Processing files...", disable_feedback=True) await msg.send() @@ -294,6 +332,7 @@ async def on_click_upload_default_files(action: cl.Action): msg.content = "Processing default files done. You can now ask questions!" await msg.update() + @cl.action_callback("uploadFilesToVectorAct") async def on_click_upload_file_query(action: cl.Action): files = None @@ -308,18 +347,23 @@ async def on_click_upload_file_query(action: cl.Action): file = files[0] # initialize db await set_vector_db() - vectordb=cl.user_session.get("vectordb") - if(file.type == "application/pdf"): + vectordb = cl.user_session.get("vectordb") + if file.type == "application/pdf": add_pdf_to_vector_db(vectordb=vectordb, file_path=file.path) - elif(file.type == "application/json"): - add_json_html_data_to_vector_db(vectordb=vectordb, file_path=file.path, content_key="content", index_key="preferredPhrase") + elif file.type == "application/json": + add_json_html_data_to_vector_db( + vectordb=vectordb, + file_path=file.path, + content_key="content", + index_key="preferredPhrase", + ) msg = cl.Message(content=f"Processing `{file.name}`...", disable_feedback=True) await msg.send() msg.content = f"Processing `{file.name}` done. You can now ask questions!" await msg.update() - -async def retrieval_function(vectordb, llm): + +async def retrieval_function(vectordb, llm): retriever = vectordb.as_retriever(search_kwargs={"k": 1}) message_history = ChatMessageHistory() @@ -338,6 +382,6 @@ async def retrieval_function(vectordb, llm): memory=memory, return_source_documents=True, ) - + # Let the user know that the system is ready cl.user_session.set("chain", chain) diff --git a/02-household-queries/ingest.py b/02-household-queries/ingest.py index d4ca553..3bf390f 100644 --- a/02-household-queries/ingest.py +++ b/02-household-queries/ingest.py @@ -4,22 +4,27 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter import json + # split text into chunks def get_text_chunks_langchain(text, source): text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=500) texts = text_splitter.split_text(text) - docs = [Document(page_content=t, metadata={"source":source}) for t in texts] + docs = [Document(page_content=t, metadata={"source": source}) for t in texts] return docs + # Chunk the pdf and load into vector db def add_pdf_to_vector_db(vectordb, file_path, chunk_size=500, chunk_overlap=100): # PDFMinerLoader only gives metadata when extract_images=True due to default using lazy_loader loader = PDFMinerLoader(file_path, extract_images=True) - text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) pdf_pages = loader.load_and_split(text_splitter) print("Loading PDF chunks into vector db") vectordb.add_documents(documents=pdf_pages) + # Chunk the json data and load into vector db def add_json_html_data_to_vector_db(vectordb, file_path, content_key, index_key): data_file = open(file_path, encoding="utf-8") @@ -27,17 +32,23 @@ def add_json_html_data_to_vector_db(vectordb, file_path, content_key, index_key) for content in json_data: soup = BeautifulSoup(content[content_key], "html.parser") - text = soup.get_text(separator='\n', strip=True) + text = soup.get_text(separator="\n", strip=True) chunks = get_text_chunks_langchain(text, content[index_key]) print(f"Loading Document {content[index_key]} chunk into vector db") vectordb.add_documents(documents=chunks) -def ingest_call(vectordb): + +def ingest_call(vectordb): # Load the PDF and create chunks # download from https://drive.google.com/file/d/1--qDjraIk1WGxwuCGBP-nfxzOr9IHvcZ/view?usp=drive_link pdf_path = "./tanf.pdf" add_pdf_to_vector_db(vectordb=vectordb, file_path=pdf_path) # download from https://drive.google.com/file/d/1UoWmktXS5nqgIWj2x_O5hgzwU0yVuaJc/view?usp=drive_link - guru_file_path='./guru_cards_for_nava.json' - add_json_html_data_to_vector_db(vectordb=vectordb, file_path=guru_file_path, content_key="content", index_key="preferredPhrase") + guru_file_path = "./guru_cards_for_nava.json" + add_json_html_data_to_vector_db( + vectordb=vectordb, + file_path=guru_file_path, + content_key="content", + index_key="preferredPhrase", + ) diff --git a/02-household-queries/llm.py b/02-household-queries/llm.py index 0c14475..bd621c5 100644 --- a/02-household-queries/llm.py +++ b/02-household-queries/llm.py @@ -6,9 +6,13 @@ from langchain_google_genai import ChatGoogleGenerativeAI import os + dotenv.load_dotenv() -def ollama_client(model_name=None, callbacks=None, settings=None, print_to_stdout=False): + +def ollama_client( + model_name=None, callbacks=None, settings=None, print_to_stdout=False +): if not callbacks: callbacks = [] if print_to_stdout: @@ -27,8 +31,14 @@ def ollama_client(model_name=None, callbacks=None, settings=None, print_to_stdou # To connect via another URL: Ollama(base_url='http://localhost:11434', ...) return Ollama(model=model_name, callbacks=callbacks, **settings) + # Add LLM client for other LLMs here... -def gpt4all_client(model_path="./models/mistral-7b-instruct-v0.1.Q4_0.gguf", callbacks=None, settings=None,print_to_stdout=False): +def gpt4all_client( + model_path="./models/mistral-7b-instruct-v0.1.Q4_0.gguf", + callbacks=None, + settings=None, + print_to_stdout=False, +): # Open source option # download Mistral at https://mistral.ai/news/announcing-mistral-7b/ if not callbacks: @@ -44,13 +54,18 @@ def gpt4all_client(model_path="./models/mistral-7b-instruct-v0.1.Q4_0.gguf", cal } print("LLM settings:", settings) - - return GPT4All(model=model_path,max_tokens=1000, verbose=True,repeat_last_n=0, **settings) -def google_gemini_client(model_name="gemini-pro", callbacks=None, settings=None, print_to_stdout=False): - # Get a Google API key by following the steps after clicking on Get an API key button + return GPT4All( + model=model_path, max_tokens=1000, verbose=True, repeat_last_n=0, **settings + ) + + +def google_gemini_client( + model_name="gemini-pro", callbacks=None, settings=None, print_to_stdout=False +): + # Get a Google API key by following the steps after clicking on Get an API key button # at https://ai.google.dev/tutorials/setup - GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') + GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") if not callbacks: callbacks = [] if print_to_stdout: @@ -65,8 +80,10 @@ def google_gemini_client(model_name="gemini-pro", callbacks=None, settings=None, } print("LLM settings:", settings) - return ChatGoogleGenerativeAI(model=model_name, - verbose = True,google_api_key=GOOGLE_API_KEY, - convert_system_message_to_human=True, **settings) - - + return ChatGoogleGenerativeAI( + model=model_name, + verbose=True, + google_api_key=GOOGLE_API_KEY, + convert_system_message_to_human=True, + **settings, + ) diff --git a/02-household-queries/retrieval.py b/02-household-queries/retrieval.py index 46d91ea..0cc2ed5 100644 --- a/02-household-queries/retrieval.py +++ b/02-household-queries/retrieval.py @@ -2,6 +2,7 @@ from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA + def retrieval_call(llm, vectordb, question): # Create the retrieval chain template = """ @@ -20,7 +21,6 @@ def retrieval_call(llm, vectordb, question): retriever=retriever, return_source_documents=True, chain_type_kwargs={"prompt": prompt}, - ) # question = os.environ.get("USER_QUERY") @@ -28,7 +28,7 @@ def retrieval_call(llm, vectordb, question): print("Please state your question here: ") question = input() # Invoke the retrieval chain - response=retrieval_chain.invoke({"query":question}) + response = retrieval_chain.invoke({"query": question}) print("\n## QUERY: ", question) print("\n## RESULT: ", response["result"]) print("\n## SOURCE DOC: ") diff --git a/02-household-queries/run.py b/02-household-queries/run.py index 3cd946b..d3c016e 100644 --- a/02-household-queries/run.py +++ b/02-household-queries/run.py @@ -1,6 +1,9 @@ import os import dotenv -from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceEmbeddings +from langchain_community.embeddings import ( + SentenceTransformerEmbeddings, + HuggingFaceEmbeddings, +) from langchain_community.llms import GPT4All from langchain_community.vectorstores import Chroma from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings @@ -42,22 +45,29 @@ if llm_choice == "2" or llm_choice == "Mistral": # Open source option # download Mistral at https://mistral.ai/news/announcing-mistral-7b/ - gpt4all_path= "./mistral-7b-instruct-v0.1.Q4_0.gguf" - llm = GPT4All(model=gpt4all_path,max_tokens=1000, verbose=True,repeat_last_n=0) + gpt4all_path = "./mistral-7b-instruct-v0.1.Q4_0.gguf" + llm = GPT4All(model=gpt4all_path, max_tokens=1000, verbose=True, repeat_last_n=0) elif llm_choice == "3": # _llm_model_name = "mistral" # "openhermes", "llama2", "mistral" llm_settings = {"temperature": 0.1} llm = ollama_client(_llm_model_name, settings=llm_settings) else: - # Get a Google API key by following the steps after clicking on Get an API key button + # Get a Google API key by following the steps after clicking on Get an API key button # at https://ai.google.dev/tutorials/setup - GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') - llm = ChatGoogleGenerativeAI(model="gemini-pro", - verbose = True,google_api_key=GOOGLE_API_KEY, - convert_system_message_to_human=True) - + GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") + llm = ChatGoogleGenerativeAI( + model="gemini-pro", + verbose=True, + google_api_key=GOOGLE_API_KEY, + convert_system_message_to_human=True, + ) + # initialize chroma db -vectordb=Chroma(embedding_function=embeddings, collection_name="resources", persist_directory="./chroma_db") +vectordb = Chroma( + embedding_function=embeddings, + collection_name="resources", + persist_directory="./chroma_db", +) print(""" Initialize DB and retrieve?