diff --git a/semantic_chunkers/chunkers/statistical.py b/semantic_chunkers/chunkers/statistical.py index a6997ba..7cf9f9b 100644 --- a/semantic_chunkers/chunkers/statistical.py +++ b/semantic_chunkers/chunkers/statistical.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import Any, List import numpy as np @@ -9,6 +9,8 @@ from semantic_chunkers.utils.text import tiktoken_length from semantic_chunkers.utils.logger import logger +from tqdm.auto import tqdm + @dataclass class ChunkStatistics: @@ -62,19 +64,82 @@ def __init__( self.enable_statistics = enable_statistics self.statistics: ChunkStatistics - def __call__(self, docs: List[str]) -> List[List[Chunk]]: - """Chunk documents into smaller chunks based on semantic similarity. + def _chunk( + self, splits: List[Any], batch_size: int = 64, enforce_max_tokens: bool = False + ) -> List[Chunk]: + """Merge splits into chunks using semantic similarity, with optional enforcement of maximum token limits per chunk. + + :param splits: Splits to be merged into chunks. + :param batch_size: Number of splits to process in one batch. + :param enforce_max_tokens: If True, further split chunks that exceed the maximum token limit. + + :return: List of chunks. + """ + # Split the docs that already exceed max_split_tokens to smaller chunks + if enforce_max_tokens: + new_splits = [] + for split in splits: + token_count = tiktoken_length(split) + if token_count > self.max_split_tokens: + logger.info( + f"Single document exceeds the maximum token limit " + f"of {self.max_split_tokens}. " + "Splitting to sentences before semantically merging." + ) + _splits = self._split(split) + new_splits.extend(_splits) + else: + new_splits.append(split) + + splits = [split for split in new_splits if split and split.strip()] + + chunks = [] + last_split = None + for i in tqdm(range(0, len(splits), batch_size)): + batch_splits = splits[i : i + batch_size] + if last_split is not None: + batch_splits = last_split.splits + batch_splits + + encoded_splits = self._encode_documents(batch_splits) + similarities = self._calculate_similarity_scores(encoded_splits) + if self.dynamic_threshold: + self._find_optimal_threshold(batch_splits, similarities) + else: + self.calculated_threshold = self.encoder.score_threshold + split_indices = self._find_split_indices(similarities=similarities) + doc_chunks = self._split_documents( + batch_splits, split_indices, similarities + ) + + if len(doc_chunks) > 1: + chunks.extend(doc_chunks[:-1]) + last_split = doc_chunks[-1] + else: + last_split = doc_chunks[0] + + if self.plot_chunks: + self.plot_similarity_scores(similarities, split_indices, doc_chunks) + + if self.enable_statistics: + print(self.statistics) + + if last_split: + chunks.append(last_split) + + return chunks + + def __call__(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]: + """Split documents into smaller chunks based on semantic similarity. :param docs: list of text documents to be split, if only wanted to split a single document, pass it as a list with a single element. - :return: list of DocumentChunk objects containing the split documents. + :return: list of Chunk objects containing the split documents. """ if not docs: raise ValueError("At least one document is required for splitting.") all_chunks = [] - for doc in docs: token_count = tiktoken_length(doc) if token_count > self.max_split_tokens: @@ -83,23 +148,12 @@ def __call__(self, docs: List[str]) -> List[List[Chunk]]: f"of {self.max_split_tokens}. " "Splitting to sentences before semantically merging." ) - splits = self._split(doc) - encoded_splits = self._encode_documents(splits) - similarities = self._calculate_similarity_scores(encoded_splits) - if self.dynamic_threshold: - self._find_optimal_threshold(splits, similarities) + if isinstance(doc, str): + splits = self._split(doc) + doc_chunks = self._chunk(splits, batch_size=batch_size) + all_chunks.append(doc_chunks) else: - self.calculated_threshold = self.encoder.score_threshold - split_indices = self._find_split_indices(similarities=similarities) - doc_chunks = self._split_documents(splits, split_indices, similarities) - - if self.plot_chunks: - self.plot_similarity_scores(similarities, split_indices, doc_chunks) - - if self.enable_statistics: - print(self.statistics) - all_chunks.append(doc_chunks) - + raise ValueError("The document must be a string.") return all_chunks def _encode_documents(self, docs: List[str]) -> np.ndarray: