Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(history): max tokens in the history provided #2487

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions backend/modules/brain/rags/quivr_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import OpenAIEmbeddings
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
Expand Down Expand Up @@ -203,6 +203,40 @@ def _combine_documents(
def get_retriever(self):
return self.vector_store.as_retriever()

def filter_history(
self, chat_history, max_history: int = 10, max_tokens: int = 2000
):
"""
Filter out the chat history to only include the messages that are relevant to the current question

Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, sous la direction de son supérieur hiérarchique, Stanislas Girard."), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
a token is 4 characters
"""
chat_history = chat_history[::-1]
total_tokens = 0
total_pairs = 0
filtered_chat_history = []
for i in range(0, len(chat_history), 2):
if i + 1 < len(chat_history):
human_message = chat_history[i]
ai_message = chat_history[i + 1]
message_tokens = (
len(human_message.content) + len(ai_message.content)
) // 4
if (
total_tokens + message_tokens > max_tokens
or total_pairs >= max_history
):
break
filtered_chat_history.append(human_message)
filtered_chat_history.append(ai_message)
total_tokens += message_tokens
total_pairs += 1
chat_history = filtered_chat_history[::-1]

return chat_history

def get_chain(self):
compressor = None
if os.getenv("COHERE_API_KEY"):
Expand All @@ -216,7 +250,9 @@ def get_chain(self):
)

loaded_memory = RunnablePassthrough.assign(
chat_history=lambda x: x["chat_history"],
chat_history=RunnableLambda(
lambda x: self.filter_history(x["chat_history"]),
),
question=lambda x: x["question"],
)

Expand All @@ -227,7 +263,7 @@ def get_chain(self):
standalone_question = {
"standalone_question": {
"question": lambda x: x["question"],
"chat_history": lambda x: x["chat_history"],
"chat_history": itemgetter("chat_history"),
}
| CONDENSE_QUESTION_PROMPT
| ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)
Expand Down