From f929603ea5457ddf8103bf1e79e036f1acf4924c Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 30 Jan 2025 14:34:13 -0800 Subject: [PATCH 1/4] Add more airtable logging --- .../onyx/connectors/airtable/airtable_connector.py | 9 +++++++++ backend/onyx/connectors/connector_runner.py | 13 ++++++++++++- backend/onyx/connectors/models.py | 10 ++++++++++ backend/onyx/indexing/indexing_pipeline.py | 9 +++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) diff --git a/backend/onyx/connectors/airtable/airtable_connector.py b/backend/onyx/connectors/airtable/airtable_connector.py index 777f2137fc7..fc534799fe3 100644 --- a/backend/onyx/connectors/airtable/airtable_connector.py +++ b/backend/onyx/connectors/airtable/airtable_connector.py @@ -274,6 +274,11 @@ def _process_record( field_val = fields.get(field_name) field_type = field_schema.type + logger.debug( + f"Processing field '{field_name}' of type '{field_type}' " + f"for record '{record_id}'." + ) + field_sections, field_metadata = self._process_field( field_id=field_schema.id, field_name=field_name, @@ -327,8 +332,12 @@ def load_from_state(self) -> GenerateDocumentsOutput: primary_field_name = field.name break + logger.info(f"Starting to process Airtable records for {table.name}.") + record_documents: list[Document] = [] for record in records: + logger.info(f"Processing record {record['id']} of {table.name}.") + document = self._process_record( record=record, table_schema=table_schema, diff --git a/backend/onyx/connectors/connector_runner.py b/backend/onyx/connectors/connector_runner.py index 650aa76b127..f053371ea74 100644 --- a/backend/onyx/connectors/connector_runner.py +++ b/backend/onyx/connectors/connector_runner.py @@ -1,4 +1,5 @@ import sys +import time from datetime import datetime from onyx.connectors.interfaces import BaseConnector @@ -45,7 +46,17 @@ def __init__( def run(self) -> GenerateDocumentsOutput: """Adds additional exception logging to the connector.""" try: - yield from self.doc_batch_generator + start = time.monotonic() + for batch in self.doc_batch_generator: + # to know how long connector is taking + end = time.monotonic() + logger.debug( + f"Connector tool in {end - start} seconds to build a batch." + ) + start = end + + yield batch + except Exception: exc_type, _, exc_traceback = sys.exc_info() diff --git a/backend/onyx/connectors/models.py b/backend/onyx/connectors/models.py index ee66d4b50a9..41123318ada 100644 --- a/backend/onyx/connectors/models.py +++ b/backend/onyx/connectors/models.py @@ -150,6 +150,16 @@ class Document(DocumentBase): id: str # This must be unique or during indexing/reindexing, chunks will be overwritten source: DocumentSource + def get_total_char_length(self) -> int: + """Calculate the total character length of the document including sections, metadata, and identifiers.""" + section_length = sum(len(section.text) for section in self.sections) + identifier_length = len(self.semantic_identifier) + len(self.title or "") + metadata_length = sum( + len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v) + for k, v in self.metadata.items() + ) + return section_length + identifier_length + metadata_length + def to_short_descriptor(self) -> str: """Used when logging the identity of a document""" return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'" diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index ea7228a97cb..e965a5ca4a8 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -380,6 +380,15 @@ def index_doc_batch( new_docs=0, total_docs=len(filtered_documents), total_chunks=0 ) + doc_descriptors = [ + { + "doc_id": doc.id, + "doc_length": doc.get_total_char_length(), + } + for doc in ctx.updatable_docs + ] + logger.debug(f"Starting indexing process for documents: {doc_descriptors}") + logger.debug("Starting chunking") chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs) From 8a8c048d5d052e96642aab28a3f0f87519034325 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 30 Jan 2025 16:56:45 -0800 Subject: [PATCH 2/4] Add multithreading --- backend/onyx/configs/app_configs.py | 6 +++ .../connectors/airtable/airtable_connector.py | 48 ++++++++++++++----- backend/onyx/connectors/connector_runner.py | 6 +-- .../search_nlp_models.py | 46 ++++++++++++++++-- 4 files changed, 86 insertions(+), 20 deletions(-) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 3235f6127b5..d121d0517ea 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -478,6 +478,12 @@ # 0 disables this behavior and is the default. INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0) +# Enable multi-threaded embedding model calls for parallel processing +# Note: only applies for API-based embedding models +INDEXING_EMBEDDING_MODEL_NUM_THREADS = int( + os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1 +) + # During an indexing attempt, specifies the number of batches which are allowed to # exception without aborting the attempt. INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0) diff --git a/backend/onyx/connectors/airtable/airtable_connector.py b/backend/onyx/connectors/airtable/airtable_connector.py index fc534799fe3..5fcf8585454 100644 --- a/backend/onyx/connectors/airtable/airtable_connector.py +++ b/backend/onyx/connectors/airtable/airtable_connector.py @@ -1,3 +1,5 @@ +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor from io import BytesIO from typing import Any @@ -312,7 +314,7 @@ def _process_record( def load_from_state(self) -> GenerateDocumentsOutput: """ - Fetch all records from the table. + Fetch all records from the table in parallel batches. NOTE: Airtable does not support filtering by time updated, so we have to fetch all records every time. @@ -334,21 +336,43 @@ def load_from_state(self) -> GenerateDocumentsOutput: logger.info(f"Starting to process Airtable records for {table.name}.") - record_documents: list[Document] = [] - for record in records: - logger.info(f"Processing record {record['id']} of {table.name}.") - - document = self._process_record( - record=record, - table_schema=table_schema, - primary_field_name=primary_field_name, - ) - if document: - record_documents.append(document) + # Process records in parallel batches using ThreadPoolExecutor + PARALLEL_BATCH_SIZE = 16 + max_workers = min(PARALLEL_BATCH_SIZE, len(records)) + + # Process records in batches + for i in range(0, len(records), PARALLEL_BATCH_SIZE): + batch_records = records[i : i + PARALLEL_BATCH_SIZE] + record_documents: list[Document] = [] + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit batch tasks + future_to_record = { + executor.submit( + self._process_record, + record=record, + table_schema=table_schema, + primary_field_name=primary_field_name, + ): record + for record in batch_records + } + + # Wait for all tasks in this batch to complete + for future in as_completed(future_to_record): + record = future_to_record[future] + try: + document = future.result() + if document: + record_documents.append(document) + except Exception as e: + logger.exception(f"Failed to process record {record['id']}") + raise e + # After batch is complete, yield if we've hit the batch size if len(record_documents) >= self.batch_size: yield record_documents record_documents = [] + # Yield any remaining records if record_documents: yield record_documents diff --git a/backend/onyx/connectors/connector_runner.py b/backend/onyx/connectors/connector_runner.py index f053371ea74..ffb35f4e64a 100644 --- a/backend/onyx/connectors/connector_runner.py +++ b/backend/onyx/connectors/connector_runner.py @@ -49,14 +49,14 @@ def run(self) -> GenerateDocumentsOutput: start = time.monotonic() for batch in self.doc_batch_generator: # to know how long connector is taking - end = time.monotonic() logger.debug( - f"Connector tool in {end - start} seconds to build a batch." + f"Connector took {time.monotonic() - start} seconds to build a batch." ) - start = end yield batch + start = time.monotonic() + except Exception: exc_type, _, exc_traceback = sys.exc_info() diff --git a/backend/onyx/natural_language_processing/search_nlp_models.py b/backend/onyx/natural_language_processing/search_nlp_models.py index b7e54e81aff..cfeb660cfdf 100644 --- a/backend/onyx/natural_language_processing/search_nlp_models.py +++ b/backend/onyx/natural_language_processing/search_nlp_models.py @@ -1,6 +1,8 @@ import threading import time from collections.abc import Callable +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor from functools import wraps from typing import Any @@ -11,6 +13,7 @@ from requests import Response from retry import retry +from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS from onyx.configs.app_configs import LARGE_CHUNK_RATIO from onyx.configs.app_configs import SKIP_WARM_UP from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS @@ -155,6 +158,7 @@ def _batch_encode_texts( text_type: EmbedTextType, batch_size: int, max_seq_length: int, + num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS, ) -> list[Embedding]: text_batches = batch_list(texts, batch_size) @@ -163,12 +167,15 @@ def _batch_encode_texts( ) embeddings: list[Embedding] = [] - for idx, text_batch in enumerate(text_batches, start=1): + + def process_batch( + batch_idx: int, text_batch: list[str] + ) -> tuple[int, list[Embedding]]: if self.callback: if self.callback.should_stop(): raise RuntimeError("_batch_encode_texts detected stop signal") - logger.debug(f"Encoding batch {idx} of {len(text_batches)}") + logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}") embed_request = EmbedRequest( model_name=self.model_name, texts=text_batch, @@ -185,10 +192,39 @@ def _batch_encode_texts( ) response = self._make_model_server_request(embed_request) - embeddings.extend(response.embeddings) + return batch_idx, response.embeddings + + if num_threads >= 1 and self.provider_type and len(text_batches) > 1: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + future_to_batch = { + executor.submit(process_batch, idx, batch): idx + for idx, batch in enumerate(text_batches, start=1) + } + + # Collect results in order + batch_results: list[tuple[int, list[Embedding]]] = [] + for future in as_completed(future_to_batch): + try: + result = future.result() + batch_results.append(result) + if self.callback: + self.callback.progress("_batch_encode_texts", 1) + except Exception as e: + logger.exception("Embedding model failed to process batch") + raise e + + # Sort by batch index and extend embeddings + batch_results.sort(key=lambda x: x[0]) + for _, batch_embeddings in batch_results: + embeddings.extend(batch_embeddings) + else: + # Original sequential processing + for idx, text_batch in enumerate(text_batches, start=1): + _, batch_embeddings = process_batch(idx, text_batch) + embeddings.extend(batch_embeddings) + if self.callback: + self.callback.progress("_batch_encode_texts", 1) - if self.callback: - self.callback.progress("_batch_encode_texts", 1) return embeddings def encode( From 5adc56d8830c06d0d8c58b1426d84eb5dc8fb32f Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 30 Jan 2025 17:02:16 -0800 Subject: [PATCH 3/4] Add multithreading --- backend/onyx/connectors/airtable/airtable_connector.py | 2 +- .../onyx/natural_language_processing/search_nlp_models.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/backend/onyx/connectors/airtable/airtable_connector.py b/backend/onyx/connectors/airtable/airtable_connector.py index 5fcf8585454..211fe3b44ce 100644 --- a/backend/onyx/connectors/airtable/airtable_connector.py +++ b/backend/onyx/connectors/airtable/airtable_connector.py @@ -314,7 +314,7 @@ def _process_record( def load_from_state(self) -> GenerateDocumentsOutput: """ - Fetch all records from the table in parallel batches. + Fetch all records from the table. NOTE: Airtable does not support filtering by time updated, so we have to fetch all records every time. diff --git a/backend/onyx/natural_language_processing/search_nlp_models.py b/backend/onyx/natural_language_processing/search_nlp_models.py index cfeb660cfdf..c8b473643f5 100644 --- a/backend/onyx/natural_language_processing/search_nlp_models.py +++ b/backend/onyx/natural_language_processing/search_nlp_models.py @@ -194,6 +194,10 @@ def process_batch( response = self._make_model_server_request(embed_request) return batch_idx, response.embeddings + # only multi thread if: + # 1. num_threads is greater than 1 + # 2. we are using an API-based embedding model (provider_type is not None) + # 3. there are more than 1 batch (no point in threading if only 1) if num_threads >= 1 and self.provider_type and len(text_batches) > 1: with ThreadPoolExecutor(max_workers=num_threads) as executor: future_to_batch = { @@ -292,7 +296,7 @@ def from_db_model( ) -class RerankingModel: +class RerankingModel: # def __init__( self, model_name: str, From cbce73474d48f2d7a7605dae7af7bd235aa990d1 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 30 Jan 2025 17:32:37 -0800 Subject: [PATCH 4/4] Remove empty comment --- backend/onyx/natural_language_processing/search_nlp_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/onyx/natural_language_processing/search_nlp_models.py b/backend/onyx/natural_language_processing/search_nlp_models.py index c8b473643f5..2f6c6d30635 100644 --- a/backend/onyx/natural_language_processing/search_nlp_models.py +++ b/backend/onyx/natural_language_processing/search_nlp_models.py @@ -296,7 +296,7 @@ def from_db_model( ) -class RerankingModel: # +class RerankingModel: def __init__( self, model_name: str,