Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Process large docs in batches #7

Merged
merged 3 commits into from
May 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 75 additions & 21 deletions semantic_chunkers/chunkers/statistical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List
from typing import Any, List

import numpy as np

Expand All @@ -9,6 +9,8 @@
from semantic_chunkers.utils.text import tiktoken_length
from semantic_chunkers.utils.logger import logger

from tqdm.auto import tqdm


@dataclass
class ChunkStatistics:
Expand Down Expand Up @@ -62,19 +64,82 @@ def __init__(
self.enable_statistics = enable_statistics
self.statistics: ChunkStatistics

def __call__(self, docs: List[str]) -> List[List[Chunk]]:
"""Chunk documents into smaller chunks based on semantic similarity.
def _chunk(
self, splits: List[Any], batch_size: int = 64, enforce_max_tokens: bool = False
) -> List[Chunk]:
"""Merge splits into chunks using semantic similarity, with optional enforcement of maximum token limits per chunk.

:param splits: Splits to be merged into chunks.
:param batch_size: Number of splits to process in one batch.
:param enforce_max_tokens: If True, further split chunks that exceed the maximum token limit.

:return: List of chunks.
"""
# Split the docs that already exceed max_split_tokens to smaller chunks
if enforce_max_tokens:
new_splits = []
for split in splits:
token_count = tiktoken_length(split)
if token_count > self.max_split_tokens:
logger.info(
f"Single document exceeds the maximum token limit "
f"of {self.max_split_tokens}. "
"Splitting to sentences before semantically merging."
)
_splits = self._split(split)
new_splits.extend(_splits)
else:
new_splits.append(split)

splits = [split for split in new_splits if split and split.strip()]

chunks = []
last_split = None
for i in tqdm(range(0, len(splits), batch_size)):
batch_splits = splits[i : i + batch_size]
if last_split is not None:
batch_splits = last_split.splits + batch_splits

encoded_splits = self._encode_documents(batch_splits)
similarities = self._calculate_similarity_scores(encoded_splits)
if self.dynamic_threshold:
self._find_optimal_threshold(batch_splits, similarities)
else:
self.calculated_threshold = self.encoder.score_threshold
split_indices = self._find_split_indices(similarities=similarities)
doc_chunks = self._split_documents(
batch_splits, split_indices, similarities
)

if len(doc_chunks) > 1:
chunks.extend(doc_chunks[:-1])
last_split = doc_chunks[-1]
else:
last_split = doc_chunks[0]

if self.plot_chunks:
self.plot_similarity_scores(similarities, split_indices, doc_chunks)

if self.enable_statistics:
print(self.statistics)

if last_split:
chunks.append(last_split)

return chunks

def __call__(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]:
"""Split documents into smaller chunks based on semantic similarity.

:param docs: list of text documents to be split, if only wanted to
split a single document, pass it as a list with a single element.

:return: list of DocumentChunk objects containing the split documents.
:return: list of Chunk objects containing the split documents.
"""
if not docs:
raise ValueError("At least one document is required for splitting.")

all_chunks = []

for doc in docs:
token_count = tiktoken_length(doc)
if token_count > self.max_split_tokens:
Expand All @@ -83,23 +148,12 @@ def __call__(self, docs: List[str]) -> List[List[Chunk]]:
f"of {self.max_split_tokens}. "
"Splitting to sentences before semantically merging."
)
splits = self._split(doc)
encoded_splits = self._encode_documents(splits)
similarities = self._calculate_similarity_scores(encoded_splits)
if self.dynamic_threshold:
self._find_optimal_threshold(splits, similarities)
if isinstance(doc, str):
splits = self._split(doc)
doc_chunks = self._chunk(splits, batch_size=batch_size)
all_chunks.append(doc_chunks)
else:
self.calculated_threshold = self.encoder.score_threshold
split_indices = self._find_split_indices(similarities=similarities)
doc_chunks = self._split_documents(splits, split_indices, similarities)

if self.plot_chunks:
self.plot_similarity_scores(similarities, split_indices, doc_chunks)

if self.enable_statistics:
print(self.statistics)
all_chunks.append(doc_chunks)

raise ValueError("The document must be a string.")
return all_chunks

def _encode_documents(self, docs: List[str]) -> np.ndarray:
Expand Down
Loading