Skip to content

Commit

Permalink
feat(reranker): Add flashrank and contextual compression retriever (#…
Browse files Browse the repository at this point in the history
…2480)

This pull request adds the flashrank and contextual compression
retriever to the codebase. The flashrank reranker model is used for
compression, and the contextual compression retriever combines the base
compressor and base retriever to improve document retrieval.
  • Loading branch information
StanGirard authored Apr 24, 2024
1 parent 7ead787 commit f656dbc
Show file tree
Hide file tree
Showing 7 changed files with 501 additions and 404 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ TELEMETRY_ENABLED=true
CELERY_BROKER_URL=redis://redis:6379/0
CELEBRY_BROKER_QUEUE_NAME=quivr-preview.fifo
QUIVR_DOMAIN=http://localhost:3000/
#COHERE_API_KEY=CHANGE_ME




Expand Down
2 changes: 2 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ datasets = "*"
pytest-dotenv = "*"
fpdf2 = "*"
unidecode = "*"
flashrank = "*"
langchain-cohere = "*"

[dev-packages]
black = "*"
Expand Down
786 changes: 432 additions & 354 deletions Pipfile.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions backend/models/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class File(BaseModel):
vectors_ids: Optional[list] = []
file_extension: Optional[str] = ""
content: Optional[Any] = None
chunk_size: int = 250
chunk_overlap: int = 0
chunk_size: int = 800
chunk_overlap: int = 300
documents: Optional[Document] = None

@property
Expand Down
30 changes: 19 additions & 11 deletions backend/modules/brain/rags/quivr_rag.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import os
from operator import itemgetter
from typing import Optional
from uuid import UUID

from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.llms.base import BaseLLM
from langchain.memory import ConversationBufferMemory
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain.schema import format_document
from langchain_cohere import CohereRerank
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.messages import get_buffer_string
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import OpenAIEmbeddings
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from modules.brain.service.brain_service import BrainService
from modules.chat.service.chat_service import ChatService
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings
from supabase.client import Client, create_client
Expand All @@ -28,7 +30,7 @@


# First step is to create the Rephrasing Prompt
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all.
Chat History:
{chat_history}
Expand Down Expand Up @@ -202,14 +204,20 @@ def get_retriever(self):
return self.vector_store.as_retriever()

def get_chain(self):
compressor = None
if os.getenv("COHERE_API_KEY"):
compressor = CohereRerank(top_n=5)
else:
compressor = FlashrankRerank(model="ms-marco-TinyBERT-L-2-v2", top_n=5)

retriever_doc = self.get_retriever()
memory = ConversationBufferMemory(
return_messages=True, output_key="answer", input_key="question"
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever_doc
)

loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(memory.load_memory_variables)
| itemgetter("history"),
chat_history=lambda x: x["chat_history"],
question=lambda x: x["question"],
)

api_base = None
Expand All @@ -219,7 +227,7 @@ def get_chain(self):
standalone_question = {
"standalone_question": {
"question": lambda x: x["question"],
"chat_history": lambda x: get_buffer_string(x["chat_history"]),
"chat_history": lambda x: x["chat_history"],
}
| CONDENSE_QUESTION_PROMPT
| ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)
Expand All @@ -233,7 +241,7 @@ def get_chain(self):

# Now we retrieve the documents
retrieved_documents = {
"docs": itemgetter("standalone_question") | retriever_doc,
"docs": itemgetter("standalone_question") | compression_retriever,
"question": lambda x: x["standalone_question"],
"custom_instructions": lambda x: prompt_to_use,
}
Expand Down
77 changes: 42 additions & 35 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ backoff==2.2.1; python_version >= '3.7' and python_version < '4.0'
beautifulsoup4==4.12.3; python_full_version >= '3.6.0'
billiard==4.2.0; python_version >= '3.7'
black==24.4.0; python_version >= '3.8'
boto3==1.34.86; python_version >= '3.8'
botocore==1.34.86; python_version >= '3.8'
boto3==1.34.90; python_version >= '3.8'
botocore==1.34.90; python_version >= '3.8'
celery[redis,sqs]==5.4.0; python_version >= '3.8'
certifi==2024.2.2; python_version >= '3.6'
cffi==1.16.0; platform_python_implementation != 'PyPy'
Expand All @@ -28,17 +28,18 @@ click==8.1.7; python_version >= '3.7'
click-didyoumean==0.3.1; python_full_version >= '3.6.2'
click-plugins==1.1.1
click-repl==0.3.0; python_version >= '3.6'
cohere==5.3.3; python_version >= '3.8' and python_version < '4.0'
coloredlogs==15.0.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
colorlog==6.8.2; python_version >= '3.6'
contourpy==1.2.1; python_version >= '3.9'
cryptography==42.0.5; python_version >= '3.7'
cssselect==1.2.0; python_version >= '3.7'
cycler==0.12.1; python_version >= '3.8'
dataclasses-json==0.6.4; python_version >= '3.7' and python_version < '4.0'
dataclasses-json-speakeasy==0.5.11; python_version >= '3.7' and python_version < '4.0'
datasets==2.18.0; python_full_version >= '3.8.0'
datasets==2.19.0; python_full_version >= '3.8.0'
debugpy==1.8.1; python_version >= '3.8'
decorator==5.1.1; python_version >= '3.5'
deepdiff==7.0.1; python_version >= '3.8'
defusedxml==0.7.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
deprecated==1.2.14; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
deprecation==2.1.0
Expand All @@ -50,22 +51,24 @@ docx2txt==0.8
duckdb==0.10.2; python_full_version >= '3.7.0'
ecdsa==0.19.0; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
effdet==0.4.1
emoji==2.11.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
emoji==2.11.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
et-xmlfile==1.1.0; python_version >= '3.6'
faker==19.13.0; python_version >= '3.8'
fastapi==0.110.1; python_version >= '3.8'
fastapi==0.110.2; python_version >= '3.8'
fastavro==1.9.4; python_version >= '3.8'
feedfinder2==0.0.4
feedparser==6.0.11; python_version >= '3.6'
filelock==3.13.4; python_version >= '3.8'
filetype==1.2.0
flake8==7.0.0; python_full_version >= '3.8.1'
flake8-black==0.3.6; python_version >= '3.7'
flashrank==0.2.0; python_version >= '3.6'
flatbuffers==24.3.25
flower==2.0.1; python_version >= '3.7'
fonttools==4.51.0; python_version >= '3.8'
fpdf2==2.7.8; python_version >= '3.7'
frozenlist==1.4.1; python_version >= '3.8'
fsspec[http]==2024.2.0; python_version >= '3.8'
fsspec[http]==2024.3.1; python_version >= '3.8'
gitdb==4.0.11; python_version >= '3.7'
gitpython==3.1.43; python_version >= '3.7'
gotrue==2.4.2; python_version >= '3.8' and python_version < '4.0'
Expand All @@ -74,6 +77,7 @@ h11==0.14.0; python_version >= '3.7'
html5lib==1.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
httpcore==1.0.5; python_version >= '3.8'
httpx==0.27.0; python_version >= '3.8'
httpx-sse==0.4.0; python_version >= '3.8'
huggingface-hub==0.22.2; python_full_version >= '3.8.0'
humanfriendly==10.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
humanize==4.9.0; python_version >= '3.8'
Expand All @@ -92,25 +96,26 @@ jsonpointer==2.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3
kiwisolver==1.4.5; python_version >= '3.7'
kombu[sqs]==5.3.7; python_version >= '3.8'
langchain==0.1.16; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-community==0.0.33; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-core==0.1.44; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-cohere==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-community==0.0.34; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-core==0.1.45; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-openai==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-text-splitters==0.0.1; python_version < '4.0' and python_full_version >= '3.8.1'
langdetect==1.0.9
langfuse==2.26.3; python_version < '4.0' and python_full_version >= '3.8.1'
langsmith==0.1.48; python_version < '4.0' and python_full_version >= '3.8.1'
langfuse==2.27.1; python_version < '4.0' and python_full_version >= '3.8.1'
langsmith==0.1.50; python_version < '4.0' and python_full_version >= '3.8.1'
layoutparser[layoutmodels,tesseract]==0.3.4; python_version >= '3.6'
litellm==1.35.10; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'
llama-index==0.10.29; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-agent-openai==0.2.2; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-cli==0.1.11; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-core==0.10.29; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-embeddings-openai==0.1.7; python_version < '4.0' and python_full_version >= '3.8.1'
litellm==1.35.21; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'
llama-index==0.10.31; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-agent-openai==0.2.3; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-cli==0.1.12; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-core==0.10.31; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-embeddings-openai==0.1.8; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-indices-managed-llama-cloud==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-legacy==0.9.48; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-llms-openai==0.1.15; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-llms-openai==0.1.16; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-multi-modal-llms-openai==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-program-openai==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-program-openai==0.1.6; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-question-gen-openai==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-readers-file==0.1.19; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-readers-llama-parse==0.1.4; python_version < '4.0' and python_full_version >= '3.8.1'
Expand Down Expand Up @@ -138,23 +143,24 @@ numpy==1.26.4; python_version >= '3.9'
olefile==0.47; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
omegaconf==2.3.0; python_version >= '3.6'
onnx==1.16.0
onnxruntime==1.15.1
openai==1.21.1; python_full_version >= '3.7.1'
onnxruntime==1.17.3
openai==1.23.3; python_full_version >= '3.7.1'
opencv-python==4.9.0.80; python_version >= '3.6'
openpyxl==3.1.2
ordered-set==4.1.0; python_version >= '3.7'
orjson==3.10.1; python_version >= '3.8'
packaging==23.2; python_version >= '3.7'
pandas==1.5.3; python_version >= '3.8'
pandasai==2.0.33; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8' and python_version >= '3.9'
pandasai==2.0.35; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8' and python_version >= '3.9'
pathspec==0.12.1; python_version >= '3.8'
pdf2image==1.17.0
pdfminer.six==20231228
pdfplumber==0.11.0; python_version >= '3.8'
pikepdf==8.15.1
pillow==10.3.0; python_version >= '3.8'
pillow-heif==0.16.0
platformdirs==4.2.0; python_version >= '3.8'
pluggy==1.4.0; python_version >= '3.8'
platformdirs==4.2.1; python_version >= '3.8'
pluggy==1.5.0; python_version >= '3.8'
portalocker==2.8.2; python_version >= '3.8'
postgrest==0.16.3; python_version >= '3.8' and python_version < '4.0'
posthog==3.5.0
Expand All @@ -165,22 +171,22 @@ psutil==5.9.8; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2,
psycopg2==2.9.9; python_version >= '3.7'
psycopg2-binary==2.9.9; python_version >= '3.7'
py==1.11.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
pyarrow==15.0.2; python_version >= '3.8'
pyarrow==16.0.0; python_version >= '3.8'
pyarrow-hotfix==0.6; python_version >= '3.5'
pyasn1==0.6.0; python_version >= '3.8'
pycocotools==2.0.7; python_version >= '3.5'
pycodestyle==2.11.1; python_version >= '3.8'
pycparser==2.22; python_version >= '3.8'
pycurl==7.45.3
pydantic==2.7.0; python_version >= '3.8'
pydantic-core==2.18.1; python_version >= '3.8'
pydantic==2.7.1; python_version >= '3.8'
pydantic-core==2.18.2; python_version >= '3.8'
pydantic-settings==2.2.1; python_version >= '3.8'
pyflakes==3.2.0; python_version >= '3.8'
pypandoc==1.13; python_version >= '3.6'
pyparsing==3.1.2; python_full_version >= '3.6.8'
pypdf==4.2.0; python_version >= '3.6'
pypdfium2==4.29.0; python_version >= '3.6'
pyright==1.1.359; python_version >= '3.7'
pyright==1.1.360; python_version >= '3.7'
pysbd==0.3.4; python_version >= '3'
pytesseract==0.3.10; python_version >= '3.7'
pytest==8.1.1; python_version >= '3.8'
Expand All @@ -201,7 +207,7 @@ pyyaml==6.0.1; python_version >= '3.6'
ragas==0.1.7
rapidfuzz==3.8.1; python_version >= '3.8'
realtime==1.0.4; python_version >= '3.8' and python_version < '4.0'
redis==5.0.3; python_version >= '3.7'
redis==5.0.4; python_version >= '3.7'
regex==2024.4.16; python_version >= '3.7'
requests==2.31.0; python_version >= '3.7'
requests-file==2.0.0
Expand Down Expand Up @@ -232,21 +238,22 @@ tiktoken==0.6.0; python_version >= '3.8'
timm==0.9.16; python_version >= '3.8'
tinysegmenter==0.3
tldextract==5.1.2; python_version >= '3.8'
tokenizers==0.15.2; python_version >= '3.7'
tokenizers==0.19.1; python_version >= '3.7'
torch==2.2.2
torchvision==0.17.2
tornado==6.4; python_version >= '3.8'
tqdm==4.66.2; python_version >= '3.7'
transformers==4.39.3; python_full_version >= '3.8.0'
transformers==4.40.1; python_full_version >= '3.8.0'
types-requests==2.31.0.20240406; python_version >= '3.8'
typing-extensions==4.11.0; python_version >= '3.8'
typing-inspect==0.9.0
tzdata==2024.1; python_version >= '2'
unidecode==1.3.8; python_version >= '3.5'
unstructured[all-docs]==0.13.2; python_version < '3.12' and python_full_version >= '3.9.0'
unstructured-client==0.18.0; python_version >= '3.8'
unstructured-inference==0.7.25
unstructured[all-docs]==0.13.3; python_version < '3.12' and python_full_version >= '3.9.0'
unstructured-client==0.22.0; python_version >= '3.8'
unstructured-inference==0.7.27
unstructured.pytesseract==0.3.12
urllib3==2.2.1; python_version >= '3.8'
urllib3==2.2.1; python_version >= '3.10'
uvicorn==0.29.0; python_version >= '3.8'
vine==5.1.0; python_version >= '3.6'
watchdog==4.0.0; python_version >= '3.8'
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/ragas_evaluation/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from modules.brain.service.brain_service import BrainService
from modules.knowledge.dto.inputs import CreateKnowledgeProperties
from modules.knowledge.service.knowledge_service import KnowledgeService
from modules.upload.service.upload_file import upload_file_storage
from ragas import evaluate
from ragas.embeddings.base import LangchainEmbeddingsWrapper
from modules.upload.service.upload_file import upload_file_storage


def main(
Expand Down Expand Up @@ -176,7 +176,7 @@ def generate_replies(
"--model", type=str, default="gpt-3.5-turbo-0125", help="Model to use"
)
parser.add_argument(
"--context_size", type=int, default=4000, help="Context size for the model"
"--context_size", type=int, default=10000, help="Context size for the model"
)
parser.add_argument(
"--metrics",
Expand Down

0 comments on commit f656dbc

Please sign in to comment.