Skip to content

Commit

Permalink
ingest refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Mar 13, 2024
1 parent f1efcc4 commit 66295d5
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 102 deletions.
9 changes: 0 additions & 9 deletions backend/rag_1/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,6 @@ def ingest_pdf(file_path: str | Path, config: DictConfig) -> None:
metadatas=table_metadata,
)

# Add tables to vectorstore
table_contents = [table.get_content() for table in tables]
table_metadata = [table.get_metadata() for table in tables]

vectorstore.add_texts(
texts=table_contents,
metadatas=table_metadata,
)

# Add images to retriever
image_path = [image.get_local_path() for image in images]
image_metadata = [image.get_metadata() for image in images]
Expand Down
2 changes: 2 additions & 0 deletions backend/rag_1/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import shutil\n",
"import time\n",
"from pathlib import Path\n",
Expand All @@ -83,6 +84,7 @@
"from backend.utils.utils import format_time_delta\n",
"from backend.utils.vectorstore import get_vectorstore\n",
"\n",
"logging.getLogger(\"backend\").setLevel(logging.INFO)\n",
"t = time.time()"
]
},
Expand Down
72 changes: 54 additions & 18 deletions backend/rag_3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ class PathConfig:
export_extracted: str


@dataclass(config=ConfigDict(extra="forbid"))
class SourceConfig:
"""Configuration for the vectorstore or docstore source."""

text: Literal["content", "summary"]
table: Literal["content", "summary"]
image: Literal["content", "summary"]


@dataclass(config=ConfigDict(extra="forbid"))
class IngestConfig:
"""Configuration for PDF ingestion."""
Expand All @@ -39,36 +48,63 @@ class IngestConfig:
chunking_enable: bool
chunking_func: HydraObject

metadata_keys: list[str]
table_format: Literal["text", "html", "image"]

summarize_text: bool
summarize_table: bool

metadata_keys: list[str]
vectorstore_source: SourceConfig
docstore_source: SourceConfig

export_extracted: bool

@root_validator(pre=False)
def validate_table_format(cls, values: dict) -> dict:
"""Validate the 'table_format' field in relation to 'summarize_table'.
This validator ensures that if the 'table_format' is set to 'image',
then 'summarize_table' must also be set to True. It enforces the rule
that image tables require summarization.
@root_validator(pre=True)
def validate_fields(cls, values: dict) -> dict:
"""Various checks on the fields.
Args:
values (dict): Dictionnary of field values for the IngestConfig class.
Raises:
ValueError: If 'table_format' is 'image' and 'summarize_table' is not True.
values (dict): Field values.
Returns:
dict: The validated field values.
dict: Validated field values.
"""
table_format = values.get("table_format")
summarize_table = values.get("summarize_table")

if table_format == "image" and not summarize_table:
raise ValueError("summarize_table must be True for table_format=image")
table_format = values["table_format"]
summarize_text = values["summarize_text"]
summarize_table = values["summarize_table"]
vectorstore_source = values["vectorstore_source"]
docstore_source = values["docstore_source"]

# Check that summary is enabled when the source is set to "summary"
if vectorstore_source["text"] == "summary" and not summarize_text:
raise ValueError(
"vectorstore_source.text cannot be 'summary' when summarize_text is"
" False"
)
if vectorstore_source["table"] == "summary" and not summarize_table:
raise ValueError(
"vectorstore_source.table cannot be 'summary' when summarize_table is"
" False"
)
if docstore_source["text"] == "summary" and not summarize_text:
raise ValueError(
"docstore_source.text cannot be 'summary' when summarize_text is False"
)
if docstore_source["table"] == "summary" and not summarize_table:
raise ValueError(
"docstore_source.table cannot be 'summary' when summarize_table is"
" False"
)

# Check that the source of vectorstore is not set to "content" when the content
# is an image
if vectorstore_source["image"] == "content":
raise ValueError("vectorstore_source.image cannot be 'content'")
if table_format == "image" and vectorstore_source["table"] == "content":
raise ValueError(
"vectorstore_source.table cannot be 'content' when table_format is"
" 'image'"
)

return values

Expand Down
17 changes: 13 additions & 4 deletions backend/rag_3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,21 @@ ingest:
new_after_n_chars: 3800
combine_text_under_n_chars: 2000

metadata_keys:
- filename
- page_number
table_format: "image" # "text" or "html" or "image"

summarize_text: False
summarize_table: True

export_extracted: True
vectorstore_source: # retrieval step
text: "content"
table: "summary"
image: "summary"
docstore_source: # RAG step
text: "content"
table: "content"
image: "content"

metadata_keys:
- filename
- page_number
export_extracted: True
42 changes: 17 additions & 25 deletions backend/rag_3/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from backend.rag_3 import prompts
from backend.rag_3.config import validate_config
from backend.utils.elements import Image, Table, Text
from backend.utils.ingest import add_elements_to_multivector_retriever
from backend.utils.llm import get_text_llm, get_vision_llm
from backend.utils.retriever import add_documents_multivector, get_retriever
from backend.utils.retriever import get_retriever
from backend.utils.summarization import (
generate_image_summaries,
generate_text_summaries,
Expand Down Expand Up @@ -176,39 +177,30 @@ async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None:
retriever = get_retriever(config)

# Add texts to retriever
text_summaries = [text.get_summary() for text in texts]
text_contents = [text.get_content() for text in texts]
text_metadata = [text.get_metadata() for text in texts]

add_documents_multivector(
logger.info("Adding texts to retriever")
add_elements_to_multivector_retriever(
elements=texts,
retriever=retriever,
summary_list=text_summaries,
content_list=text_contents,
metadata_list=text_metadata,
vectorstore_source=config.ingest.vectorstore_source.text,
docstore_source=config.ingest.docstore_source.text,
)

# Add tables to retriever
table_summaries = [table.get_summary() for table in tables]
table_contents = [table.get_content() for table in tables]
table_metadata = [table.get_metadata() for table in tables]

add_documents_multivector(
logger.info("Adding tables to retriever")
add_elements_to_multivector_retriever(
elements=tables,
retriever=retriever,
summary_list=table_summaries,
content_list=table_contents,
metadata_list=table_metadata,
vectorstore_source=config.ingest.vectorstore_source.table,
docstore_source=config.ingest.docstore_source.table,
)

# Add images to retriever
image_summaries = [image.get_summary() for image in images]
image_contents = [image.get_content() for image in images]
image_metadata = [image.get_metadata() for image in images]

add_documents_multivector(
logger.info("Adding images to retriever")
add_elements_to_multivector_retriever(
elements=images,
retriever=retriever,
summary_list=image_summaries,
content_list=image_contents,
metadata_list=image_metadata,
vectorstore_source=config.ingest.vectorstore_source.image,
docstore_source=config.ingest.docstore_source.image,
)

# Export extracted elements
Expand Down
41 changes: 16 additions & 25 deletions backend/rag_3/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import shutil\n",
"import time\n",
"from pathlib import Path\n",
Expand All @@ -79,7 +80,8 @@
" apply_summarize_text,\n",
")\n",
"from backend.utils.elements import convert_documents_to_elements\n",
"from backend.utils.retriever import add_documents_multivector, get_retriever\n",
"from backend.utils.ingest import add_elements_to_multivector_retriever\n",
"from backend.utils.retriever import get_retriever\n",
"from backend.utils.unstructured import (\n",
" load_chunking_func,\n",
" select_images,\n",
Expand All @@ -88,6 +90,7 @@
")\n",
"from backend.utils.utils import format_time_delta\n",
"\n",
"logging.getLogger(\"backend\").setLevel(logging.INFO)\n",
"t = time.time()"
]
},
Expand Down Expand Up @@ -343,15 +346,11 @@
"outputs": [],
"source": [
"# Add texts to retriever\n",
"text_summaries = [text.get_summary() for text in texts]\n",
"text_contents = [text.get_content() for text in texts]\n",
"text_metadata = [text.get_metadata() for text in texts]\n",
"\n",
"add_documents_multivector(\n",
"add_elements_to_multivector_retriever(\n",
" elements=texts,\n",
" retriever=retriever,\n",
" summary_list=text_summaries,\n",
" content_list=text_contents,\n",
" metadata_list=text_metadata,\n",
" vectorstore_source=config.ingest.vectorstore_source.text,\n",
" docstore_source=config.ingest.docstore_source.text,\n",
")"
]
},
Expand All @@ -362,15 +361,11 @@
"outputs": [],
"source": [
"# Add tables to retriever\n",
"table_summaries = [table.get_summary() for table in tables]\n",
"table_contents = [table.get_content() for table in tables]\n",
"table_metadata = [table.get_metadata() for table in tables]\n",
"\n",
"add_documents_multivector(\n",
"add_elements_to_multivector_retriever(\n",
" elements=tables,\n",
" retriever=retriever,\n",
" summary_list=table_summaries,\n",
" content_list=table_contents,\n",
" metadata_list=table_metadata,\n",
" vectorstore_source=config.ingest.vectorstore_source.table,\n",
" docstore_source=config.ingest.docstore_source.table,\n",
")"
]
},
Expand All @@ -381,15 +376,11 @@
"outputs": [],
"source": [
"# Add images to retriever\n",
"image_summaries = [image.get_summary() for image in images]\n",
"image_contents = [image.get_content() for image in images]\n",
"image_metadata = [image.get_metadata() for image in images]\n",
"\n",
"add_documents_multivector(\n",
"add_elements_to_multivector_retriever(\n",
" elements=images,\n",
" retriever=retriever,\n",
" summary_list=image_summaries,\n",
" content_list=image_contents,\n",
" metadata_list=image_metadata,\n",
" vectorstore_source=config.ingest.vectorstore_source.image,\n",
" docstore_source=config.ingest.docstore_source.image,\n",
")"
]
},
Expand Down
66 changes: 66 additions & 0 deletions backend/utils/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Ingest utility functions."""

import logging
from collections.abc import Sequence

from langchain.retrievers.multi_vector import MultiVectorRetriever

from .elements import Element
from .retriever import add_documents_multivector

logger = logging.getLogger(__name__)


def get_attr_from_elements(elements: Sequence[Element], attr: str) -> list:
"""Get a specific attribute from a list of elements.
Args:
elements (list[Element]): List of elements.
attr (str): Attribute to get from the elements.
Raises:
ValueError: If the attribute is not supported.
Returns:
list: List of the specified attribute from the elements.
"""
match attr:
case "content":
return [element.get_content() for element in elements]
case "summary":
return [element.get_summary() for element in elements]
case "metadata":
return [element.get_metadata() for element in elements]
case other:
raise ValueError(f"Unsupported attribute: {other}")


def add_elements_to_multivector_retriever(
elements: Sequence[Element],
retriever: MultiVectorRetriever,
vectorstore_source: str,
docstore_source: str,
) -> None:
"""Add a list of elements to the multi-vector retriever.
Args:
elements (Sequence[Element]): List of elements to add.
retriever (MultiVectorRetriever): Multi-vector retriever.
vectorstore_source (str): Attribute of the elements to add to the vectorstore.
docstore_source (str): Attribute of the elements to add to the docstore.
"""
vectorstore_content = get_attr_from_elements(elements, vectorstore_source)
docstore_content = get_attr_from_elements(elements, docstore_source)
metadata_list = get_attr_from_elements(elements, "metadata")

logging.info(f"Adding {vectorstore_source} to vectorstore.")
logging.info(f"Adding {docstore_source} to docstore.")

add_documents_multivector(
retriever=retriever,
vectorstore_content=vectorstore_content,
docstore_content=docstore_content,
metadata_list=metadata_list,
vectorstore_source=vectorstore_source,
docstore_source=docstore_source,
)
Loading

0 comments on commit 66295d5

Please sign in to comment.