Skip to content
This repository has been archived by the owner on Dec 9, 2024. It is now read-only.

Commit

Permalink
fix: Sharing keys between chains
Browse files Browse the repository at this point in the history
  • Loading branch information
iusztinpaul committed Oct 10, 2023
1 parent 0a5d0b9 commit c093770
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 26 deletions.
2 changes: 0 additions & 2 deletions modules/financial_bot/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ add_dev:
run:
@echo "Running financial_bot..."

@echo "LD_LIBRARY_PATH: $(LD_LIBRARY_PATH)"

poetry run python -m tools.run_chain


Expand Down
34 changes: 14 additions & 20 deletions modules/financial_bot/financial_bot/chains.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Dict, List

import qdrant_client
from financial_bot.embeddings import EmbeddingModelSingleton
from financial_bot.template import PromptTemplate
from langchain.chains.base import Chain
from langchain.llms import HuggingFacePipeline

from financial_bot.embeddings import EmbeddingModelSingleton
from financial_bot.template import PromptTemplate


class ContextExtractorChain(Chain):
"""
Expand All @@ -17,43 +18,37 @@ class ContextExtractorChain(Chain):
embedding_model: EmbeddingModelSingleton
vector_store: qdrant_client.QdrantClient
vector_collection: str
output_key: str = "payload"
output_key: str = "context"

@property
def input_keys(self) -> List[str]:
return ["about_me", "question", "context"]
return ["about_me", "question"]

@property
def output_keys(self) -> List[str]:
return [self.output_key]

def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# TODO: handle that None, without the need to enter chain
about_key, quest_key, contx_key = self.input_keys
question_str = inputs.get(quest_key, None)
_, quest_key = self.input_keys
question_str = inputs.get[quest_key]

# TODO: maybe async embed?
embeddings = self.embedding_model(question_str)

# TODO: get rid of hardcoded collection_name, specify 1 top_k or adjust multiple context insertions
matches = self.vector_store.search(
query_vector=embeddings,
k=self.top_k,
collection_name=self.vector_collection,
)

content = ""
context = ""
for match in matches:
content += match.payload["summary"] + "\n"
context += match.payload["summary"] + "\n"

payload = {
about_key: inputs[about_key],
quest_key: inputs[quest_key],
contx_key: content,
return {
self.output_key: context,
}

return {self.output_key: payload}


class FinancialBotQAChain(Chain):
"""This custom chain handles LLM generation upon given prompt"""
Expand All @@ -64,20 +59,19 @@ class FinancialBotQAChain(Chain):

@property
def input_keys(self) -> List[str]:
return ["payload"]
return ["context"]

@property
def output_keys(self) -> List[str]:
return [self.output_key]

def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# TODO: use .get and treat default value?
about_me = inputs["about_me"]
question = inputs["question"]
context = inputs["context"]
news_context = inputs.get("context")

prompt = self.template.infer_raw_template.format(
user_context=about_me, news_context=context, question=question
user_context=about_me, news_context=news_context, question=question
)
response = self.hf_pipeline(prompt)

Expand Down
6 changes: 3 additions & 3 deletions modules/financial_bot/financial_bot/langchain_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def build_chain(self) -> chains.SequentialChain:
logger.info("Connecting chains into SequentialChain")
seq_chain = chains.SequentialChain(
chains=[context_retrieval_chain, llm_generator_chain],
input_variables=["about_me", "question", "context"],
input_variables=["about_me", "question"],
output_variables=["response"],
verbose=True,
)
Expand All @@ -81,7 +81,7 @@ def build_chain(self) -> chains.SequentialChain:
)
return seq_chain

def answer(self, about_me: str, question: str, context: str) -> str:
def answer(self, about_me: str, question: str) -> str:
"""
Given a short description about the user and a question make the LLM
generate a response.
Expand All @@ -99,7 +99,7 @@ def answer(self, about_me: str, question: str, context: str) -> str:
LLM generated response.
"""
try:
inputs = {"about_me": about_me, "question": question, "context": context}
inputs = {"about_me": about_me, "question": question}
response = self.finbot_chain.run(inputs)
return response
except KeyError as e:
Expand Down
1 change: 0 additions & 1 deletion modules/financial_bot/tools/run_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def main():
input_payload = {
"about_me": "I'm a student and I have some money that I want to invest.",
"question": "Should I consider investing in stocks from the Tech Sector?",
"context": ""
}

response = bot.answer(**input_payload)
Expand Down

0 comments on commit c093770

Please sign in to comment.