Skip to content

Commit

Permalink
Merge pull request #12 from artefactory/feature/format-document
Browse files Browse the repository at this point in the history
add document formatting
  • Loading branch information
baptiste-pasquier authored Mar 25, 2024
2 parents dc32333 + 8ce8b92 commit 290f3b2
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 87 deletions.
53 changes: 10 additions & 43 deletions backend/rag_1/chain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""RAG chain for Option 1."""

from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
Expand All @@ -11,60 +10,28 @@
from backend.rag_components.chain_links.rag_with_history import (
construct_rag_with_history,
)
from backend.utils.image import resize_base64_image
from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import (
fetch_docs_chain,
)
from backend.utils.llm import get_vision_llm
from backend.utils.retriever import get_retriever

from . import prompts


def split_image_text_types(docs: list[Document]) -> dict[str, list]:
"""Split base64-encoded images and texts.
Args:
docs (list[Document]): List of documents.
Returns:
dict[str, list]: Dictionary containing lists of images, mime types, and texts.
"""
img_base64_list = []
img_mime_type_list = []
text_list = []
for doc in docs:
match doc.metadata["type"]:
case "text":
text_list.append(doc.page_content)
case "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(resize_base64_image(img))
img_mime_type_list.append(doc.metadata["mime_type"])
case "table":
if doc.metadata["format"] == "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(img)
img_mime_type_list.append(doc.metadata["mime_type"])
else:
text_list.append(doc.page_content)

return {
"images": img_base64_list,
"mime_types": img_mime_type_list,
"texts": text_list,
}


def img_prompt_func(data_dict: dict) -> list[BaseMessage]:
"""Join the context into a single string.
def img_prompt_func(
data_dict: dict, document_separator: str = "\n\n"
) -> list[BaseMessage]:
r"""Join the context into a single string with images and the question.
Args:
data_dict (dict): Dictionary containing the context and question.
document_separator (str, optional): _description_. Defaults to "\n\n".
Returns:
list[BaseMessage]: List of messages to be sent to the model.
"""
formatted_texts = "\n".join(data_dict["context"]["texts"])
formatted_texts = document_separator.join(data_dict["context"]["texts"])
messages = []

# Adding image(s) to the messages if present
Expand Down Expand Up @@ -121,7 +88,7 @@ def get_base_chain(config: DictConfig) -> RunnableSequence:
# Define the RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"context": fetch_docs_chain(retriever),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
Expand Down
5 changes: 4 additions & 1 deletion backend/rag_2/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from backend.rag_components.chain_links.rag_with_history import (
construct_rag_with_history,
)
from backend.rag_components.chain_links.retrieve_and_format_text_docs import (
fetch_docs_chain,
)
from backend.utils.llm import get_text_llm
from backend.utils.retriever import get_retriever

Expand Down Expand Up @@ -55,7 +58,7 @@ def get_base_chain(config: DictConfig) -> RunnableSequence:
# Define the RAG pipeline
chain = (
{
"context": retriever,
"context": fetch_docs_chain(retriever),
"question": RunnablePassthrough(),
}
| prompt
Expand Down
53 changes: 10 additions & 43 deletions backend/rag_3/chain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""RAG chain for Option 3."""

from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
Expand All @@ -14,60 +13,28 @@
from backend.rag_components.chain_links.rag_with_history import (
construct_rag_with_history,
)
from backend.utils.image import resize_base64_image
from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import (
fetch_docs_chain,
)
from backend.utils.llm import get_vision_llm
from backend.utils.retriever import get_retriever

from . import prompts


def split_image_text_types(docs: list[Document]) -> dict[str, list]:
"""Split base64-encoded images and texts.
Args:
docs (list[Document]): List of documents.
Returns:
dict[str, list]: Dictionary containing lists of images, mime types, and texts.
"""
img_base64_list = []
img_mime_type_list = []
text_list = []
for doc in docs:
match doc.metadata["type"]:
case "text":
text_list.append(doc.page_content)
case "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(resize_base64_image(img))
img_mime_type_list.append(doc.metadata["mime_type"])
case "table":
if doc.metadata["format"] == "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(img)
img_mime_type_list.append(doc.metadata["mime_type"])
else:
text_list.append(doc.page_content)

return {
"images": img_base64_list,
"mime_types": img_mime_type_list,
"texts": text_list,
}


def img_prompt_func(data_dict: dict) -> list[BaseMessage]:
"""Join the context into a single string.
def img_prompt_func(
data_dict: dict, document_separator: str = "\n\n"
) -> list[BaseMessage]:
r"""Join the context into a single string with images and the question.
Args:
data_dict (dict): Dictionary containing the context and question.
document_separator (str, optional): _description_. Defaults to "\n\n".
Returns:
list[BaseMessage]: List of messages to be sent to the model.
"""
formatted_texts = "\n".join(data_dict["context"]["texts"])
formatted_texts = document_separator.join(data_dict["context"]["texts"])
messages = []

# Adding image(s) to the messages if present
Expand Down Expand Up @@ -124,7 +91,7 @@ def get_base_chain(config: DictConfig) -> RunnableSequence:
# Define the RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"context": fetch_docs_chain(retriever),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""This chain fetches multimodal documents."""

from langchain.schema import format_document
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableLambda, RunnableSequence
from pydantic import BaseModel

from backend.utils.image import resize_base64_image

DOCUMENT_TEMPLATE = """\
Document metadata:
- Filename: {filename}
- Page number: {page_number}
Document content:
###
{page_content}
###
"""
DOCUMENT_PROMPT = PromptTemplate.from_template(template=DOCUMENT_TEMPLATE)


class Question(BaseModel):
"""Question to be answered."""

question: str


class Documents(BaseModel):
"""Multimodal documents."""

images: list[str]
mime_types: list[str]
texts: list[str]


def fetch_docs_chain(retriever: BaseRetriever) -> RunnableSequence:
"""Creates a chain that retrieves and processes multimodal documents.
This chain first retrieves documents and then splits them into images and texts
based on their metadata. It then formats the documents into a structure that
separates base64-encoded images, their mime types, and text content.
Args:
retriever (BaseRetriever): Retriever that fetches documents.
Returns:
RunnableSequence: Langchain sequence.
"""
relevant_documents = retriever | RunnableLambda(_split_image_text_types)
typed_chain = relevant_documents.with_types(
input_type=Question, output_type=Documents
)
return typed_chain


def _split_image_text_types(docs: list[Document]) -> dict[str, list]:
"""Split base64-encoded images and texts.
Args:
docs (list[Document]): List of documents.
Returns:
dict[str, list]: Dictionary containing lists of images, mime types, and texts.
"""
img_base64_list = []
img_mime_type_list = []
text_list = []
for doc in docs:
match doc.metadata["type"]:
case "text":
formatted_doc = format_document(doc, DOCUMENT_PROMPT)
text_list.append(formatted_doc)
case "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(resize_base64_image(img))
img_mime_type_list.append(doc.metadata["mime_type"])
case "table":
if doc.metadata["format"] == "image":
img = doc.page_content
img = resize_base64_image(img)
img_base64_list.append(img)
img_mime_type_list.append(doc.metadata["mime_type"])
else:
formatted_doc = format_document(doc, DOCUMENT_PROMPT)
text_list.append(formatted_doc)

return {
"images": img_base64_list,
"mime_types": img_mime_type_list,
"texts": text_list,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""This chain fetches text documents and combines them into a single string."""

from langchain.schema import format_document
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.base import RunnableLambda, RunnableSequence
from pydantic import BaseModel

DOCUMENT_TEMPLATE = """\
Document metadata:
- Filename: {filename}
- Page number: {page_number}
Document content:
###
{page_content}
###
"""


class Question(BaseModel):
"""Question to be answered."""

question: str


class Documents(BaseModel):
"""Text documents."""

documents: str


def fetch_docs_chain(retriever: BaseRetriever) -> RunnableSequence:
"""Creates a chain that retrieves and formats text documents.
This chain uses the provided retriever to fetch text documents and then combines
them into a single string formatted according to a predefined template. The
resulting string includes metadata and content for each document.
Args:
retriever (BaseRetriever): Retriever that fetches documents.
Returns:
RunnableSequence: Langchain sequence.
"""
relevant_documents = retriever | RunnableLambda(_combine_documents)
typed_chain = relevant_documents.with_types(
input_type=Question, output_type=Documents
)
return typed_chain


def _combine_documents(docs: list[Document], document_separator: str = "\n\n") -> str:
r"""Combine a list of text documents into a single string.
Args:
docs (list[Document]): List of documents.
document_separator (str, optional): String to insert between each formatted
document. Defaults to "\n\n".
Returns:
str: Single string containing all formatted documents
"""
document_prompt = PromptTemplate.from_template(template=DOCUMENT_TEMPLATE)
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)

0 comments on commit 290f3b2

Please sign in to comment.