Skip to content

Commit

Permalink
✨ add config in chain call
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed May 2, 2024
1 parent bef92dd commit 476759e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
26 changes: 22 additions & 4 deletions backend/rag_components/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def add_elements_to_multivector_retriever(


async def apply_summarize_text(
text_list: list[Text], config: DictConfig, prompt_template: str
text_list: list[Text],
config: DictConfig,
prompt_template: str,
chain_config: dict | None = None,
) -> None:
"""Apply text summarization to a list of Text elements.
Expand All @@ -83,14 +86,18 @@ async def apply_summarize_text(
text_list (list[Text]): List of Text elements.
config (DictConfig): Configuration object.
prompt_template (str): Prompt template for the summarization.
chain_config (dict, optional): Configuration for the chain. Defaults to None.
"""
if config.ingest.summarize_text:
str_list = [text.text for text in text_list]

model = get_text_llm(config)

text_summaries = await generate_text_summaries(
str_list, prompt_template=prompt_template, model=model
str_list,
prompt_template=prompt_template,
model=model,
chain_config=chain_config,
)

for text in text_list:
Expand All @@ -103,7 +110,10 @@ async def apply_summarize_text(


async def apply_summarize_table(
table_list: list[Table], config: DictConfig, prompt_template: str
table_list: list[Table],
config: DictConfig,
prompt_template: str,
chain_config: dict | None = None,
) -> None:
"""Apply table summarization to a list of Table elements.
Expand All @@ -113,6 +123,7 @@ async def apply_summarize_table(
table_list (list[Table]): List of Table elements.
config (DictConfig): Configuration object.
prompt_template (str): Prompt template for the summarization.
chain_config (dict, optional): Configuration for the chain. Defaults to None.
Raises:
ValueError: If the table format is "image" and summarize_table is False.
Expand All @@ -129,6 +140,7 @@ async def apply_summarize_table(
str_list,
prompt_template=prompt_template,
model=model,
chain_config=chain_config,
)
elif config.ingest.table_format == "image":
img_base64_list = [table.base64 for table in table_list]
Expand All @@ -140,6 +152,7 @@ async def apply_summarize_table(
img_mime_type_list,
prompt=prompt_template,
model=model,
chain_config=chain_config,
)
else:
raise ValueError(f"Invalid table format: {table_format}")
Expand All @@ -154,7 +167,10 @@ async def apply_summarize_table(


async def apply_summarize_image(
image_list: list[Image], config: DictConfig, prompt_template: str
image_list: list[Image],
config: DictConfig,
prompt_template: str,
chain_config: dict | None = None,
) -> None:
"""Apply image summarization to a list of Image elements.
Expand All @@ -164,6 +180,7 @@ async def apply_summarize_image(
image_list (list[Image]): List of Image elements.
config (DictConfig): Configuration object.
prompt_template (str): Prompt template for the summarization.
chain_config (dict, optional): Configuration for the chain. Defaults to None.
"""
img_base64_list = [image.base64 for image in image_list]
img_mime_type_list = [image.mime_type for image in image_list]
Expand All @@ -175,6 +192,7 @@ async def apply_summarize_image(
img_mime_type_list,
prompt=prompt_template,
model=model,
chain_config=chain_config,
)

for image in image_list:
Expand Down
17 changes: 13 additions & 4 deletions backend/rag_components/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
# after=after_log(logger, logging.INFO),
before_sleep=before_sleep_log(logger, logging.INFO),
)
async def abatch_with_retry(chain: Runnable, batch: list) -> list:
async def abatch_with_retry(
chain: Runnable, batch: list, config: dict | None = None
) -> list:
"""Process a batch of items, applying retries on failure.
This function is designed to handle exceptions such as rate limits or bad requests
Expand All @@ -40,18 +42,20 @@ async def abatch_with_retry(chain: Runnable, batch: list) -> list:
Args:
chain (Runnable): Langchain chain.
batch (list): List of items to be processed by the chain.
config (dict, optional): Configuration for the chain. Defaults to None.
Returns:
list: List of results.
"""
return await chain.abatch(batch)
return await chain.abatch(batch, config=config)


async def generate_text_summaries(
text_list: list[str],
prompt_template: str,
model: BaseChatModel,
batch_size: int = 10,
chain_config: dict | None = None,
) -> list[str]:
"""Generate summaries for a list of texts.
Expand All @@ -61,6 +65,7 @@ async def generate_text_summaries(
model (BaseChatModel): Language model used for generating summaries.
batch_size (int, optional): Number of texts to process simultaneously in the API
request. Defaults to 50.
chain_config (dict, optional): Configuration for the chain. Defaults to None.
Returns:
list[str]: List of summaries for the texts.
Expand All @@ -81,7 +86,9 @@ async def generate_text_summaries(
# Process texts in batches
for i in range(0, len(text_list), batch_size):
batch = text_list[i : i + batch_size]
batch_summaries = await abatch_with_retry(summarize_chain, batch)
batch_summaries = await abatch_with_retry(
summarize_chain, batch, config=chain_config
)
text_summaries.extend(batch_summaries)

return text_summaries
Expand All @@ -93,6 +100,7 @@ async def generate_image_summaries(
prompt: str,
model: BaseChatModel,
batch_size: int = 10,
chain_config: dict | None = None,
) -> list[str]:
"""Generate summaries for a list of images encoded in base64.
Expand All @@ -104,6 +112,7 @@ async def generate_image_summaries(
model (BaseChatModel): Language model used for generating summaries.
batch_size (int, optional): Number of images to process simultaneously in the
API request. Defaults to 50.
chain_config (dict, optional): Configuration for the chain. Defaults to None.
Returns:
list[str]: List of summaries for the images.
Expand Down Expand Up @@ -150,7 +159,7 @@ def _get_messages_from_url(_dict: dict[str, str]) -> Sequence[BaseMessage]:
strict=False,
)
]
batch_summaries = await abatch_with_retry(chain, batch)
batch_summaries = await abatch_with_retry(chain, batch, config=chain_config)
image_summaries.extend(batch_summaries)

return image_summaries

0 comments on commit 476759e

Please sign in to comment.