From c09377040fed0251e2934e4e8837ef229234c292 Mon Sep 17 00:00:00 2001
From: iusztinpaul
Date: Tue, 10 Oct 2023 08:40:58 +0300
Subject: [PATCH] fix: Sharing keys between chains
---
modules/financial_bot/Makefile | 2 --
modules/financial_bot/financial_bot/chains.py | 34 ++++++++-----------
.../financial_bot/langchain_bot.py | 6 ++--
modules/financial_bot/tools/run_chain.py | 1 -
4 files changed, 17 insertions(+), 26 deletions(-)
diff --git a/modules/financial_bot/Makefile b/modules/financial_bot/Makefile
index cf26130..2728f90 100644
--- a/modules/financial_bot/Makefile
+++ b/modules/financial_bot/Makefile
@@ -25,8 +25,6 @@ add_dev:
run:
@echo "Running financial_bot..."
- @echo "LD_LIBRARY_PATH: $(LD_LIBRARY_PATH)"
-
poetry run python -m tools.run_chain
diff --git a/modules/financial_bot/financial_bot/chains.py b/modules/financial_bot/financial_bot/chains.py
index ab46ccc..a00ab69 100644
--- a/modules/financial_bot/financial_bot/chains.py
+++ b/modules/financial_bot/financial_bot/chains.py
@@ -1,11 +1,12 @@
from typing import Any, Dict, List
import qdrant_client
-from financial_bot.embeddings import EmbeddingModelSingleton
-from financial_bot.template import PromptTemplate
from langchain.chains.base import Chain
from langchain.llms import HuggingFacePipeline
+from financial_bot.embeddings import EmbeddingModelSingleton
+from financial_bot.template import PromptTemplate
+
class ContextExtractorChain(Chain):
"""
@@ -17,43 +18,37 @@ class ContextExtractorChain(Chain):
embedding_model: EmbeddingModelSingleton
vector_store: qdrant_client.QdrantClient
vector_collection: str
- output_key: str = "payload"
+ output_key: str = "context"
@property
def input_keys(self) -> List[str]:
- return ["about_me", "question", "context"]
+ return ["about_me", "question"]
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- # TODO: handle that None, without the need to enter chain
- about_key, quest_key, contx_key = self.input_keys
- question_str = inputs.get(quest_key, None)
+ _, quest_key = self.input_keys
+ question_str = inputs.get[quest_key]
# TODO: maybe async embed?
embeddings = self.embedding_model(question_str)
- # TODO: get rid of hardcoded collection_name, specify 1 top_k or adjust multiple context insertions
matches = self.vector_store.search(
query_vector=embeddings,
k=self.top_k,
collection_name=self.vector_collection,
)
- content = ""
+ context = ""
for match in matches:
- content += match.payload["summary"] + "\n"
+ context += match.payload["summary"] + "\n"
- payload = {
- about_key: inputs[about_key],
- quest_key: inputs[quest_key],
- contx_key: content,
+ return {
+ self.output_key: context,
}
- return {self.output_key: payload}
-
class FinancialBotQAChain(Chain):
"""This custom chain handles LLM generation upon given prompt"""
@@ -64,20 +59,19 @@ class FinancialBotQAChain(Chain):
@property
def input_keys(self) -> List[str]:
- return ["payload"]
+ return ["context"]
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- # TODO: use .get and treat default value?
about_me = inputs["about_me"]
question = inputs["question"]
- context = inputs["context"]
+ news_context = inputs.get("context")
prompt = self.template.infer_raw_template.format(
- user_context=about_me, news_context=context, question=question
+ user_context=about_me, news_context=news_context, question=question
)
response = self.hf_pipeline(prompt)
diff --git a/modules/financial_bot/financial_bot/langchain_bot.py b/modules/financial_bot/financial_bot/langchain_bot.py
index 39ea66f..659294e 100644
--- a/modules/financial_bot/financial_bot/langchain_bot.py
+++ b/modules/financial_bot/financial_bot/langchain_bot.py
@@ -68,7 +68,7 @@ def build_chain(self) -> chains.SequentialChain:
logger.info("Connecting chains into SequentialChain")
seq_chain = chains.SequentialChain(
chains=[context_retrieval_chain, llm_generator_chain],
- input_variables=["about_me", "question", "context"],
+ input_variables=["about_me", "question"],
output_variables=["response"],
verbose=True,
)
@@ -81,7 +81,7 @@ def build_chain(self) -> chains.SequentialChain:
)
return seq_chain
- def answer(self, about_me: str, question: str, context: str) -> str:
+ def answer(self, about_me: str, question: str) -> str:
"""
Given a short description about the user and a question make the LLM
generate a response.
@@ -99,7 +99,7 @@ def answer(self, about_me: str, question: str, context: str) -> str:
LLM generated response.
"""
try:
- inputs = {"about_me": about_me, "question": question, "context": context}
+ inputs = {"about_me": about_me, "question": question}
response = self.finbot_chain.run(inputs)
return response
except KeyError as e:
diff --git a/modules/financial_bot/tools/run_chain.py b/modules/financial_bot/tools/run_chain.py
index 915ee06..c366c49 100644
--- a/modules/financial_bot/tools/run_chain.py
+++ b/modules/financial_bot/tools/run_chain.py
@@ -11,7 +11,6 @@ def main():
input_payload = {
"about_me": "I'm a student and I have some money that I want to invest.",
"question": "Should I consider investing in stocks from the Tech Sector?",
- "context": ""
}
response = bot.answer(**input_payload)