Skip to content

Commit

Permalink
refactor convert_documents_to_elements
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Mar 13, 2024
1 parent 5bd2cdb commit 89dbbbf
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 26 deletions.
95 changes: 71 additions & 24 deletions backend/utils/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from langchain_core.documents import Document
from pydantic import BaseModel, PrivateAttr, validator

from .image import local_image_to_base64


class Element(BaseModel):
"""Abstract base class representing an element with a type, format, and metadata.
Expand Down Expand Up @@ -311,6 +313,69 @@ class TableImage(Table, Image):
format: Literal["image"] = "image"


def _create_text_element(doc: Document, element_class: type[Text]) -> Element:
"""Create a text element from a Langchain Document object.
Args:
doc (Document): Langchain Document object.
element_class (Type[Text]): Text element class to create (Text, TableText).
Returns:
Element: Text element created from the Document object.
"""
is_summary = doc.metadata.get("is_summary", False)
if is_summary:
element = element_class(
type=doc.metadata["type"],
format=doc.metadata["format"],
text="No content available",
metadata=doc.metadata,
)
element.set_summary(doc.page_content)
else:
element = element_class(
type=doc.metadata["type"],
format=doc.metadata["format"],
text=doc.page_content,
metadata=doc.metadata,
)
return element


NO_IMAGE = local_image_to_base64("img/no_image.png")


def _create_image_element(doc: Document, element_class: type[Image]) -> Element:
"""Create an image element from a Langchain Document object.
Args:
doc (Document): Langchain Document object.
element_class (Type[Image]): Image element class to create (Image, TableImage).
Returns:
Element: Image element created from the Document object.
"""
is_summary = doc.metadata.get("is_summary", False)
if is_summary:
element = element_class(
type=doc.metadata["type"],
format=doc.metadata["format"],
base64=NO_IMAGE,
mime_type="image/png",
metadata=doc.metadata,
)
element.set_summary(doc.page_content)
else:
element = element_class(
type=doc.metadata["type"],
format=doc.metadata["format"],
base64=doc.page_content,
mime_type=doc.metadata["mime_type"],
metadata=doc.metadata,
)
return element


def convert_documents_to_elements(docs: list[Document]) -> list:
"""Convert a list of Langchain Document objects to a list of Element objects.
Expand All @@ -327,37 +392,19 @@ def convert_documents_to_elements(docs: list[Document]) -> list:
for doc in docs:
match doc.metadata["type"]:
case "text":
text_format = doc.metadata["format"]
element = Text(
text=doc.page_content, format=text_format, metadata=doc.metadata
)
element = _create_text_element(doc, Text)
case "image":
element = Image(
base64=doc.page_content,
mime_type=doc.metadata["mime_type"],
metadata=doc.metadata,
)

element = _create_image_element(doc, Image)
case "table":
table_format = doc.metadata["format"]
match table_format:
match doc.metadata["format"]:
case "text" | "html" | "markdown":
element = TableText(
text=doc.page_content,
format=table_format,
metadata=doc.metadata,
)

element = _create_text_element(doc, TableText)
case "image":
element = TableImage(
base64=doc.page_content,
mime_type=doc.metadata["mime_type"],
metadata=doc.metadata,
)
element = _create_image_element(doc, TableImage)
case other:
raise ValueError(f"Unsupported table format: {other}")
case other:
raise ValueError(f"Unsupported document type: {other['type']}")
raise ValueError(f"Unsupported document type: {other}")
elements.append(element)

return elements
14 changes: 14 additions & 0 deletions backend/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import io
from pathlib import Path

from PIL import Image

Expand Down Expand Up @@ -32,3 +33,16 @@ def resize_base64_image(

# Encode the resized image to Base64
return base64.b64encode(buffered.getvalue()).decode("utf-8")


def local_image_to_base64(image_path: str) -> str:
"""Convert a local image to a Base64 string.
Args:
image_path (str): Path to the image.
Returns:
str: Base64 string.
"""
with Path(image_path).open("rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
4 changes: 2 additions & 2 deletions backend/utils/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def add_documents_multivector(
summary_docs = [
Document(
page_content=s,
metadata=metadata_list[i],
metadata={**metadata_list[i], "is_summary": True},
)
for i, s in enumerate(summary_list)
]
Expand All @@ -88,7 +88,7 @@ def add_documents_multivector(
content_docs = [
Document(
page_content=c if isinstance(c, str) else c.page_content,
metadata=metadata_list[i],
metadata={**metadata_list[i], "is_summary": False},
)
for i, c in enumerate(content_list)
]
Expand Down
Binary file added img/no_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 89dbbbf

Please sign in to comment.