From c09377040fed0251e2934e4e8837ef229234c292 Mon Sep 17 00:00:00 2001 From: iusztinpaul Date: Tue, 10 Oct 2023 08:40:58 +0300 Subject: [PATCH] fix: Sharing keys between chains --- modules/financial_bot/Makefile | 2 -- modules/financial_bot/financial_bot/chains.py | 34 ++++++++----------- .../financial_bot/langchain_bot.py | 6 ++-- modules/financial_bot/tools/run_chain.py | 1 - 4 files changed, 17 insertions(+), 26 deletions(-) diff --git a/modules/financial_bot/Makefile b/modules/financial_bot/Makefile index cf26130..2728f90 100644 --- a/modules/financial_bot/Makefile +++ b/modules/financial_bot/Makefile @@ -25,8 +25,6 @@ add_dev: run: @echo "Running financial_bot..." - @echo "LD_LIBRARY_PATH: $(LD_LIBRARY_PATH)" - poetry run python -m tools.run_chain diff --git a/modules/financial_bot/financial_bot/chains.py b/modules/financial_bot/financial_bot/chains.py index ab46ccc..a00ab69 100644 --- a/modules/financial_bot/financial_bot/chains.py +++ b/modules/financial_bot/financial_bot/chains.py @@ -1,11 +1,12 @@ from typing import Any, Dict, List import qdrant_client -from financial_bot.embeddings import EmbeddingModelSingleton -from financial_bot.template import PromptTemplate from langchain.chains.base import Chain from langchain.llms import HuggingFacePipeline +from financial_bot.embeddings import EmbeddingModelSingleton +from financial_bot.template import PromptTemplate + class ContextExtractorChain(Chain): """ @@ -17,43 +18,37 @@ class ContextExtractorChain(Chain): embedding_model: EmbeddingModelSingleton vector_store: qdrant_client.QdrantClient vector_collection: str - output_key: str = "payload" + output_key: str = "context" @property def input_keys(self) -> List[str]: - return ["about_me", "question", "context"] + return ["about_me", "question"] @property def output_keys(self) -> List[str]: return [self.output_key] def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - # TODO: handle that None, without the need to enter chain - about_key, quest_key, contx_key = self.input_keys - question_str = inputs.get(quest_key, None) + _, quest_key = self.input_keys + question_str = inputs.get[quest_key] # TODO: maybe async embed? embeddings = self.embedding_model(question_str) - # TODO: get rid of hardcoded collection_name, specify 1 top_k or adjust multiple context insertions matches = self.vector_store.search( query_vector=embeddings, k=self.top_k, collection_name=self.vector_collection, ) - content = "" + context = "" for match in matches: - content += match.payload["summary"] + "\n" + context += match.payload["summary"] + "\n" - payload = { - about_key: inputs[about_key], - quest_key: inputs[quest_key], - contx_key: content, + return { + self.output_key: context, } - return {self.output_key: payload} - class FinancialBotQAChain(Chain): """This custom chain handles LLM generation upon given prompt""" @@ -64,20 +59,19 @@ class FinancialBotQAChain(Chain): @property def input_keys(self) -> List[str]: - return ["payload"] + return ["context"] @property def output_keys(self) -> List[str]: return [self.output_key] def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - # TODO: use .get and treat default value? about_me = inputs["about_me"] question = inputs["question"] - context = inputs["context"] + news_context = inputs.get("context") prompt = self.template.infer_raw_template.format( - user_context=about_me, news_context=context, question=question + user_context=about_me, news_context=news_context, question=question ) response = self.hf_pipeline(prompt) diff --git a/modules/financial_bot/financial_bot/langchain_bot.py b/modules/financial_bot/financial_bot/langchain_bot.py index 39ea66f..659294e 100644 --- a/modules/financial_bot/financial_bot/langchain_bot.py +++ b/modules/financial_bot/financial_bot/langchain_bot.py @@ -68,7 +68,7 @@ def build_chain(self) -> chains.SequentialChain: logger.info("Connecting chains into SequentialChain") seq_chain = chains.SequentialChain( chains=[context_retrieval_chain, llm_generator_chain], - input_variables=["about_me", "question", "context"], + input_variables=["about_me", "question"], output_variables=["response"], verbose=True, ) @@ -81,7 +81,7 @@ def build_chain(self) -> chains.SequentialChain: ) return seq_chain - def answer(self, about_me: str, question: str, context: str) -> str: + def answer(self, about_me: str, question: str) -> str: """ Given a short description about the user and a question make the LLM generate a response. @@ -99,7 +99,7 @@ def answer(self, about_me: str, question: str, context: str) -> str: LLM generated response. """ try: - inputs = {"about_me": about_me, "question": question, "context": context} + inputs = {"about_me": about_me, "question": question} response = self.finbot_chain.run(inputs) return response except KeyError as e: diff --git a/modules/financial_bot/tools/run_chain.py b/modules/financial_bot/tools/run_chain.py index 915ee06..c366c49 100644 --- a/modules/financial_bot/tools/run_chain.py +++ b/modules/financial_bot/tools/run_chain.py @@ -11,7 +11,6 @@ def main(): input_payload = { "about_me": "I'm a student and I have some money that I want to invest.", "question": "Should I consider investing in stocks from the Tech Sector?", - "context": "" } response = bot.answer(**input_payload)