Skip to content

Commit

Permalink
♻️ refactor document retrieval and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Mar 25, 2024
1 parent 7b03c1f commit a604d2a
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 137 deletions.
49 changes: 5 additions & 44 deletions backend/rag_1/chain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""RAG chain for Option 1."""

from langchain.schema import format_document
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
Expand All @@ -12,56 +10,19 @@
from backend.rag_components.chain_links.rag_with_history import (
construct_rag_with_history,
)
from backend.utils.image import resize_base64_image
from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import (
fetch_docs_chain,
)
from backend.utils.llm import get_vision_llm
from backend.utils.retriever import get_retriever

from . import prompts


def split_image_text_types(docs: list[Document]) -> dict[str, list]:
"""Split base64-encoded images and texts.
Args:
docs (list[Document]): List of documents.
Returns:
dict[str, list]: Dictionary containing lists of images, mime types, and texts.
"""
img_base64_list = []
img_mime_type_list = []
text_list = []
for doc in docs:
match doc.metadata["type"]:
case "text":
formatted_doc = format_document(doc, prompts.DOCUMENT_TEMPLATE)
text_list.append(formatted_doc)
case "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(resize_base64_image(img))
img_mime_type_list.append(doc.metadata["mime_type"])
case "table":
if doc.metadata["format"] == "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(img)
img_mime_type_list.append(doc.metadata["mime_type"])
else:
formatted_doc = format_document(doc, prompts.DOCUMENT_TEMPLATE)
text_list.append(formatted_doc)

return {
"images": img_base64_list,
"mime_types": img_mime_type_list,
"texts": text_list,
}


def img_prompt_func(
data_dict: dict, document_separator: str = "\n\n"
) -> list[BaseMessage]:
r"""Join the context into a single string.
r"""Join the context into a single string with images and the question.
Args:
data_dict (dict): Dictionary containing the context and question.
Expand Down Expand Up @@ -127,7 +88,7 @@ def get_base_chain(config: DictConfig) -> RunnableSequence:
# Define the RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"context": fetch_docs_chain(retriever),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
Expand Down
10 changes: 0 additions & 10 deletions backend/rag_1/prompts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
"""Prompts for RAG Option 1."""

DOCUMENT_TEMPLATE = """
Document metadata:
- Filename: {filename}
- Page number: {page_number}
Document content:
###
{page_content}
###
"""

RAG_PROMPT = """You will be given a mixed of text, tables, and images usually of \
charts or graphs. Use this information to provide an answer to the user question.
User-provided question:
Expand Down
25 changes: 5 additions & 20 deletions backend/rag_2/chain.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""RAG chain for Option 2."""

from langchain.prompts import PromptTemplate
from langchain.schema import format_document
from langchain_core.documents import Document
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables.base import (
RunnableSequence,
RunnableSerializable,
Expand All @@ -16,27 +13,15 @@
from backend.rag_components.chain_links.rag_with_history import (
construct_rag_with_history,
)
from backend.rag_components.chain_links.retrieve_and_format_text_docs import (
fetch_docs_chain,
)
from backend.utils.llm import get_text_llm
from backend.utils.retriever import get_retriever

from . import prompts


def _combine_documents(docs: list[Document], document_separator: str = "\n\n") -> str:
r"""_summary_.
Args:
docs (list[Document]): List of documents.
document_separator (str, optional): _description_. Defaults to "\n\n".
Returns:
str: _description_
"""
document_prompt = PromptTemplate.from_template(template=prompts.DOCUMENT_TEMPLATE)
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)


class Question(BaseModel):
"""Question to be answered."""

Expand Down Expand Up @@ -73,7 +58,7 @@ def get_base_chain(config: DictConfig) -> RunnableSequence:
# Define the RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(_combine_documents),
"context": fetch_docs_chain(retriever),
"question": RunnablePassthrough(),
}
| prompt
Expand Down
10 changes: 0 additions & 10 deletions backend/rag_2/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,6 @@
for retrieval. These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""

DOCUMENT_TEMPLATE = """
Document metadata:
- Filename: {filename}
- Page number: {page_number}
Document content:
###
{page_content}
###
"""

RAG_PROMPT = """You will be given a mixed of text, tables, and image summaries usually \
of charts or graphs. Use this information to provide an answer to the user question.
User-provided question:
Expand Down
49 changes: 5 additions & 44 deletions backend/rag_3/chain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""RAG chain for Option 3."""

from langchain.schema import format_document
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
Expand All @@ -15,56 +13,19 @@
from backend.rag_components.chain_links.rag_with_history import (
construct_rag_with_history,
)
from backend.utils.image import resize_base64_image
from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import (
fetch_docs_chain,
)
from backend.utils.llm import get_vision_llm
from backend.utils.retriever import get_retriever

from . import prompts


def split_image_text_types(docs: list[Document]) -> dict[str, list]:
"""Split base64-encoded images and texts.
Args:
docs (list[Document]): List of documents.
Returns:
dict[str, list]: Dictionary containing lists of images, mime types, and texts.
"""
img_base64_list = []
img_mime_type_list = []
text_list = []
for doc in docs:
match doc.metadata["type"]:
case "text":
formatted_doc = format_document(doc, prompts.DOCUMENT_TEMPLATE)
text_list.append(formatted_doc)
case "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(resize_base64_image(img))
img_mime_type_list.append(doc.metadata["mime_type"])
case "table":
if doc.metadata["format"] == "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(img)
img_mime_type_list.append(doc.metadata["mime_type"])
else:
formatted_doc = format_document(doc, prompts.DOCUMENT_TEMPLATE)
text_list.append(formatted_doc)

return {
"images": img_base64_list,
"mime_types": img_mime_type_list,
"texts": text_list,
}


def img_prompt_func(
data_dict: dict, document_separator: str = "\n\n"
) -> list[BaseMessage]:
r"""Join the context into a single string.
r"""Join the context into a single string with images and the question.
Args:
data_dict (dict): Dictionary containing the context and question.
Expand Down Expand Up @@ -130,7 +91,7 @@ def get_base_chain(config: DictConfig) -> RunnableSequence:
# Define the RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"context": fetch_docs_chain(retriever),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
Expand Down
9 changes: 0 additions & 9 deletions backend/rag_3/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,6 @@
for retrieval. These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""

DOCUMENT_TEMPLATE = """
Document metadata:
- Filename: {filename}
- Page number: {page_number}
Document content:
###
{page_content}
###
"""

RAG_PROMPT = """You will be given a mixed of text, tables, and images usually of \
charts or graphs. Use this information to provide an answer to the user question.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""This chain fetches multimodal documents."""

from langchain.schema import format_document
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableLambda, RunnableSequence
from pydantic import BaseModel

from backend.utils.image import resize_base64_image

DOCUMENT_TEMPLATE = """\
Document metadata:
- Filename: {filename}
- Page number: {page_number}
Document content:
###
{page_content}
###
"""
DOCUMENT_PROMPT = PromptTemplate.from_template(template=DOCUMENT_TEMPLATE)


class Question(BaseModel):
"""Question to be answered."""

question: str


class Documents(BaseModel):
"""Multimodal documents."""

images: list[str]
mime_types: list[str]
texts: list[str]


def fetch_docs_chain(retriever: BaseRetriever) -> RunnableSequence:
"""Creates a chain that retrieves and processes multimodal documents.
This chain first retrieves documents and then splits them into images and texts
based on their metadata. It then formats the documents into a structure that
separates base64-encoded images, their mime types, and text content.
Args:
retriever (BaseRetriever): Retriever that fetches documents.
Returns:
RunnableSequence: Langchain sequence.
"""
relevant_documents = retriever | RunnableLambda(_split_image_text_types)
typed_chain = relevant_documents.with_types(
input_type=Question, output_type=Documents
)
return typed_chain


def _split_image_text_types(docs: list[Document]) -> dict[str, list]:
"""Split base64-encoded images and texts.
Args:
docs (list[Document]): List of documents.
Returns:
dict[str, list]: Dictionary containing lists of images, mime types, and texts.
"""
img_base64_list = []
img_mime_type_list = []
text_list = []
for doc in docs:
match doc.metadata["type"]:
case "text":
formatted_doc = format_document(doc, DOCUMENT_PROMPT)
text_list.append(formatted_doc)
case "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(resize_base64_image(img))
img_mime_type_list.append(doc.metadata["mime_type"])
case "table":
if doc.metadata["format"] == "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(img)
img_mime_type_list.append(doc.metadata["mime_type"])
else:
formatted_doc = format_document(doc, DOCUMENT_PROMPT)
text_list.append(formatted_doc)

return {
"images": img_base64_list,
"mime_types": img_mime_type_list,
"texts": text_list,
}
Loading

0 comments on commit a604d2a

Please sign in to comment.