-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from artefactory/feature/format-document
add document formatting
- Loading branch information
Showing
5 changed files
with
184 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
94 changes: 94 additions & 0 deletions
94
backend/rag_components/chain_links/retrieve_and_format_multimodal_docs.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
"""This chain fetches multimodal documents.""" | ||
|
||
from langchain.schema import format_document | ||
from langchain_core.documents import Document | ||
from langchain_core.prompts import PromptTemplate | ||
from langchain_core.retrievers import BaseRetriever | ||
from langchain_core.runnables import RunnableLambda, RunnableSequence | ||
from pydantic import BaseModel | ||
|
||
from backend.utils.image import resize_base64_image | ||
|
||
DOCUMENT_TEMPLATE = """\ | ||
Document metadata: | ||
- Filename: {filename} | ||
- Page number: {page_number} | ||
Document content: | ||
### | ||
{page_content} | ||
### | ||
""" | ||
DOCUMENT_PROMPT = PromptTemplate.from_template(template=DOCUMENT_TEMPLATE) | ||
|
||
|
||
class Question(BaseModel): | ||
"""Question to be answered.""" | ||
|
||
question: str | ||
|
||
|
||
class Documents(BaseModel): | ||
"""Multimodal documents.""" | ||
|
||
images: list[str] | ||
mime_types: list[str] | ||
texts: list[str] | ||
|
||
|
||
def fetch_docs_chain(retriever: BaseRetriever) -> RunnableSequence: | ||
"""Creates a chain that retrieves and processes multimodal documents. | ||
This chain first retrieves documents and then splits them into images and texts | ||
based on their metadata. It then formats the documents into a structure that | ||
separates base64-encoded images, their mime types, and text content. | ||
Args: | ||
retriever (BaseRetriever): Retriever that fetches documents. | ||
Returns: | ||
RunnableSequence: Langchain sequence. | ||
""" | ||
relevant_documents = retriever | RunnableLambda(_split_image_text_types) | ||
typed_chain = relevant_documents.with_types( | ||
input_type=Question, output_type=Documents | ||
) | ||
return typed_chain | ||
|
||
|
||
def _split_image_text_types(docs: list[Document]) -> dict[str, list]: | ||
"""Split base64-encoded images and texts. | ||
Args: | ||
docs (list[Document]): List of documents. | ||
Returns: | ||
dict[str, list]: Dictionary containing lists of images, mime types, and texts. | ||
""" | ||
img_base64_list = [] | ||
img_mime_type_list = [] | ||
text_list = [] | ||
for doc in docs: | ||
match doc.metadata["type"]: | ||
case "text": | ||
formatted_doc = format_document(doc, DOCUMENT_PROMPT) | ||
text_list.append(formatted_doc) | ||
case "image": | ||
img = doc.page_content | ||
img = resize_base64_image(img) | ||
img_base64_list.append(resize_base64_image(img)) | ||
img_mime_type_list.append(doc.metadata["mime_type"]) | ||
case "table": | ||
if doc.metadata["format"] == "image": | ||
img = doc.page_content | ||
img = resize_base64_image(img) | ||
img_base64_list.append(img) | ||
img_mime_type_list.append(doc.metadata["mime_type"]) | ||
else: | ||
formatted_doc = format_document(doc, DOCUMENT_PROMPT) | ||
text_list.append(formatted_doc) | ||
|
||
return { | ||
"images": img_base64_list, | ||
"mime_types": img_mime_type_list, | ||
"texts": text_list, | ||
} |
66 changes: 66 additions & 0 deletions
66
backend/rag_components/chain_links/retrieve_and_format_text_docs.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""This chain fetches text documents and combines them into a single string.""" | ||
|
||
from langchain.schema import format_document | ||
from langchain_core.documents import Document | ||
from langchain_core.prompts import PromptTemplate | ||
from langchain_core.retrievers import BaseRetriever | ||
from langchain_core.runnables.base import RunnableLambda, RunnableSequence | ||
from pydantic import BaseModel | ||
|
||
DOCUMENT_TEMPLATE = """\ | ||
Document metadata: | ||
- Filename: {filename} | ||
- Page number: {page_number} | ||
Document content: | ||
### | ||
{page_content} | ||
### | ||
""" | ||
|
||
|
||
class Question(BaseModel): | ||
"""Question to be answered.""" | ||
|
||
question: str | ||
|
||
|
||
class Documents(BaseModel): | ||
"""Text documents.""" | ||
|
||
documents: str | ||
|
||
|
||
def fetch_docs_chain(retriever: BaseRetriever) -> RunnableSequence: | ||
"""Creates a chain that retrieves and formats text documents. | ||
This chain uses the provided retriever to fetch text documents and then combines | ||
them into a single string formatted according to a predefined template. The | ||
resulting string includes metadata and content for each document. | ||
Args: | ||
retriever (BaseRetriever): Retriever that fetches documents. | ||
Returns: | ||
RunnableSequence: Langchain sequence. | ||
""" | ||
relevant_documents = retriever | RunnableLambda(_combine_documents) | ||
typed_chain = relevant_documents.with_types( | ||
input_type=Question, output_type=Documents | ||
) | ||
return typed_chain | ||
|
||
|
||
def _combine_documents(docs: list[Document], document_separator: str = "\n\n") -> str: | ||
r"""Combine a list of text documents into a single string. | ||
Args: | ||
docs (list[Document]): List of documents. | ||
document_separator (str, optional): String to insert between each formatted | ||
document. Defaults to "\n\n". | ||
Returns: | ||
str: Single string containing all formatted documents | ||
""" | ||
document_prompt = PromptTemplate.from_template(template=DOCUMENT_TEMPLATE) | ||
doc_strings = [format_document(doc, document_prompt) for doc in docs] | ||
return document_separator.join(doc_strings) |