Skip to content

Commit

Permalink
♻️ refactor apply summarization
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Mar 25, 2024
1 parent 77db4e8 commit fdbaca0
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 237 deletions.
133 changes: 21 additions & 112 deletions backend/rag_2/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@

from backend.rag_2 import prompts
from backend.rag_2.config import validate_config
from backend.rag_components.elements import Image, Table, Text
from backend.rag_components.ingest import add_elements_to_multivector_retriever
from backend.rag_components.llm import get_text_llm, get_vision_llm
from backend.rag_components.retriever import get_retriever
from backend.rag_components.summarization import (
generate_image_summaries,
generate_text_summaries,
from backend.rag_components.ingest import (
add_elements_to_multivector_retriever,
apply_summarize_image,
apply_summarize_table,
apply_summarize_text,
)
from backend.rag_components.retriever import get_retriever
from backend.rag_components.unstructured import (
load_chunking_func,
select_images,
Expand All @@ -30,108 +29,6 @@
logger = logging.getLogger(__name__)


async def apply_summarize_text(text_list: list[Text], config: DictConfig) -> None:
"""Apply text summarization to a list of Text elements.
The function directly modifies the Text elements inplace.
Args:
text_list (list[Text]): List of Text elements.
config (DictConfig): Configuration object.
"""
if config.ingest.summarize_text:
str_list = [text.text for text in text_list]

model = get_text_llm(config)

text_summaries = await generate_text_summaries(
str_list, prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT, model=model
)

for text in text_list:
text.set_summary(text_summaries.pop(0))

else:
logger.info("Skipping text summarization")

return


async def apply_summarize_table(table_list: list[Table], config: DictConfig) -> None:
"""Apply table summarization to a list of Table elements.
The function directly modifies the Table elements inplace.
Args:
table_list (list[Table]): List of Table elements.
config (DictConfig): Configuration object.
Raises:
ValueError: If the table format is "image" and summarize_table is False.
ValueError: If the table format is invalid.
"""
if config.ingest.summarize_table:
table_format = config.ingest.table_format
if table_format in ["text", "html"]:
str_list = [table.text for table in table_list]

model = get_text_llm(config)

table_summaries = await generate_text_summaries(
str_list,
prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,
model=model,
)
elif config.ingest.table_format == "image":
img_base64_list = [table.base64 for table in table_list]
img_mime_type_list = [table.mime_type for table in table_list]
model = get_vision_llm(config)

table_summaries = await generate_image_summaries(
img_base64_list,
img_mime_type_list,
prompt=prompts.TABLE_SUMMARIZATION_PROMPT,
model=model,
)
else:
raise ValueError(f"Invalid table format: {table_format}")

for table in table_list:
table.set_summary(table_summaries.pop(0))

else:
logger.info("Skipping table summarization")

return


async def apply_summarize_image(image_list: list[Image], config: DictConfig) -> None:
"""Apply image summarization to a list of Image elements.
The function directly modifies the Image elements inplace.
Args:
image_list (list[Image]): List of Image elements.
config (DictConfig): Configuration object.
"""
img_base64_list = [image.base64 for image in image_list]
img_mime_type_list = [image.mime_type for image in image_list]

model = get_vision_llm(config)

image_summaries = await generate_image_summaries(
img_base64_list,
img_mime_type_list,
prompt=prompts.IMAGE_SUMMARIZATION_PROMPT,
model=model,
)

for image in image_list:
image.set_summary(image_summaries.pop(0))

return


async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None:
"""Ingest a PDF file.
Expand Down Expand Up @@ -173,13 +70,25 @@ async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None:
)

# Summarize text
await apply_summarize_text(texts, config)
await apply_summarize_text(
text_list=texts,
config=config,
prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT,
)

# Summarize tables
await apply_summarize_table(tables, config)
await apply_summarize_table(
table_list=tables,
config=config,
prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,
)

# Summarize images
await apply_summarize_image(images, config)
await apply_summarize_image(
image_list=images,
config=config,
prompt_template=prompts.IMAGE_SUMMARIZATION_PROMPT,
)

retriever = get_retriever(config)

Expand Down
25 changes: 19 additions & 6 deletions backend/rag_2/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,16 @@
"from hydra import compose, initialize\n",
"from unstructured.partition.pdf import partition_pdf\n",
"\n",
"from backend.rag_2 import prompts\n",
"from backend.rag_2.chain import get_chain\n",
"from backend.rag_2.config import validate_config\n",
"from backend.rag_2.ingest import (\n",
"from backend.rag_components.elements import convert_documents_to_elements\n",
"from backend.rag_components.ingest import (\n",
" add_elements_to_multivector_retriever,\n",
" apply_summarize_image,\n",
" apply_summarize_table,\n",
" apply_summarize_text,\n",
")\n",
"from backend.rag_components.elements import convert_documents_to_elements\n",
"from backend.rag_components.ingest import add_elements_to_multivector_retriever\n",
"from backend.rag_components.retriever import get_retriever\n",
"from backend.rag_components.unstructured import (\n",
" load_chunking_func,\n",
Expand Down Expand Up @@ -295,7 +296,11 @@
"outputs": [],
"source": [
"# Summarize text\n",
"await apply_summarize_text(texts, config)\n",
"await apply_summarize_text(\n",
" text_list=texts,\n",
" config=config,\n",
" prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT,\n",
")\n",
"for text in texts[:N_DISPLAY]:\n",
" display(text)"
]
Expand All @@ -307,7 +312,11 @@
"outputs": [],
"source": [
"# Summarize tables\n",
"await apply_summarize_table(tables, config)\n",
"await apply_summarize_table(\n",
" table_list=tables,\n",
" config=config,\n",
" prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,\n",
")\n",
"for table in tables[:N_DISPLAY]:\n",
" display(table)"
]
Expand All @@ -319,7 +328,11 @@
"outputs": [],
"source": [
"# Summarize images\n",
"await apply_summarize_image(images, config)\n",
"await apply_summarize_image(\n",
" image_list=images,\n",
" config=config,\n",
" prompt_template=prompts.IMAGE_SUMMARIZATION_PROMPT,\n",
")\n",
"for image in images[:N_DISPLAY]:\n",
" display(image)"
]
Expand Down
133 changes: 21 additions & 112 deletions backend/rag_3/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@

from backend.rag_3 import prompts
from backend.rag_3.config import validate_config
from backend.rag_components.elements import Image, Table, Text
from backend.rag_components.ingest import add_elements_to_multivector_retriever
from backend.rag_components.llm import get_text_llm, get_vision_llm
from backend.rag_components.retriever import get_retriever
from backend.rag_components.summarization import (
generate_image_summaries,
generate_text_summaries,
from backend.rag_components.ingest import (
add_elements_to_multivector_retriever,
apply_summarize_image,
apply_summarize_table,
apply_summarize_text,
)
from backend.rag_components.retriever import get_retriever
from backend.rag_components.unstructured import (
load_chunking_func,
select_images,
Expand All @@ -30,108 +29,6 @@
logger = logging.getLogger(__name__)


async def apply_summarize_text(text_list: list[Text], config: DictConfig) -> None:
"""Apply text summarization to a list of Text elements.
The function directly modifies the Text elements inplace.
Args:
text_list (list[Text]): List of Text elements.
config (DictConfig): Configuration object.
"""
if config.ingest.summarize_text:
str_list = [text.text for text in text_list]

model = get_text_llm(config)

text_summaries = await generate_text_summaries(
str_list, prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT, model=model
)

for text in text_list:
text.set_summary(text_summaries.pop(0))

else:
logger.info("Skipping text summarization")

return


async def apply_summarize_table(table_list: list[Table], config: DictConfig) -> None:
"""Apply table summarization to a list of Table elements.
The function directly modifies the Table elements inplace.
Args:
table_list (list[Table]): List of Table elements.
config (DictConfig): Configuration object.
Raises:
ValueError: If the table format is "image" and summarize_table is False.
ValueError: If the table format is invalid.
"""
if config.ingest.summarize_table:
table_format = config.ingest.table_format
if table_format in ["text", "html"]:
str_list = [table.text for table in table_list]

model = get_text_llm(config)

table_summaries = await generate_text_summaries(
str_list,
prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,
model=model,
)
elif config.ingest.table_format == "image":
img_base64_list = [table.base64 for table in table_list]
img_mime_type_list = [table.mime_type for table in table_list]
model = get_vision_llm(config)

table_summaries = await generate_image_summaries(
img_base64_list,
img_mime_type_list,
prompt=prompts.TABLE_SUMMARIZATION_PROMPT,
model=model,
)
else:
raise ValueError(f"Invalid table format: {table_format}")

for table in table_list:
table.set_summary(table_summaries.pop(0))

else:
logger.info("Skipping table summarization")

return


async def apply_summarize_image(image_list: list[Image], config: DictConfig) -> None:
"""Apply image summarization to a list of Image elements.
The function directly modifies the Image elements inplace.
Args:
image_list (list[Image]): List of Image elements.
config (DictConfig): Configuration object.
"""
img_base64_list = [image.base64 for image in image_list]
img_mime_type_list = [image.mime_type for image in image_list]

model = get_vision_llm(config)

image_summaries = await generate_image_summaries(
img_base64_list,
img_mime_type_list,
prompt=prompts.IMAGE_SUMMARIZATION_PROMPT,
model=model,
)

for image in image_list:
image.set_summary(image_summaries.pop(0))

return


async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None:
"""Ingest a PDF file.
Expand Down Expand Up @@ -173,13 +70,25 @@ async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None:
)

# Summarize text
await apply_summarize_text(texts, config)
await apply_summarize_text(
text_list=texts,
config=config,
prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT,
)

# Summarize tables
await apply_summarize_table(tables, config)
await apply_summarize_table(
table_list=tables,
config=config,
prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,
)

# Summarize images
await apply_summarize_image(images, config)
await apply_summarize_image(
image_list=images,
config=config,
prompt_template=prompts.IMAGE_SUMMARIZATION_PROMPT,
)

retriever = get_retriever(config)

Expand Down
Loading

0 comments on commit fdbaca0

Please sign in to comment.