Skip to content

Commit

Permalink
refactor convert_documents_to_elements
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Mar 13, 2024
1 parent 66295d5 commit afee458
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 24 deletions.
99 changes: 75 additions & 24 deletions backend/utils/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from langchain_core.documents import Document
from pydantic import BaseModel, PrivateAttr, validator

from .image import local_image_to_base64


class Element(BaseModel):
"""Abstract base class representing an element with a type, format, and metadata.
Expand Down Expand Up @@ -311,6 +313,73 @@ class TableImage(Table, Image):
format: Literal["image"] = "image"


def _create_text_element(doc: Document, element_class: type[Text]) -> Element:
"""Create a text element from a Langchain Document object.
Args:
doc (Document): Langchain Document object.
element_class (Type[Text]): Text element class to create (Text, TableText).
Returns:
Element: Text element created from the Document object.
"""
match doc.metadata["source"]:
case "content":
element = element_class(
type=doc.metadata["type"],
format=doc.metadata["format"],
text=doc.page_content,
metadata=doc.metadata,
)
case "summary":
element = element_class(
type=doc.metadata["type"],
format=doc.metadata["format"],
text="No content available",
metadata=doc.metadata,
)
element.set_summary(doc.page_content)
case other:
raise ValueError(f"Unsupported element source: {other}")
return element


NO_IMAGE = local_image_to_base64("img/no_image.png")


def _create_image_element(doc: Document, element_class: type[Image]) -> Element:
"""Create an image element from a Langchain Document object.
Args:
doc (Document): Langchain Document object.
element_class (Type[Image]): Image element class to create (Image, TableImage).
Returns:
Element: Image element created from the Document object.
"""
match doc.metadata["source"]:
case "content":
element = element_class(
type=doc.metadata["type"],
format=doc.metadata["format"],
base64=doc.page_content,
mime_type=doc.metadata["mime_type"],
metadata=doc.metadata,
)
case "summary":
element = element_class(
type=doc.metadata["type"],
format=doc.metadata["format"],
base64=NO_IMAGE,
mime_type="image/png",
metadata=doc.metadata,
)
element.set_summary(doc.page_content)
case other:
raise ValueError(f"Unsupported element source: {other}")
return element


def convert_documents_to_elements(docs: list[Document]) -> list:
"""Convert a list of Langchain Document objects to a list of Element objects.
Expand All @@ -327,37 +396,19 @@ def convert_documents_to_elements(docs: list[Document]) -> list:
for doc in docs:
match doc.metadata["type"]:
case "text":
text_format = doc.metadata["format"]
element = Text(
text=doc.page_content, format=text_format, metadata=doc.metadata
)
element = _create_text_element(doc, Text)
case "image":
element = Image(
base64=doc.page_content,
mime_type=doc.metadata["mime_type"],
metadata=doc.metadata,
)

element = _create_image_element(doc, Image)
case "table":
table_format = doc.metadata["format"]
match table_format:
match doc.metadata["format"]:
case "text" | "html" | "markdown":
element = TableText(
text=doc.page_content,
format=table_format,
metadata=doc.metadata,
)

element = _create_text_element(doc, TableText)
case "image":
element = TableImage(
base64=doc.page_content,
mime_type=doc.metadata["mime_type"],
metadata=doc.metadata,
)
element = _create_image_element(doc, TableImage)
case other:
raise ValueError(f"Unsupported table format: {other}")
case other:
raise ValueError(f"Unsupported document type: {other['type']}")
raise ValueError(f"Unsupported document type: {other}")
elements.append(element)

return elements
14 changes: 14 additions & 0 deletions backend/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import io
from pathlib import Path

from PIL import Image

Expand Down Expand Up @@ -32,3 +33,16 @@ def resize_base64_image(

# Encode the resized image to Base64
return base64.b64encode(buffered.getvalue()).decode("utf-8")


def local_image_to_base64(image_path: str) -> str:
"""Convert a local image to a Base64 string.
Args:
image_path (str): Path to the image.
Returns:
str: Base64 string.
"""
with Path(image_path).open("rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
Binary file added img/no_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit afee458

Please sign in to comment.