diff --git a/backend/rag_1/chain.py b/backend/rag_1/chain.py index eb488bc..a0c5103 100644 --- a/backend/rag_1/chain.py +++ b/backend/rag_1/chain.py @@ -1,7 +1,5 @@ """RAG chain for Option 1.""" -from langchain.schema import format_document -from langchain_core.documents import Document from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.output_parsers.string import StrOutputParser from langchain_core.runnables import RunnableLambda, RunnablePassthrough @@ -12,56 +10,19 @@ from backend.rag_components.chain_links.rag_with_history import ( construct_rag_with_history, ) -from backend.utils.image import resize_base64_image +from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import ( + fetch_docs_chain, +) from backend.utils.llm import get_vision_llm from backend.utils.retriever import get_retriever from . import prompts -def split_image_text_types(docs: list[Document]) -> dict[str, list]: - """Split base64-encoded images and texts. - - Args: - docs (list[Document]): List of documents. - - Returns: - dict[str, list]: Dictionary containing lists of images, mime types, and texts. - """ - img_base64_list = [] - img_mime_type_list = [] - text_list = [] - for doc in docs: - match doc.metadata["type"]: - case "text": - formatted_doc = format_document(doc, prompts.DOCUMENT_TEMPLATE) - text_list.append(formatted_doc) - case "image": - img = doc.page_content - img = resize_base64_image(img) - img_base64_list.append(resize_base64_image(img)) - img_mime_type_list.append(doc.metadata["mime_type"]) - case "table": - if doc.metadata["format"] == "image": - img = doc.page_content - img = resize_base64_image(img) - img_base64_list.append(img) - img_mime_type_list.append(doc.metadata["mime_type"]) - else: - formatted_doc = format_document(doc, prompts.DOCUMENT_TEMPLATE) - text_list.append(formatted_doc) - - return { - "images": img_base64_list, - "mime_types": img_mime_type_list, - "texts": text_list, - } - - def img_prompt_func( data_dict: dict, document_separator: str = "\n\n" ) -> list[BaseMessage]: - r"""Join the context into a single string. + r"""Join the context into a single string with images and the question. Args: data_dict (dict): Dictionary containing the context and question. @@ -127,7 +88,7 @@ def get_base_chain(config: DictConfig) -> RunnableSequence: # Define the RAG pipeline chain = ( { - "context": retriever | RunnableLambda(split_image_text_types), + "context": fetch_docs_chain(retriever), "question": RunnablePassthrough(), } | RunnableLambda(img_prompt_func) diff --git a/backend/rag_1/prompts.py b/backend/rag_1/prompts.py index d993f28..0114551 100644 --- a/backend/rag_1/prompts.py +++ b/backend/rag_1/prompts.py @@ -1,15 +1,5 @@ """Prompts for RAG Option 1.""" -DOCUMENT_TEMPLATE = """ -Document metadata: -- Filename: {filename} -- Page number: {page_number} -Document content: -### -{page_content} -### -""" - RAG_PROMPT = """You will be given a mixed of text, tables, and images usually of \ charts or graphs. Use this information to provide an answer to the user question. User-provided question: diff --git a/backend/rag_2/chain.py b/backend/rag_2/chain.py index 1b7727c..3eb7b6e 100644 --- a/backend/rag_2/chain.py +++ b/backend/rag_2/chain.py @@ -1,11 +1,8 @@ """RAG chain for Option 2.""" -from langchain.prompts import PromptTemplate -from langchain.schema import format_document -from langchain_core.documents import Document from langchain_core.output_parsers.string import StrOutputParser from langchain_core.prompts import ChatPromptTemplate -from langchain_core.runnables import RunnableLambda, RunnablePassthrough +from langchain_core.runnables import RunnablePassthrough from langchain_core.runnables.base import ( RunnableSequence, RunnableSerializable, @@ -16,27 +13,15 @@ from backend.rag_components.chain_links.rag_with_history import ( construct_rag_with_history, ) +from backend.rag_components.chain_links.retrieve_and_format_text_docs import ( + fetch_docs_chain, +) from backend.utils.llm import get_text_llm from backend.utils.retriever import get_retriever from . import prompts -def _combine_documents(docs: list[Document], document_separator: str = "\n\n") -> str: - r"""_summary_. - - Args: - docs (list[Document]): List of documents. - document_separator (str, optional): _description_. Defaults to "\n\n". - - Returns: - str: _description_ - """ - document_prompt = PromptTemplate.from_template(template=prompts.DOCUMENT_TEMPLATE) - doc_strings = [format_document(doc, document_prompt) for doc in docs] - return document_separator.join(doc_strings) - - class Question(BaseModel): """Question to be answered.""" @@ -73,7 +58,7 @@ def get_base_chain(config: DictConfig) -> RunnableSequence: # Define the RAG pipeline chain = ( { - "context": retriever | RunnableLambda(_combine_documents), + "context": fetch_docs_chain(retriever), "question": RunnablePassthrough(), } | prompt diff --git a/backend/rag_2/prompts.py b/backend/rag_2/prompts.py index 0679bdf..8a33025 100644 --- a/backend/rag_2/prompts.py +++ b/backend/rag_2/prompts.py @@ -18,16 +18,6 @@ for retrieval. These summaries will be embedded and used to retrieve the raw image. \ Give a concise summary of the image that is well optimized for retrieval.""" -DOCUMENT_TEMPLATE = """ -Document metadata: -- Filename: {filename} -- Page number: {page_number} -Document content: -### -{page_content} -### -""" - RAG_PROMPT = """You will be given a mixed of text, tables, and image summaries usually \ of charts or graphs. Use this information to provide an answer to the user question. User-provided question: diff --git a/backend/rag_3/chain.py b/backend/rag_3/chain.py index 266610a..58ed3ad 100644 --- a/backend/rag_3/chain.py +++ b/backend/rag_3/chain.py @@ -1,7 +1,5 @@ """RAG chain for Option 3.""" -from langchain.schema import format_document -from langchain_core.documents import Document from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.output_parsers.string import StrOutputParser from langchain_core.runnables import RunnableLambda, RunnablePassthrough @@ -15,56 +13,19 @@ from backend.rag_components.chain_links.rag_with_history import ( construct_rag_with_history, ) -from backend.utils.image import resize_base64_image +from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import ( + fetch_docs_chain, +) from backend.utils.llm import get_vision_llm from backend.utils.retriever import get_retriever from . import prompts -def split_image_text_types(docs: list[Document]) -> dict[str, list]: - """Split base64-encoded images and texts. - - Args: - docs (list[Document]): List of documents. - - Returns: - dict[str, list]: Dictionary containing lists of images, mime types, and texts. - """ - img_base64_list = [] - img_mime_type_list = [] - text_list = [] - for doc in docs: - match doc.metadata["type"]: - case "text": - formatted_doc = format_document(doc, prompts.DOCUMENT_TEMPLATE) - text_list.append(formatted_doc) - case "image": - img = doc.page_content - img = resize_base64_image(img) - img_base64_list.append(resize_base64_image(img)) - img_mime_type_list.append(doc.metadata["mime_type"]) - case "table": - if doc.metadata["format"] == "image": - img = doc.page_content - img = resize_base64_image(img) - img_base64_list.append(img) - img_mime_type_list.append(doc.metadata["mime_type"]) - else: - formatted_doc = format_document(doc, prompts.DOCUMENT_TEMPLATE) - text_list.append(formatted_doc) - - return { - "images": img_base64_list, - "mime_types": img_mime_type_list, - "texts": text_list, - } - - def img_prompt_func( data_dict: dict, document_separator: str = "\n\n" ) -> list[BaseMessage]: - r"""Join the context into a single string. + r"""Join the context into a single string with images and the question. Args: data_dict (dict): Dictionary containing the context and question. @@ -130,7 +91,7 @@ def get_base_chain(config: DictConfig) -> RunnableSequence: # Define the RAG pipeline chain = ( { - "context": retriever | RunnableLambda(split_image_text_types), + "context": fetch_docs_chain(retriever), "question": RunnablePassthrough(), } | RunnableLambda(img_prompt_func) diff --git a/backend/rag_3/prompts.py b/backend/rag_3/prompts.py index a39e744..bda0666 100644 --- a/backend/rag_3/prompts.py +++ b/backend/rag_3/prompts.py @@ -18,15 +18,6 @@ for retrieval. These summaries will be embedded and used to retrieve the raw image. \ Give a concise summary of the image that is well optimized for retrieval.""" -DOCUMENT_TEMPLATE = """ -Document metadata: -- Filename: {filename} -- Page number: {page_number} -Document content: -### -{page_content} -### -""" RAG_PROMPT = """You will be given a mixed of text, tables, and images usually of \ charts or graphs. Use this information to provide an answer to the user question. diff --git a/backend/rag_components/chain_links/retrieve_and_format_multimodal_docs.py b/backend/rag_components/chain_links/retrieve_and_format_multimodal_docs.py new file mode 100644 index 0000000..050df5f --- /dev/null +++ b/backend/rag_components/chain_links/retrieve_and_format_multimodal_docs.py @@ -0,0 +1,94 @@ +"""This chain fetches multimodal documents.""" + +from langchain.schema import format_document +from langchain_core.documents import Document +from langchain_core.prompts import PromptTemplate +from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables import RunnableLambda, RunnableSequence +from pydantic import BaseModel + +from backend.utils.image import resize_base64_image + +DOCUMENT_TEMPLATE = """\ +Document metadata: +- Filename: {filename} +- Page number: {page_number} +Document content: +### +{page_content} +### +""" +DOCUMENT_PROMPT = PromptTemplate.from_template(template=DOCUMENT_TEMPLATE) + + +class Question(BaseModel): + """Question to be answered.""" + + question: str + + +class Documents(BaseModel): + """Multimodal documents.""" + + images: list[str] + mime_types: list[str] + texts: list[str] + + +def fetch_docs_chain(retriever: BaseRetriever) -> RunnableSequence: + """Creates a chain that retrieves and processes multimodal documents. + + This chain first retrieves documents and then splits them into images and texts + based on their metadata. It then formats the documents into a structure that + separates base64-encoded images, their mime types, and text content. + + Args: + retriever (BaseRetriever): Retriever that fetches documents. + + Returns: + RunnableSequence: Langchain sequence. + """ + relevant_documents = retriever | RunnableLambda(_split_image_text_types) + typed_chain = relevant_documents.with_types( + input_type=Question, output_type=Documents + ) + return typed_chain + + +def _split_image_text_types(docs: list[Document]) -> dict[str, list]: + """Split base64-encoded images and texts. + + Args: + docs (list[Document]): List of documents. + + Returns: + dict[str, list]: Dictionary containing lists of images, mime types, and texts. + """ + img_base64_list = [] + img_mime_type_list = [] + text_list = [] + for doc in docs: + match doc.metadata["type"]: + case "text": + formatted_doc = format_document(doc, DOCUMENT_PROMPT) + text_list.append(formatted_doc) + case "image": + img = doc.page_content + img = resize_base64_image(img) + img_base64_list.append(resize_base64_image(img)) + img_mime_type_list.append(doc.metadata["mime_type"]) + case "table": + if doc.metadata["format"] == "image": + img = doc.page_content + img = resize_base64_image(img) + img_base64_list.append(img) + img_mime_type_list.append(doc.metadata["mime_type"]) + else: + formatted_doc = format_document(doc, DOCUMENT_PROMPT) + text_list.append(formatted_doc) + + return { + "images": img_base64_list, + "mime_types": img_mime_type_list, + "texts": text_list, + } diff --git a/backend/rag_components/chain_links/retrieve_and_format_text_docs.py b/backend/rag_components/chain_links/retrieve_and_format_text_docs.py new file mode 100644 index 0000000..4dcd3af --- /dev/null +++ b/backend/rag_components/chain_links/retrieve_and_format_text_docs.py @@ -0,0 +1,66 @@ +"""This chain fetches text documents and combines them into a single string.""" + +from langchain.schema import format_document +from langchain_core.documents import Document +from langchain_core.prompts import PromptTemplate +from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables.base import RunnableLambda, RunnableSequence +from pydantic import BaseModel + +DOCUMENT_TEMPLATE = """\ +Document metadata: +- Filename: {filename} +- Page number: {page_number} +Document content: +### +{page_content} +### +""" + + +class Question(BaseModel): + """Question to be answered.""" + + question: str + + +class Documents(BaseModel): + """Text documents.""" + + documents: str + + +def fetch_docs_chain(retriever: BaseRetriever) -> RunnableSequence: + """Creates a chain that retrieves and formats text documents. + + This chain uses the provided retriever to fetch text documents and then combines + them into a single string formatted according to a predefined template. The + resulting string includes metadata and content for each document. + + Args: + retriever (BaseRetriever): Retriever that fetches documents. + + Returns: + RunnableSequence: Langchain sequence. + """ + relevant_documents = retriever | RunnableLambda(_combine_documents) + typed_chain = relevant_documents.with_types( + input_type=Question, output_type=Documents + ) + return typed_chain + + +def _combine_documents(docs: list[Document], document_separator: str = "\n\n") -> str: + r"""Combine a list of text documents into a single string. + + Args: + docs (list[Document]): List of documents. + document_separator (str, optional): String to insert between each formatted + document. Defaults to "\n\n". + + Returns: + str: Single string containing all formatted documents + """ + document_prompt = PromptTemplate.from_template(template=DOCUMENT_TEMPLATE) + doc_strings = [format_document(doc, document_prompt) for doc in docs] + return document_separator.join(doc_strings)