diff --git a/libs/colbert/ragstack_colbert/__init__.py b/libs/colbert/ragstack_colbert/__init__.py index 0ed4fc0a4..a267af4bc 100644 --- a/libs/colbert/ragstack_colbert/__init__.py +++ b/libs/colbert/ragstack_colbert/__init__.py @@ -1,4 +1,5 @@ -""" +"""Ragstack Colbert: A ColBERT-based text retrieval system. + This package provides a suite of tools for encoding and retrieving text using the ColBERT model, integrated with a Cassandra database for scalable storage and retrieval operations. It includes classes for token embeddings, managing the vector store, and diff --git a/libs/colbert/ragstack_colbert/base_database.py b/libs/colbert/ragstack_colbert/base_database.py index 7c5282ac6..3371a26a7 100644 --- a/libs/colbert/ragstack_colbert/base_database.py +++ b/libs/colbert/ragstack_colbert/base_database.py @@ -1,6 +1,7 @@ -""" -This module defines abstract base classes for implementing storage mechanisms for text -chunk embeddings, specifically designed to work with ColBERT or similar embedding +"""Base Database module. + +This module defines abstract base classes for implementing storage mechanisms for +text chunk embeddings, specifically designed to work with ColBERT or similar embedding models. """ @@ -11,7 +12,8 @@ class BaseDatabase(ABC): - """ + """Base Database abstract class for ColBERT. + Abstract base class (ABC) for a storage system designed to hold vector representations of text chunks, typically generated by a ColBERT model or similar embedding model. @@ -23,10 +25,9 @@ class BaseDatabase(ABC): @abstractmethod def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: - """ - Stores a list of embedded text chunks in the vector store + """Stores a list of embedded text chunks in the vector store. - Parameters: + Args: chunks (List[Chunk]): A list of `Chunk` instances to be stored. Returns: @@ -35,10 +36,9 @@ def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: @abstractmethod def delete_chunks(self, doc_ids: List[str]) -> bool: - """ - Deletes chunks from the vector store based on their document id. + """Deletes chunks from the vector store based on their document id. - Parameters: + Args: doc_ids (List[str]): A list of document identifiers specifying the chunks to be deleted. @@ -50,10 +50,9 @@ def delete_chunks(self, doc_ids: List[str]) -> bool: async def aadd_chunks( self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100 ) -> List[Tuple[str, int]]: - """ - Stores a list of embedded text chunks in the vector store + """Stores a list of embedded text chunks in the vector store. - Parameters: + Args: chunks (List[Chunk]): A list of `Chunk` instances to be stored. concurrent_inserts (Optional[int]): How many concurrent inserts to make to the database. Defaults to 100. @@ -66,10 +65,9 @@ async def aadd_chunks( async def adelete_chunks( self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100 ) -> bool: - """ - Deletes chunks from the vector store based on their document id. + """Deletes chunks from the vector store based on their document id. - Parameters: + Args: doc_ids (List[str]): A list of document identifiers specifying the chunks to be deleted. concurrent_deletes (Optional[int]): How many concurrent deletes to make @@ -81,8 +79,7 @@ async def adelete_chunks( @abstractmethod async def search_relevant_chunks(self, vector: Vector, n: int) -> List[Chunk]: - """ - Retrieves 'n' ANN results for an embedded token vector. + """Retrieves 'n' ANN results for an embedded token vector. Returns: A list of Chunks with only `doc_id` and `chunk_id` set. @@ -91,8 +88,7 @@ async def search_relevant_chunks(self, vector: Vector, n: int) -> List[Chunk]: @abstractmethod async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk: - """ - Retrieve the embedding data for a chunk. + """Retrieve the embedding data for a chunk. Returns: A chunk with `doc_id`, `chunk_id`, and `embedding` set. @@ -102,8 +98,7 @@ async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk: async def get_chunk_data( self, doc_id: str, chunk_id: int, include_embedding: Optional[bool] ) -> Chunk: - """ - Retrieve the text and metadata for a chunk. + """Retrieve the text and metadata for a chunk. Returns: A chunk with `doc_id`, `chunk_id`, `text`, `metadata`, and optionally @@ -112,6 +107,4 @@ async def get_chunk_data( @abstractmethod def close(self) -> None: - """ - Cleans up any open resources. - """ + """Cleans up any open resources.""" diff --git a/libs/colbert/ragstack_colbert/base_embedding_model.py b/libs/colbert/ragstack_colbert/base_embedding_model.py index 8a2fbc342..afe7c54f1 100644 --- a/libs/colbert/ragstack_colbert/base_embedding_model.py +++ b/libs/colbert/ragstack_colbert/base_embedding_model.py @@ -1,6 +1,7 @@ -""" -This module defines an abstract base class (ABC) for generating token-based embeddings -for text. +"""Base embedding for ColBERT. + +This module defines an abstract base class (ABC) for generating token-based +embeddings for text. """ from abc import ABC, abstractmethod @@ -10,8 +11,7 @@ class BaseEmbeddingModel(ABC): - """ - Abstract base class (ABC) for token-based embedding models. + """Abstract base class (ABC) for token-based embedding models. This class defines the interface for models that generate embeddings for text chunks and queries. @@ -22,11 +22,9 @@ class BaseEmbeddingModel(ABC): @abstractmethod def embed_texts(self, texts: List[str]) -> List[Embedding]: - """ - Embeds a list of texts into their corresponding vector embedding - representations. + """Embeds a list of texts into their vector embedding representations. - Parameters: + Args: texts (List[str]): A list of string texts. Returns: @@ -40,13 +38,12 @@ def embed_query( full_length_search: Optional[bool] = False, query_maxlen: int = -1, ) -> Embedding: - """ - Embeds a single query text into its vector representation. + """Embeds a single query text into its vector representation. If the query has fewer than query_maxlen tokens it will be padded with BERT special [mast] tokens. - Parameters: + Args: query (str): The query text to encode. full_length_search (Optional[bool]): Indicates whether to encode the query for a full-length search. Defaults to False. diff --git a/libs/colbert/ragstack_colbert/base_retriever.py b/libs/colbert/ragstack_colbert/base_retriever.py index 6c19726b0..ff50383af 100644 --- a/libs/colbert/ragstack_colbert/base_retriever.py +++ b/libs/colbert/ragstack_colbert/base_retriever.py @@ -1,4 +1,5 @@ -""" +"""Base retriever module. + This module defines abstract base classes for implementing retrieval mechanisms for text chunk embeddings, specifically designed to work with ColBERT or similar embedding models. @@ -11,9 +12,10 @@ class BaseRetriever(ABC): - """ - Abstract base class (ABC) for a retrieval system that operates on a ColBERT vector - store, facilitating the search and retrieval of text chunks based on query + """Base Retriever abstract class for ColBERT. + + Abstract base class (ABC) for a retrieval system that operates on a ColBERT + vector store, facilitating the search and retrieval of text chunks based on query embeddings. """ @@ -26,11 +28,12 @@ def embedding_search( include_embedding: Optional[bool] = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """ + """Search for relevant text chunks based on a query embedding. + Retrieves a list of text chunks relevant to a given query from the vector store, ranked by relevance or other metrics. - Parameters: + Args: query_embedding (Embedding): The query embedding to search for relevant text chunks. k (Optional[int]): The number of top results to retrieve. @@ -54,11 +57,12 @@ async def aembedding_search( include_embedding: Optional[bool] = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """ + """Search for relevant text chunks based on a query embedding. + Retrieves a list of text chunks relevant to a given query from the vector store, ranked by relevance or other metrics. - Parameters: + Args: query_embedding (Embedding): The query embedding to search for relevant text chunks. k (Optional[int]): The number of top results to retrieve. @@ -83,11 +87,12 @@ def text_search( include_embedding: Optional[bool] = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """ + """Search for relevant text chunks based on a query text. + Retrieves a list of text chunks relevant to a given query from the vector store, ranked by relevance or other metrics. - Parameters: + Args: query_text (str): The query text to search for relevant text chunks. k (Optional[int]): The number of top results to retrieve. query_maxlen (Optional[int]): The maximum length of the query to consider. @@ -113,11 +118,12 @@ async def atext_search( include_embedding: Optional[bool] = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """ + """Search for relevant text chunks based on a query text. + Retrieves a list of text chunks relevant to a given query from the vector store, ranked by relevance or other metrics. - Parameters: + Args: query_text (str): The query text to search for relevant text chunks. k (Optional[int]): The number of top results to retrieve. query_maxlen (Optional[int]): The maximum length of the query to consider. diff --git a/libs/colbert/ragstack_colbert/base_vector_store.py b/libs/colbert/ragstack_colbert/base_vector_store.py index 7f8dafeed..1f95ccaf6 100644 --- a/libs/colbert/ragstack_colbert/base_vector_store.py +++ b/libs/colbert/ragstack_colbert/base_vector_store.py @@ -1,4 +1,5 @@ -""" +"""Base Vector Store module for ColBERT. + This module defines the abstract base class for a standard vector store specifically designed to work with ColBERT or similar dense embedding models, and can be used to create a LangChain or LlamaIndex ColBERT vector store. @@ -23,7 +24,8 @@ class BaseVectorStore(ABC): - """ + """Base Vector Store abstract class for ColBERT. + Abstract base class (ABC) for a storage system designed to hold vector representations of text chunks, typically generated by a ColBERT model or similar embedding model. @@ -36,10 +38,9 @@ class BaseVectorStore(ABC): # handles LlamaIndex add @abstractmethod def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: - """ - Stores a list of embedded text chunks in the vector store + """Stores a list of embedded text chunks in the vector store. - Parameters: + Args: chunks (List[Chunk]): A list of `Chunk` instances to be stored. Returns: @@ -54,11 +55,12 @@ def add_texts( metadatas: Optional[List[Metadata]], doc_id: Optional[str] = None, ) -> List[Tuple[str, int]]: - """ + """Adds text chunks to the vector store. + Embeds and stores a list of text chunks and optional metadata into the vector store. - Parameters: + Args: texts (List[str]): The list of text chunks to be embedded metadatas (Optional[List[Metadata]])): An optional list of Metadata to be stored. If provided, these are set 1 to 1 with the texts list. @@ -72,10 +74,9 @@ def add_texts( # handles LangChain and LlamaIndex delete @abstractmethod def delete_chunks(self, doc_ids: List[str]) -> bool: - """ - Deletes chunks from the vector store based on their document id. + """Deletes chunks from the vector store based on their document id. - Parameters: + Args: doc_ids (List[str]): A list of document identifiers specifying the chunks to be deleted. @@ -88,10 +89,9 @@ def delete_chunks(self, doc_ids: List[str]) -> bool: async def aadd_chunks( self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100 ) -> List[Tuple[str, int]]: - """ - Stores a list of embedded text chunks in the vector store + """Stores a list of embedded text chunks in the vector store. - Parameters: + Args: chunks (List[Chunk]): A list of `Chunk` instances to be stored. concurrent_inserts (Optional[int]): How many concurrent inserts to make to the database. Defaults to 100. @@ -109,11 +109,12 @@ async def aadd_texts( doc_id: Optional[str] = None, concurrent_inserts: Optional[int] = 100, ) -> List[Tuple[str, int]]: - """ + """Adds text chunks to the vector store. + Embeds and stores a list of text chunks and optional metadata into the vector store. - Parameters: + Args: texts (List[str]): The list of text chunks to be embedded metadatas (Optional[List[Metadata]])): An optional list of Metadata to be stored. If provided, these are set 1 to 1 with the texts list. @@ -131,10 +132,9 @@ async def aadd_texts( async def adelete_chunks( self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100 ) -> bool: - """ - Deletes chunks from the vector store based on their document id. + """Deletes chunks from the vector store based on their document id. - Parameters: + Args: doc_ids (List[str]): A list of document identifiers specifying the chunks to be deleted. concurrent_deletes (Optional[int]): How many concurrent deletes to make to @@ -147,6 +147,4 @@ async def adelete_chunks( # handles LangChain as_retriever @abstractmethod def as_retriever(self) -> BaseRetriever: - """ - Gets a retriever using the vector store. - """ + """Gets a retriever using the vector store.""" diff --git a/libs/colbert/ragstack_colbert/cassandra_database.py b/libs/colbert/ragstack_colbert/cassandra_database.py index c14d65c1e..790468227 100644 --- a/libs/colbert/ragstack_colbert/cassandra_database.py +++ b/libs/colbert/ragstack_colbert/cassandra_database.py @@ -1,4 +1,5 @@ -""" +"""Casandra Database. + This module provides an implementation of the BaseVectorStore abstract class, specifically designed for use with a Cassandra database backend. It allows for the efficient storage and management of text embeddings generated by a ColBERT model, @@ -21,11 +22,12 @@ class CassandraDatabaseError(Exception): - pass + """Exception raised for errors in the CassandraDatabase class.""" class CassandraDatabase(BaseDatabase): - """ + """Casandra Database. + An implementation of the BaseDatabase abstract base class using Cassandra as the backend storage system. This class provides methods to store, retrieve, and manage text embeddings within a Cassandra database, specifically designed for handling @@ -37,7 +39,7 @@ class CassandraDatabase(BaseDatabase): _table: ClusteredMetadataVectorCassandraTable - def __new__(cls): + def __new__(cls): # noqa: D102 raise ValueError( "This class cannot be instantiated directly. " "Please use the `from_astra()` or `from_session()` class methods." @@ -52,6 +54,7 @@ def from_astra( table_name: Optional[str] = "colbert", timeout: Optional[int] = 300, ): + """Creates a CassandraVectorStore using AstraDB connection info.""" cassio.init(token=astra_token, database_id=database_id, keyspace=keyspace) session = cassio.config.resolve_session() session.default_timeout = timeout @@ -67,6 +70,7 @@ def from_session( keyspace: Optional[str] = "default_keyspace", table_name: Optional[str] = "colbert", ): + """Creates a CassandraVectorStore using an existing session.""" instance = super().__new__(cls) instance._initialize(session=session, keyspace=keyspace, table_name=table_name) # noqa: SLF001 return instance @@ -77,10 +81,9 @@ def _initialize( keyspace: str, table_name: str, ): - """ - Initializes a new instance of the CassandraVectorStore. + """Initializes a new instance of the CassandraVectorStore. - Parameters: + Args: session (Session): The Cassandra session to use. keyspace (str): The keyspace in which the table exists or will be created. table_name (str): The name of the table to use or create for storing @@ -88,7 +91,6 @@ def _initialize( timeout (int, optional): The default timeout in seconds for Cassandra operations. Defaults to 180. """ - try: is_astra = session.cluster.cloud except AttributeError: @@ -126,16 +128,14 @@ def _log_insert_error( ) def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: - """ - Stores a list of embedded text chunks in the vector store + """Stores a list of embedded text chunks in the vector store. - Parameters: + Args: chunks (List[Chunk]): A list of `Chunk` instances to be stored. Returns: a list of tuples: (doc_id, chunk_id) """ - failed_chunks: List[Tuple[str, int]] = [] success_chunks: List[Tuple[str, int]] = [] @@ -212,10 +212,9 @@ async def _limited_put( async def aadd_chunks( self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100 ) -> List[Tuple[str, int]]: - """ - Stores a list of embedded text chunks in the vector store + """Stores a list of embedded text chunks in the vector store. - Parameters: + Args: chunks (List[Chunk]): A list of `Chunk` instances to be stored. concurrent_inserts (Optional[int]): How many concurrent inserts to make to the database. Defaults to 100. @@ -284,17 +283,15 @@ async def aadd_chunks( return outputs def delete_chunks(self, doc_ids: List[str]) -> bool: - """ - Deletes chunks from the vector store based on their document id. + """Deletes chunks from the vector store based on their document id. - Parameters: + Args: doc_ids (List[str]): A list of document identifiers specifying the chunks to be deleted. Returns: True if the all the deletes were successful. """ - failed_docs: List[str] = [] for doc_id in doc_ids: @@ -327,10 +324,9 @@ async def _limited_delete( async def adelete_chunks( self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100 ) -> bool: - """ - Deletes chunks from the vector store based on their document id. + """Deletes chunks from the vector store based on their document id. - Parameters: + Args: doc_ids (List[str]): A list of document identifiers specifying the chunks to be deleted. concurrent_deletes (Optional[int]): How many concurrent deletes to make @@ -339,7 +335,6 @@ async def adelete_chunks( Returns: True if the all the deletes were successful. """ - semaphore = asyncio.Semaphore(concurrent_deletes) all_tasks = [ self._limited_delete( @@ -369,14 +364,12 @@ async def adelete_chunks( return success async def search_relevant_chunks(self, vector: Vector, n: int) -> List[Chunk]: - """ - Retrieves 'n' ANN results for an embedded token vector. + """Retrieves 'n' ANN results for an embedded token vector. Returns: A list of Chunks with only `doc_id` and `chunk_id` set. Fewer than 'n' results may be returned. """ - chunks: Set[Chunk] = set() # TODO: only return partition_id and row_id after cassio supports this @@ -391,13 +384,11 @@ async def search_relevant_chunks(self, vector: Vector, n: int) -> List[Chunk]: return list(chunks) async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk: - """ - Retrieve the embedding data for a chunk. + """Retrieve the embedding data for a chunk. Returns: A chunk with `doc_id`, `chunk_id`, and `embedding` set. """ - row_id = (chunk_id, Predicate(PredicateOperator.GT, -1)) rows = await self._table.aget_partition(partition_id=doc_id, row_id=row_id) @@ -408,13 +399,11 @@ async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk: async def get_chunk_data( self, doc_id: str, chunk_id: int, include_embedding: Optional[bool] = False ) -> Chunk: - """ - Retrieve the text and metadata for a chunk. + """Retrieve the text and metadata for a chunk. Returns: A chunk with `doc_id`, `chunk_id`, `text`, and `metadata` set. """ - row_id = (chunk_id, Predicate(PredicateOperator.EQ, -1)) row = await self._table.aget(partition_id=doc_id, row_id=row_id) @@ -435,6 +424,4 @@ async def get_chunk_data( ) def close(self) -> None: - """ - Cleans up any open resources. - """ + """Cleans up any open resources.""" diff --git a/libs/colbert/ragstack_colbert/colbert_embedding_model.py b/libs/colbert/ragstack_colbert/colbert_embedding_model.py index d283ca6e5..42fcf2179 100644 --- a/libs/colbert/ragstack_colbert/colbert_embedding_model.py +++ b/libs/colbert/ragstack_colbert/colbert_embedding_model.py @@ -1,4 +1,5 @@ -""" +"""ColBERT Embedding Model. + This module integrates the ColBERT model with token embedding functionalities, offering tools for efficiently encoding queries and text chunks into dense vector representations. It facilitates semantic search and retrieval by providing optimized @@ -20,9 +21,10 @@ class ColbertEmbeddingModel(BaseEmbeddingModel): - """ - A class for generating token embeddings using a ColBERT model. This class provides - functionalities for encoding queries and document chunks into dense vector + """ColBERT embedding model. + + A class for generating token embeddings using a ColBERT model. This class + provides functionalities for encoding queries and document chunks into dense vector representations, facilitating semantic search and retrieval tasks. It leverages a pre-trained ColBERT model and supports distributed computing environments. @@ -44,12 +46,13 @@ def __init__( verbose: Optional[int] = 3, # 3 is the default on ColBERT checkpoint chunk_batch_size: Optional[int] = 640, ): - """ - Initializes a new instance of the ColbertEmbeddingModel class, setting up the + """Initializes a new instance of the ColbertEmbeddingModel class. + + Initializes a new instance of the ColbertEmbeddingModel class setting up the model configuration, loading the necessary checkpoints, and preparing the tokenizer and encoder. - Parameters: + Args: checkpoint (Optional[str]): Path or URL to the Colbert model checkpoint. Default is a pre-defined model. doc_maxlen (Optional[int]): Maximum number of tokens for document chunks. @@ -65,7 +68,6 @@ def __init__( embedding. Defaults to 640. **kwargs: Additional keyword arguments for future extensions. """ - if query_maxlen is None: query_maxlen = -1 @@ -83,17 +85,14 @@ def __init__( # implements the Abstract Class Method def embed_texts(self, texts: List[str]) -> List[Embedding]: - """ - Embeds a list of texts into their corresponding vector embedding - representations. + """Embeds a list of texts into their vector embedding representations. - Parameters: + Args: texts (List[str]): A list of string texts. Returns: List[Embedding]: A list of embeddings, in the order of the input list """ - chunks = [ Chunk(doc_id="dummy", chunk_id=i, text=t) for i, t in enumerate(texts) ] @@ -115,13 +114,12 @@ def embed_query( full_length_search: Optional[bool] = False, query_maxlen: Optional[int] = None, ) -> Embedding: - """ - Embeds a single query text into its vector representation. + """Embeds a single query text into its vector representation. If the query has fewer than query_maxlen tokens it will be padded with BERT special [mast] tokens. - Parameters: + Args: query (str): The query string to encode. full_length_search (Optional[bool]): Indicates whether to encode the query for a full-length search. Defaults to False. @@ -131,7 +129,6 @@ def embed_query( Returns: Embedding: A vector embedding representation of the query text """ - if query_maxlen is None: query_maxlen = -1 diff --git a/libs/colbert/ragstack_colbert/colbert_retriever.py b/libs/colbert/ragstack_colbert/colbert_retriever.py index 82a4db706..6676c9b91 100644 --- a/libs/colbert/ragstack_colbert/colbert_retriever.py +++ b/libs/colbert/ragstack_colbert/colbert_retriever.py @@ -1,4 +1,5 @@ -""" +"""ColBERT Retriever Module. + This module integrates text embedding retrieval and similarity computation functionalities with a vector store backend, optimized for high-performance operations in large-scale text retrieval applications. @@ -25,8 +26,7 @@ def all_gpus_support_fp16(is_cuda: Optional[bool] = False): - """ - Check if all available GPU devices support FP16 (half-precision) operations. + """Check if all available GPU devices support FP16 (half-precision) operations. Returns: bool: True if all GPUs support FP16, False otherwise. @@ -58,11 +58,12 @@ def max_similarity_torch( is_cuda: Optional[bool] = False, is_fp16: Optional[bool] = False, ) -> float: - """ - Calculates the maximum similarity (dot product) between a query vector and a chunk - embedding, leveraging PyTorch for efficient computation. + """Calculates the maximum similarity for a query vector and a chunk embedding. + + Calculates the maximum similarity (dot product) between a query vector and a + chunk embedding, leveraging PyTorch for efficient computation. - Parameters: + Args: query_vector (Vector): A list of float representing the query text. chunk_embedding (Embedding): A list of Vector, each representing an chunk embedding vector. @@ -80,7 +81,6 @@ def max_similarity_torch( This function is designed to run on GPU for enhanced performance but can also execute on CPU. """ - # Convert inputs to tensors query_tensor = torch.Tensor(query_vector) embedding_tensor = torch.stack([torch.Tensor(v) for v in chunk_embedding]) @@ -106,6 +106,7 @@ def max_similarity_torch( def get_trace(e: Exception) -> str: + """Extracts the traceback information from an exception.""" trace = "" tb = e.__traceback__ while tb is not None: @@ -115,20 +116,18 @@ def get_trace(e: Exception) -> str: class ColbertRetriever(BaseRetriever): - """ - A retriever class that implements the retrieval of text chunks from a vector store, - based on their semantic similarity to a given query. + """ColBERT Retriever. + + A retriever class that implements the retrieval of text chunks from a vector + store, based on their semantic similarity to a given query. This implementation leverages the ColBERT model for generating embeddings of the query. - Attributes: - vector_store (BaseVectorStore): The vector store instance where chunks are - stored. - embedding_model (BaseEmbeddingModel): The ColBERT embeddings model for - encoding queries. - is_cuda (bool): A flag indicating whether to use CUDA (GPU) for computation. - is_fp16 (bool): A flag indicating whether to half-precision floating point - operations on CUDA (GPU). Has no effect on CPU computation. + Args: + database (BaseDatabase): The data store to be used for retrieving + embeddings. + embedding_model (BaseEmbeddingModel): The ColBERT embeddings model to be + used for encoding queries. Note: The class is designed to work with a GPU for optimal performance but will @@ -141,6 +140,8 @@ class ColbertRetriever(BaseRetriever): _is_fp16: bool class Config: + """Pydantic configuration for the ColbertRetriever class.""" + arbitrary_types_allowed = True def __init__( @@ -148,34 +149,18 @@ def __init__( database: BaseDatabase, embedding_model: BaseEmbeddingModel, ): - """ - Initializes the retriever with a specific vector store and Colbert embeddings - model. - - Parameters: - database (BaseDatabase): The data store to be used for retrieving - embeddings. - embedding_model (BaseEmbeddingModel): The ColBERT embeddings model to be - used for encoding queries. - """ - self._database = database self._embedding_model = embedding_model self._is_cuda = torch.cuda.is_available() self._is_fp16 = all_gpus_support_fp16(self._is_cuda) def close(self) -> None: - """ - Closes any open resources held by the retriever. - """ + """Closes any open resources held by the retriever.""" async def _query_relevant_chunks( self, query_embedding: Embedding, top_k: int ) -> Set[Chunk]: - """ - Retrieves the top_k ANN Chunks (`doc_id` and `chunk_id` only) for each embedded - query token. - """ + """Queries for the top_k most relevant chunks for each query token.""" chunks: Set[Chunk] = set() # Collect all tasks tasks = [ @@ -198,9 +183,7 @@ async def _query_relevant_chunks( return chunks async def _get_chunk_embeddings(self, chunks: Set[Chunk]) -> List[Chunk]: - """ - Retrieves Chunks with `doc_id`, `chunk_id`, and `embedding` set. - """ + """Retrieves Chunks with `doc_id`, `chunk_id`, and `embedding` set.""" # Collect all tasks tasks = [ self._database.get_chunk_embedding(doc_id=c.doc_id, chunk_id=c.chunk_id) @@ -222,9 +205,7 @@ async def _get_chunk_embeddings(self, chunks: Set[Chunk]) -> List[Chunk]: def _score_chunks( self, query_embedding: Embedding, chunk_embeddings: List[Chunk] ) -> Dict[Chunk, float]: - """ - Process the retrieved chunk data to calculate scores. - """ + """Process the retrieved chunk data to calculate scores.""" chunk_scores = {} for chunk in chunk_embeddings: chunk_scores[chunk] = sum( @@ -243,14 +224,12 @@ async def _get_chunk_data( chunks: List[Chunk], include_embedding: Optional[bool] = False, ) -> List[Chunk]: - """ - Fetches text and metadata for each chunk. + """Fetches text and metadata for each chunk. Returns: List[Chunk]: A list of chunks with `doc_id`, `chunk_id`, `text`, `metadata`, and optionally `embedding` set. """ - # Collect all tasks tasks = [ self._database.get_chunk_data( @@ -281,11 +260,12 @@ async def atext_search( include_embedding: Optional[bool] = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """ + """Searches for relevant text chunks based on a given query text. + Retrieves a list of text chunks most relevant to the given query, using semantic similarity as the criteria. - Parameters: + Args: query_text (str): The query text to search for relevant text chunks. k (Optional[int]): The number of top results to retrieve. Default 5. query_maxlen (Optional[int]): The maximum length of the query to consider. @@ -300,7 +280,6 @@ async def atext_search( each representing a text chunk that is relevant to the query, along with its similarity score. """ - query_embedding = self._embedding_model.embed_query( query=query_text, query_maxlen=query_maxlen ) @@ -320,14 +299,15 @@ async def aembedding_search( include_embedding: Optional[bool] = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """ + """Searches for relevant text chunks based on a given query embedding. + Retrieves a list of text chunks most relevant to the given query, using semantic similarity as the criteria. - Parameters: + Args: query_embedding (Embedding): The query embedding to search for relevant text chunks. - k (Optional[int]): The number of top results to retrieve. + k (Optional[int]): The number of top results to retrieve. Default 5. include_embedding (Optional[bool]): Optional (default False) flag to include the embedding vectors in the returned chunks **kwargs (Any): Additional parameters that implementations might require @@ -338,7 +318,6 @@ async def aembedding_search( each representing a text chunk that is relevant to the query, along with its similarity score. """ - top_k = max(math.floor(len(query_embedding) / 2), 16) logging.debug( "based on query length of %s tokens, retrieving %s results per " @@ -384,11 +363,12 @@ def text_search( include_embedding: Optional[bool] = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """ + """Searches for relevant text chunks based on a given query text. + Retrieves a list of text chunks relevant to a given query from the vector store, ranked by relevance or other metrics. - Parameters: + Args: query_text (str): The query text to search for relevant text chunks. k (Optional[int]): The number of top results to retrieve. Default 5. query_maxlen (Optional[int]): The maximum length of the query to consider. @@ -403,7 +383,6 @@ def text_search( each representing a text chunk that is relevant to the query, along with its similarity score. """ - return asyncio.run( self.atext_search( query_text=query_text, @@ -422,11 +401,12 @@ def embedding_search( include_embedding: Optional[bool] = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """ + """Searches for relevant text chunks based on a given query embedding. + Retrieves a list of text chunks relevant to a given query from the vector store, ranked by relevance or other metrics. - Parameters: + Args: query_embedding (Embedding): The query embedding to search for relevant text chunks. k (Optional[int]): The number of top results to retrieve. Default 5. @@ -440,7 +420,6 @@ def embedding_search( each representing a text chunk that is relevant to the query, along with its similarity score. """ - return asyncio.run( self.aembedding_search( query_embedding=query_embedding, diff --git a/libs/colbert/ragstack_colbert/colbert_vector_store.py b/libs/colbert/ragstack_colbert/colbert_vector_store.py index 95ea10ea5..49c38572e 100644 --- a/libs/colbert/ragstack_colbert/colbert_vector_store.py +++ b/libs/colbert/ragstack_colbert/colbert_vector_store.py @@ -1,4 +1,5 @@ -""" +"""ColBERT Vector Store. + This module provides an implementation of the BaseVectorStore abstract class, specifically designed for use with a Cassandra database backend. It allows for the efficient storage and management of text embeddings @@ -18,8 +19,12 @@ class ColbertVectorStore(BaseVectorStore): - """ - An implementation of the BaseVectorStore abstract base class. + """A vector store implementation for ColBERT. + + Args: + database (BaseDatabase): The database to use for storage + embedding_model (Optional[BaseEmbeddingModel]): The embedding model to use + for embedding text and queries. """ _database: BaseDatabase @@ -30,15 +35,6 @@ def __init__( database: BaseDatabase, embedding_model: Optional[BaseEmbeddingModel] = None, ): - """ - Initializes a new instance of the ColbertVectorStore. - - Parameters: - database (BaseDatabase): The database to use for storage - embedding_model (Optional[BaseEmbeddingModel]): The embedding model to use - for embedding text and queries. - """ - self._database = database self._embedding_model = embedding_model @@ -79,16 +75,14 @@ def _build_chunks( # implements the abc method to handle LlamaIndex add def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: - """ - Stores a list of embedded text chunks in the vector store + """Stores a list of embedded text chunks in the vector store. - Parameters: + Args: chunks (List[Chunk]): A list of `Chunk` instances to be stored. Returns: a list of tuples: (doc_id, chunk_id) """ - return self._database.add_chunks(chunks=chunks) # implements the abc method to handle LangChain add @@ -98,11 +92,12 @@ def add_texts( metadatas: Optional[List[Metadata]] = None, doc_id: Optional[str] = None, ) -> List[Tuple[str, int]]: - """ + """Adds text chunks to the vector store. + Embeds and stores a list of text chunks and optional metadata into the vector store. - Parameters: + Args: texts: The list of text chunks to be embedded metadatas: An optional list of Metadata to be stored. If provided, these are set 1 to 1 with the texts list. @@ -117,26 +112,23 @@ def add_texts( # implements the abc method to handle LangChain and LlamaIndex delete def delete_chunks(self, doc_ids: List[str]) -> bool: - """ - Deletes chunks from the vector store based on their document id. + """Deletes chunks from the vector store based on their document id. - Parameters: + Args: doc_ids: A list of document identifiers specifying the chunks to be deleted. Returns: True if the all the deletes were successful. """ - return self._database.delete_chunks(doc_ids=doc_ids) # implements the abc method to handle LlamaIndex add async def aadd_chunks( self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100 ) -> List[Tuple[str, int]]: - """ - Stores a list of embedded text chunks in the vector store + """Stores a list of embedded text chunks in the vector store. - Parameters: + Args: chunks: A list of `Chunk` instances to be stored. concurrent_inserts: How many concurrent inserts to make to the database. Defaults to 100. @@ -144,7 +136,6 @@ async def aadd_chunks( Returns: a list of tuples: (doc_id, chunk_id) """ - return await self._database.aadd_chunks( chunks=chunks, concurrent_inserts=concurrent_inserts ) @@ -157,11 +148,12 @@ async def aadd_texts( doc_id: Optional[str] = None, concurrent_inserts: Optional[int] = 100, ) -> List[Tuple[str, int]]: - """ + """Adds text chunks to the vector store. + Embeds and stores a list of text chunks and optional metadata into the vector store. - Parameters: + Args: texts (List[str]): The list of text chunks to be embedded metadatas: An optional list of Metadata to be stored. If provided, these are set 1 to 1 with the texts list. @@ -182,10 +174,9 @@ async def aadd_texts( async def adelete_chunks( self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100 ) -> bool: - """ - Deletes chunks from the vector store based on their document id. + """Deletes chunks from the vector store based on their document id. - Parameters: + Args: doc_ids: A list of document identifiers specifying the chunks to be deleted. concurrent_deletes: How many concurrent deletes to make to the database. Defaults to 100. @@ -198,10 +189,7 @@ async def adelete_chunks( ) def as_retriever(self) -> BaseRetriever: - """ - Gets a retriever using the vector store. - """ - + """Gets a retriever using the vector store.""" self._validate_embedding_model() return ColbertRetriever( database=self._database, embedding_model=self._embedding_model diff --git a/libs/colbert/ragstack_colbert/constant.py b/libs/colbert/ragstack_colbert/constant.py index fe053aaaf..5af6dcd93 100644 --- a/libs/colbert/ragstack_colbert/constant.py +++ b/libs/colbert/ragstack_colbert/constant.py @@ -1,4 +1,5 @@ -""" +"""Constants for configuring and operating the ColBERT model. + Defines constants used across the system for configuring and operating the ColBERT model for semantic search and retrieval tasks. diff --git a/libs/colbert/ragstack_colbert/objects.py b/libs/colbert/ragstack_colbert/objects.py index ffcd543d7..150dc3e1f 100644 --- a/libs/colbert/ragstack_colbert/objects.py +++ b/libs/colbert/ragstack_colbert/objects.py @@ -1,6 +1,7 @@ -""" -This module defines a set of data classes for handling chunks of text in various stages -of processing within the ColBERT retrieval system. +"""Objects for handling chunks of text in the ColBERT retrieval system. + +This module defines a set of data classes for handling chunks of text in various +stages of processing within the ColBERT retrieval system. """ from typing import Any, Dict, List, Optional @@ -25,6 +26,8 @@ class Chunk(BaseModel): + """A chunk of text with associated metadata and embedding.""" + doc_id: str = Field(..., description="id of the parent document", frozen=True) chunk_id: int = Field(..., description="id of the chunk", frozen=True, ge=0) text: str = Field(default=None, description="text of the chunk") @@ -36,6 +39,8 @@ class Chunk(BaseModel): ) class Config: + """Pydantic configuration for the Chunk class.""" + validate_assignment = True # Define equality based on doc_id and chunk_id only diff --git a/libs/colbert/ragstack_colbert/text_encoder.py b/libs/colbert/ragstack_colbert/text_encoder.py index e1d319b0d..5445c0fb5 100644 --- a/libs/colbert/ragstack_colbert/text_encoder.py +++ b/libs/colbert/ragstack_colbert/text_encoder.py @@ -1,4 +1,5 @@ -""" +"""Text encoder for ColBERT. + This module provides functionalities to encode text chunks into dense vector representations using a ColBERT model. It supports encoding chunks in batches to efficiently manage memory usage and prevent out-of-memory errors when processing large @@ -18,11 +19,12 @@ def calculate_query_maxlen(tokens: List[List[str]]) -> int: - """ + """Calculates maximum query length. + Calculates an appropriate maximum query length for token embeddings, based on the length of the tokenized input. - Parameters: + Args: tokens (List[List[str]]): A nested list where each sublist contains tokens from a single query or chunk. @@ -30,7 +32,6 @@ def calculate_query_maxlen(tokens: List[List[str]]) -> int: int: The calculated maximum length for query tokens, adhering to the specified minimum and maximum bounds, and adjusted to the nearest power of two. """ - max_token_length = max(len(inner_list) for inner_list in tokens) # tokens from the query tokenizer does not include the SEP, CLS @@ -41,22 +42,18 @@ def calculate_query_maxlen(tokens: List[List[str]]) -> int: class TextEncoder: - """ + """Text encoder for ColBERT. + Encapsulates the logic for encoding text chunks and queries into dense vector representations using a specified ColBERT model configuration and checkpoint. This class is optimized for batch processing to manage GPU memory usage efficiently. + + Args: + config (ColBERTConfig): The configuration for the Colbert model. + verbose (int): The level of logging to use """ def __init__(self, config: ColBERTConfig, verbose: Optional[int] = 3) -> None: - """ - Initializes the ChunkEncoder with a given ColBERT model configuration and - checkpoint. - - Parameters: - config (ColBERTConfig): The configuration for the Colbert model. - verbose (int): The level of logging to use - """ - logging.info("Cuda enabled GPU available: %s", torch.cuda.is_available()) self._checkpoint = Checkpoint( @@ -65,12 +62,13 @@ def __init__(self, config: ColBERTConfig, verbose: Optional[int] = 3) -> None: self._use_cpu = config.total_visible_gpus == 0 def encode_chunks(self, chunks: List[Chunk], batch_size: int = 640) -> List[Chunk]: - """ - Encodes a list of chunks into embeddings, processing in batches to efficiently - manage memory. + """Encodes a list of chunks into embeddings. + + Encodes a list of chunks into embeddings, processing in batches to + efficiently manage memory. - Parameters: - texts (List[str]): The text chunks to encode. + Args: + chunks (List[str]): The text chunks to encode. batch_size (int): The size of batches for processing to avoid memory overflow. Defaults to 64. @@ -78,7 +76,6 @@ def encode_chunks(self, chunks: List[Chunk], batch_size: int = 640) -> List[Chun A tuple containing the concatenated tensor of embeddings and a list of document lengths. """ - logging.debug("#> Encoding %s chunks..", len(chunks)) embedded_chunks: List[Chunk] = [] @@ -112,6 +109,7 @@ def encode_chunks(self, chunks: List[Chunk], batch_size: int = 640) -> List[Chun def encode_query( self, text: str, query_maxlen: int, full_length_search: Optional[bool] = False ) -> Embedding: + """Encodes a query into an embedding.""" if query_maxlen < 0: tokens = self._checkpoint.query_tokenizer.tokenize([text]) query_maxlen = calculate_query_maxlen(tokens) diff --git a/libs/e2e-tests/e2e_tests/langchain/nemo_guardrails.py b/libs/e2e-tests/e2e_tests/langchain/nemo_guardrails.py index c5a437bf3..2a700ef52 100644 --- a/libs/e2e-tests/e2e_tests/langchain/nemo_guardrails.py +++ b/libs/e2e-tests/e2e_tests/langchain/nemo_guardrails.py @@ -51,9 +51,7 @@ def __init__(self, retriever): self.retriever = retriever async def rag_using_lc(self, context: dict, llm: BaseLLM) -> ActionResult: - """ - Defines the custom rag action - """ + """Defines the custom rag action""" user_message = context.get("last_user_message") context_updates = {} diff --git a/libs/e2e-tests/e2e_tests/langchain/trulens.py b/libs/e2e-tests/e2e_tests/langchain/trulens.py index 0c1825c70..6d8923b88 100644 --- a/libs/e2e-tests/e2e_tests/langchain/trulens.py +++ b/libs/e2e-tests/e2e_tests/langchain/trulens.py @@ -54,9 +54,7 @@ def _create_chain(retriever: VectorStoreRetriever, llm: BaseLanguageModel) -> Ru def run_trulens_evaluation(vector_store: VectorStore, llm: BaseLanguageModel): - """ - Executes the TruLens evaluation process. - """ + """Executes the TruLens evaluation process.""" vector_store.add_texts(SAMPLE_DATA) _initialize_tru() retriever = vector_store.as_retriever() diff --git a/libs/e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py b/libs/e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py index e96194383..ffa60bde7 100644 --- a/libs/e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py +++ b/libs/e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py @@ -54,15 +54,11 @@ def __init__(self, delete_function: Callable, max_workers=5): self.semaphore = threading.Semaphore(max_workers) def get_current_deletions(self): - """ - Returns the number of ongoing deletions. - """ + """Returns the number of ongoing deletions.""" return self.max_workers - self.semaphore._value # noqa: SLF001 def await_ongoing_deletions_completed(self): - """ - Blocks until all ongoing deletions are completed. - """ + """Blocks until all ongoing deletions are completed.""" pending_deletions = self.max_workers - self.semaphore._value # noqa: SLF001 while pending_deletions >= 0: logging.debug( @@ -72,9 +68,8 @@ def await_ongoing_deletions_completed(self): return def run_delete(self, collection: str): - """ - Runs a delete_collection in the background, blocking if max_workers are already - running. + """Runs a delete_collection in the background, blocking if max_workers are + already running. """ self.semaphore.acquire() # Wait for a free thread return self.executor.submit( @@ -82,9 +77,8 @@ def run_delete(self, collection: str): ) def _run_and_release(self, collection: str): - """ - Internal wrapper to run the delete function and release the semaphore once done. - """ + """Internal wrapper to run the delete function and release the semaphore once + done.""" try: logging.info("deleting collection %s", collection) self.delete_function(collection) @@ -93,9 +87,7 @@ def _run_and_release(self, collection: str): self.semaphore.release() def shutdown(self, wait=True): - """ - Shuts down the executor, waiting for tasks to complete if specified. - """ + """Shuts down the executor, waiting for tasks to complete if specified.""" self.executor.shutdown(wait=wait) diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/cassandra_graph_store.py b/libs/knowledge-graph/ragstack_knowledge_graph/cassandra_graph_store.py index 9d08edc50..4218007df 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/cassandra_graph_store.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/cassandra_graph_store.py @@ -26,6 +26,8 @@ def _node(node: LangChainNode) -> Node: class CassandraGraphStore(GraphStore): + """A Cassandra-based graph store.""" + def __init__( self, node_table: str = "entities", @@ -34,8 +36,7 @@ def __init__( session: Optional[Session] = None, keyspace: Optional[str] = None, ) -> None: - """ - Create a Cassandra Graph Store. + """Create a Cassandra Graph Store. Before calling this, you must initialize cassio with `cassio.init`, or provide valid session and keyspace values. @@ -62,6 +63,7 @@ def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: # noqa: @override @property + @override def get_schema(self) -> str: raise NotImplementedError @@ -75,13 +77,14 @@ def refresh_schema(self) -> None: raise NotImplementedError def as_runnable(self, steps: int = 3, edge_filters: Sequence[str] = ()) -> Runnable: - """ - Return a runnable that retrieves the sub-graph near the - input entity or entities. + """Convert to a runnable. + + Returns a runnable that retrieves the sub-graph near the input entity or + entities. - Parameters: - - steps: The maximum distance to follow from the starting points. - - edge_filters: Predicates to use for filtering the edges. + Args: + steps: The maximum distance to follow from the starting points. + edge_filters: Predicates to use for filtering the edges. """ return RunnableLambda( func=self.graph.traverse, afunc=self.graph.atraverse diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/extraction.py b/libs/knowledge-graph/ragstack_knowledge_graph/extraction.py index a4b9da164..d65bf876f 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/extraction.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/extraction.py @@ -31,6 +31,8 @@ def _format_example(idx: int, example: Example) -> str: class KnowledgeSchemaExtractor: + """Extracts knowledge graphs from documents.""" + def __init__( self, llm: BaseChatModel, @@ -89,6 +91,7 @@ def _process_response( return document def extract(self, documents: List[Document]) -> List[GraphDocument]: + """Extract knowledge graphs from a list of documents.""" # TODO: Define an async version of extraction? responses = self._chain.batch_as_completed( [{"input": doc.page_content} for doc in documents] diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py index 2aecd7592..adcc18b3a 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py @@ -30,6 +30,20 @@ def _parse_node(row) -> Node: class CassandraKnowledgeGraph: + """Cassandra Knowledge Graph. + + Args: + node_table: Name of the table containing nodes. Defaults to `"entities"`. + edge_table: Name of the table containing edges. Defaults to ` + "relationships"`. + text_embeddings: Name of the embeddings to use, if any. + session: The Cassandra `Session` to use. If not specified, uses the default + `cassio` session, which requires `cassio.init` has been called. + keyspace: The Cassandra keyspace to use. If not specified, uses the default + `cassio` keyspace, which requires `cassio.init` has been called. + apply_schema: If true, the node table and edge table are created. + """ + def __init__( self, node_table: str = "entities", @@ -39,21 +53,6 @@ def __init__( keyspace: Optional[str] = None, apply_schema: bool = True, ) -> None: - """ - Create a Cassandra Knowledge Graph. - - Args: - node_table: Name of the table containing nodes. Defaults to `"entities"`. - edge_table: Name of the table containing edges. Defaults to ` - "relationships"`. - text_embeddings: Name of the embeddings to use, if any. - session: The Cassandra `Session` to use. If not specified, uses the default - `cassio` session, which requires `cassio.init` has been called. - keyspace: The Cassandra keyspace to use. If not specified, uses the default - `cassio` keyspace, which requires `cassio.init` has been called. - apply_schema: If true, the node table and edge table are created. - """ - session = check_resolve_session(session) keyspace = check_resolve_keyspace(keyspace) @@ -164,8 +163,7 @@ def _send_query_nearest_node(self, node: str, k: int = 1) -> ResponseFuture: # TODO: Allow filtering by node predicates and/or minimum similarity. def query_nearest_nodes(self, nodes: Iterable[str], k: int = 1) -> Iterable[Node]: - """ - For each node, return the nearest nodes in the table. + """For each node, return the nearest nodes in the table. Args: nodes: The strings to search for in the list of nodes. @@ -188,6 +186,7 @@ def insert( self, elements: Iterable[Union[Node, Relation]], ) -> None: + """Insert the given elements into the graph.""" for batch in batched(elements, n=4): from yaml import dump @@ -237,9 +236,7 @@ def subgraph( edge_filters: Sequence[str] = (), steps: int = 3, ) -> Tuple[Iterable[Node], Iterable[Relation]]: - """ - Retrieve the sub-graph from the given starting nodes. - """ + """Retrieve the sub-graph from the given starting nodes.""" edges = self.traverse(start, edge_filters, steps) # Create the set of nodes. @@ -266,9 +263,9 @@ def traverse( edge_filters: Sequence[str] = (), steps: int = 3, ) -> Iterable[Relation]: - """ - Traverse the graph from the given starting nodes and return the resulting - sub-graph. + """Traverse the graph from the given starting nodes. + + Returns the resulting sub-graph. Args: start: The starting node or nodes. @@ -298,9 +295,9 @@ async def atraverse( edge_filters: Sequence[str] = (), steps: int = 3, ) -> Iterable[Relation]: - """ - Traverse the graph from the given starting nodes and return the resulting - sub-graph. + """Traverse the graph from the given starting nodes. + + Returns the resulting sub-graph. Args: start: The starting node or nodes. diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py index 82e32f74b..8070a4851 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py @@ -8,6 +8,8 @@ class NodeSchema(BaseModel): + """Schema for a node.""" + type: str """The name of the node type.""" @@ -16,6 +18,8 @@ class NodeSchema(BaseModel): class EdgeSchema(BaseModel): + """Schema for an edge.""" + type: str """The name of the edge type.""" @@ -24,6 +28,8 @@ class EdgeSchema(BaseModel): class RelationshipSchema(BaseModel): + """Schema for a relationship.""" + edge_type: str """The name of the edge type for the relationhsip.""" @@ -38,6 +44,8 @@ class RelationshipSchema(BaseModel): class Example(BaseModel): + """An example of a graph.""" + input: str """The source input.""" @@ -49,6 +57,8 @@ class Example(BaseModel): class KnowledgeSchema(BaseModel): + """Schema for a knowledge graph.""" + nodes: List[NodeSchema] """Allowed node types for the knowledge schema.""" @@ -59,20 +69,23 @@ class KnowledgeSchema(BaseModel): def from_file(cls, path: Union[str, Path]) -> "KnowledgeSchema": """Load a KnowledgeSchema from a JSON or YAML file. - Parameters: - - path: The path to the file to load. + Args: + path: The path to the file to load. """ from pydantic_yaml import parse_yaml_file_as return parse_yaml_file_as(cls, path) def to_yaml_str(self) -> str: + """Convert the schema to a YAML string.""" from pydantic_yaml import to_yaml_str return to_yaml_str(self) class KnowledgeSchemaValidator: + """Validates graph documents against a knowledge schema.""" + def __init__(self, schema: KnowledgeSchema) -> None: self._schema = schema @@ -86,6 +99,7 @@ def __init__(self, schema: KnowledgeSchema) -> None: # source/target type should exist in nodes, edge_type should exist in edges def validate_graph_document(self, document: GraphDocument): + """Validate a graph document against the schema.""" e = ValueError("Invalid graph document for schema") for node_type in {node.type for node in document.nodes}: if node_type not in self._nodes: diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/render.py b/libs/knowledge-graph/ragstack_knowledge_graph/render.py index 15b56d059..583103d6c 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/render.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/render.py @@ -13,6 +13,7 @@ def _node_label(node: Node) -> str: def print_graph_documents( graph_documents: Union[GraphDocument, Iterable[GraphDocument]], ): + """Prints the relationships in the graph documents.""" if isinstance(graph_documents, GraphDocument): graph_documents = [graph_documents] @@ -26,6 +27,7 @@ def print_graph_documents( def render_graph_documents( graph_documents: Union[GraphDocument, Iterable[GraphDocument]], ) -> graphviz.Digraph: + """Renders the relationships in the graph documents.""" if isinstance(graph_documents, GraphDocument): graph_documents = [GraphDocument] @@ -52,6 +54,7 @@ def _node_id(node: Node) -> int: def render_knowledge_schema(knowledge_schema: KnowledgeSchema) -> graphviz.Digraph: + """Renders the knowledge schema as a graph.""" dot = graphviz.Digraph() for node in knowledge_schema.nodes: diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/runnables.py b/libs/knowledge-graph/ragstack_knowledge_graph/runnables.py index e1e152250..4393657fb 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/runnables.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/runnables.py @@ -27,8 +27,7 @@ def extract_entities( keyword_extraction_prompt: str = QUERY_ENTITY_EXTRACT_PROMPT, node_types: Optional[List[str]] = None, ) -> Runnable: - """ - Return a keyword-extraction runnable. + """Return a keyword-extraction runnable. This will expect a dictionary containing the `"question"` to extract keywords from. diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/schema_inference.py b/libs/knowledge-graph/ragstack_knowledge_graph/schema_inference.py index 6ab9c1908..6c13c84c8 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/schema_inference.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/schema_inference.py @@ -13,6 +13,8 @@ class KnowledgeSchemaInferer: + """Infers knowledge schemas from documents.""" + def __init__(self, llm: BaseChatModel) -> None: prompt = ChatPromptTemplate.from_messages( [ @@ -29,6 +31,7 @@ def __init__(self, llm: BaseChatModel) -> None: def infer_schemas_from( self, documents: Sequence[Document] ) -> Sequence[KnowledgeSchema]: + """Infer knowledge schemas from a sequence of documents.""" responses = self._chain.batch( [{"input": doc.page_content} for doc in documents] ) diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/templates.py b/libs/knowledge-graph/ragstack_knowledge_graph/templates.py index 70cf8ca67..f7ebc5a94 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/templates.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/templates.py @@ -9,6 +9,7 @@ def load_template( filename: str, **kwargs: Union[str, Callable[[], str]] ) -> PromptTemplate: + """Load a template from a file.""" template = PromptTemplate.from_file(path.join(TEMPLATE_PATH, filename)) if kwargs: template = template.partial(**kwargs) diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/traverse.py b/libs/knowledge-graph/ragstack_knowledge_graph/traverse.py index 35d0d464c..c47d60222 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/traverse.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/traverse.py @@ -13,9 +13,11 @@ class _Node(NamedTuple): class Node(_Node): + """A node in the graph.""" + __slots__ = () - def __new__(cls, name, type, properties=None): # noqa: A002 + def __new__(cls, name, type, properties=None): # noqa: A002, D102 if properties is None: properties = {} return super().__new__(cls, name, type, properties) @@ -34,6 +36,8 @@ def __eq__(self, value) -> bool: class Relation(NamedTuple): + """A relation between two nodes.""" + source: Node target: Node type: str @@ -90,8 +94,9 @@ def traverse( session: Optional[Session] = None, keyspace: Optional[str] = None, ) -> Iterable[Relation]: - """ - Traverse the graph from the given starting nodes and return the resulting sub-graph. + """Traverse the graph from the given starting nodes. + + Returns the resulting sub-graph. Args: start: The starting node or nodes. @@ -162,8 +167,7 @@ def handle_error(e): condition.notify() def fetch_relationships(distance: int, source: Node) -> None: - """ - Fetch relationships from node `source` is found at `distance`. + """Fetch relationships from node `source` is found at `distance`. This will retrieve the edges from `source`, and visit the resulting nodes at distance `distance + 1`. @@ -200,6 +204,8 @@ def fetch_relationships(distance: int, source: Node) -> None: class AsyncPagedQuery: + """An async iterator over the results of a paged query.""" + def __init__(self, depth: int, response_future: ResponseFuture): self.loop = asyncio.get_running_loop() self.depth = depth @@ -214,6 +220,7 @@ def _handle_error(self, error): self.loop.call_soon_threadsafe(self.current_page_future.set_exception, error) async def next(self): + """Fetch the next page of results.""" page = [_parse_relation(r) for r in await self.current_page_future] if self.response_future.has_more_pages: @@ -236,11 +243,11 @@ async def atraverse( session: Optional[Session] = None, keyspace: Optional[str] = None, ) -> Iterable[Relation]: - """ - Async traversal of the graph from the given starting nodes and return the resulting - sub-graph. + """Async traversal of the graph from the given starting nodes. + + Returns the resulting sub-graph. - Parameters: + Args: start: The starting node or nodes. edge_table: The table containing the edges. edge_source_name: The name of the column containing edge source names. @@ -261,7 +268,6 @@ async def atraverse( Returns: An iterable over relations in the traversed sub-graph. """ - session = check_resolve_session(session) keyspace = check_resolve_keyspace(keyspace) diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/utils.py b/libs/knowledge-graph/ragstack_knowledge_graph/utils.py index 2de461cc1..e6a703881 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/utils.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/utils.py @@ -11,6 +11,7 @@ # This is equivalent to `itertools.batched`, but that is only available in 3.12 def batched(iterable: Iterable[T], n: int) -> Iterator[Iterator[T]]: + """Emulate itertools.batched.""" if n < 1: raise ValueError("n must be at least one") it = iter(iterable) diff --git a/libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py b/libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py index cbf2208a9..04ef3258a 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py +++ b/libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py @@ -33,6 +33,17 @@ def update_redundancy(self, new_weighted_redundancy: float): class MmrHelper: + """Helper for executing an MMR traversal query. + + Args: + query_embedding: The embedding of the query to use for scoring. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to -infinity. + """ + dimensions: int """Dimensions of the embedding.""" @@ -76,16 +87,6 @@ def __init__( lambda_mult: float = 0.5, score_threshold: float = NEG_INF, ) -> None: - """Create a helper for executing an MMR traversal query. - - Args: - query_embedding: The embedding of the query to use for scoring. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding to maximum - diversity and 1 to minimum diversity. Defaults to 0.5. - score_threshold: Only documents with a score greater than or equal - this threshold will be chosen. Defaults to -infinity. - """ self.query_embedding = _emb_to_ndarray(query_embedding) self.dimensions = self.query_embedding.shape[1] @@ -109,6 +110,7 @@ def __init__( self.best_id = None def candidate_ids(self) -> Iterable[str]: + """Return the IDs of the candidates.""" return self.candidate_id_to_index.keys() def _already_selected_embeddings(self) -> np.ndarray: @@ -186,7 +188,6 @@ def pop_best(self) -> Optional[str]: def add_candidates(self, candidates: Dict[str, List[float]]): """Add candidates to the consideration set.""" - # Determine the keys to actually include. # These are the candidates that aren't already selected # or under consideration. diff --git a/libs/knowledge-store/ragstack_knowledge_store/concurrency.py b/libs/knowledge-store/ragstack_knowledge_store/concurrency.py index a13563f75..16cc04a8c 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/concurrency.py +++ b/libs/knowledge-store/ragstack_knowledge_store/concurrency.py @@ -54,8 +54,7 @@ def execute( parameters: Optional[Tuple] = None, callback: Optional[Callable[[Sequence[NamedTuple]], Any]] = None, ): - """ - Execute a query concurrently. + """Execute a query concurrently. Because this is done concurrently, it expects a callback if you need to inspect the results. @@ -65,7 +64,6 @@ def execute( parameters: Parameter tuple for the query. Defaults to `None`. callback: Callback to apply to the results. Defaults to `None`. """ - # TODO: We could have some form of throttling, where we track the number # of pending calls and queue things if it exceed some threshold. diff --git a/libs/knowledge-store/ragstack_knowledge_store/content.py b/libs/knowledge-store/ragstack_knowledge_store/content.py index f6c62d961..7d3c3a65f 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/content.py +++ b/libs/knowledge-store/ragstack_knowledge_store/content.py @@ -2,6 +2,8 @@ class Kind(str, Enum): + """The kind of content in a document.""" + document = "document" """A root document (PDF, HTML, etc.). diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index cfd918a31..6260b2ba8 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -29,7 +29,7 @@ @dataclass class Node: - """Node in the GraphStore""" + """Node in the GraphStore.""" id: Optional[str] = None """Unique ID for the node. Will be generated by the GraphStore if not set.""" @@ -42,6 +42,8 @@ class Node: class SetupMode(Enum): + """Mode used to create the Cassandra table.""" + SYNC = 1 ASYNC = 2 OFF = 3 @@ -105,6 +107,17 @@ class _Edge: class GraphStore: + """A hybrid vector-and-graph store backed by Cassandra. + + Document chunks support vector-similarity search as well as edges linking + documents based on structural and semantic properties. + + Args: + embedding: The embeddings to use for the document content. + setup_mode: Mode used to create the Cassandra table (SYNC, + ASYNC or OFF). + """ + def __init__( self, embedding: EmbeddingModel, @@ -115,16 +128,6 @@ def __init__( keyspace: Optional[str] = None, setup_mode: SetupMode = SetupMode.SYNC, ): - """A hybrid vector-and-graph store backed by Cassandra. - - Document chunks support vector-similarity search as well as edges linking - documents based on structural and semantic properties. - - Args: - embedding: The embeddings to use for the document content. - setup_mode: Mode used to create the Cassandra table (SYNC, - ASYNC or OFF). - """ session = check_resolve_session(session) keyspace = check_resolve_keyspace(keyspace) @@ -292,6 +295,7 @@ def add_nodes( self, nodes: Iterable[Node], ) -> Iterable[str]: + """Add nodes to the graph store.""" node_ids = [] texts = [] metadatas = [] @@ -473,10 +477,10 @@ def traversal_search( k: The number of Documents to return from the initial vector search. Defaults to 4. depth: The maximum depth of edges to traverse. Defaults to 1. + Returns: Collection of retrieved documents. """ - # Depth 0: # Query for `k` nodes similar to the question. # Retrieve `content_id` and `link_to_tags`. @@ -571,6 +575,7 @@ def similarity_search( embedding: List[float], k: int = 4, ) -> Iterable[Node]: + """Retrieve nodes similar to the given embedding.""" for row in self._session.execute(self._query_by_embedding, (embedding, k)): yield _row_to_node(row) @@ -592,7 +597,6 @@ def _get_adjacent( Returns: List of adjacent edges. """ - targets = {} def add_sources(rows): diff --git a/libs/knowledge-store/ragstack_knowledge_store/knowledge_store.py b/libs/knowledge-store/ragstack_knowledge_store/knowledge_store.py index 7266cf22a..06d39bdb0 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/knowledge_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/knowledge_store.py @@ -1,4 +1,4 @@ -"""Temporary backward-compatibility for KnowledgeStore""" +"""Temporary backward-compatibility for KnowledgeStore.""" from .graph_store import ( EmbeddingModel, diff --git a/libs/knowledge-store/ragstack_knowledge_store/links.py b/libs/knowledge-store/ragstack_knowledge_store/links.py index 6dad7fa22..8a27a0d98 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/links.py +++ b/libs/knowledge-store/ragstack_knowledge_store/links.py @@ -4,18 +4,23 @@ @dataclass(frozen=True) class Link: + """A link to a tag in the graph.""" + kind: str direction: Literal["in", "out", "bidir"] tag: str @staticmethod def incoming(kind: str, tag: str) -> "Link": + """Create an incoming link.""" return Link(kind=kind, direction="in", tag=tag) @staticmethod def outgoing(kind: str, tag: str) -> "Link": + """Create an outgoing link.""" return Link(kind=kind, direction="out", tag=tag) @staticmethod def bidir(kind: str, tag: str) -> "Link": + """Create a bidirectional link.""" return Link(kind=kind, direction="bidir", tag=tag) diff --git a/libs/knowledge-store/ragstack_knowledge_store/math.py b/libs/knowledge-store/ragstack_knowledge_store/math.py index b3b5413ec..6f058fdad 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/math.py +++ b/libs/knowledge-store/ragstack_knowledge_store/math.py @@ -1,6 +1,7 @@ -"""Copied from langchain_community.utils.math -See https://github.com/langchain-ai/langchain/blob/langchain-community%3D%3D0.0.38/libs/community/langchain_community/utils/math.py -""" +"""Copied from langchain_community.utils.math. + +See https://github.com/langchain-ai/langchain/blob/langchain-community%3D%3D0.0.38/libs/community/langchain_community/utils/math.py . +""" # noqa: E501 import logging from typing import List, Union diff --git a/libs/knowledge-store/tests/unit_tests/test_mmr_helper.py b/libs/knowledge-store/tests/unit_tests/test_mmr_helper.py index eca0f3cd8..ca2b735f2 100644 --- a/libs/knowledge-store/tests/unit_tests/test_mmr_helper.py +++ b/libs/knowledge-store/tests/unit_tests/test_mmr_helper.py @@ -34,8 +34,7 @@ def angular_embedding(angle: float) -> List[float]: def test_mmr_helper_added_documetns(): - """ - Test end to end construction and MMR search. + """Test end to end construction and MMR search. The embedding function used here ensures `texts` become the following vectors on a circle (numbered v0 through v3): @@ -53,7 +52,6 @@ def test_mmr_helper_added_documetns(): Both v2 and v3 are discovered after v0. """ - helper = MmrHelper(5, angular_embedding(0.0)) # Fetching the 2 nearest neighbors to 0.0 diff --git a/libs/langchain/ragstack_langchain/colbert/colbert_retriever.py b/libs/langchain/ragstack_langchain/colbert/colbert_retriever.py index 50e9d6124..03e2deb2a 100644 --- a/libs/langchain/ragstack_langchain/colbert/colbert_retriever.py +++ b/libs/langchain/ragstack_langchain/colbert/colbert_retriever.py @@ -56,6 +56,7 @@ def _get_relevant_documents( run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: """Get documents relevant to a query. + Args: query: String to find relevant documents for run_manager: The callbacks handler to use @@ -79,6 +80,7 @@ async def _aget_relevant_documents( run_manager: AsyncCallbackManagerForRetrieverRun, ) -> List[Document]: """Asynchronously get documents relevant to a query. + Args: query: String to find relevant documents for run_manager: The callbacks handler to use diff --git a/libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py b/libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py index fc409a885..71e189fba 100644 --- a/libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py +++ b/libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py @@ -19,6 +19,8 @@ class ColbertVectorStore(VectorStore): + """VectorStore for ColBERT.""" + _vector_store: ColbertBaseVectorStore _retriever: ColbertBaseRetriever @@ -52,6 +54,7 @@ def add_texts( Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + doc_id: Optional document ID to associate with the texts. kwargs: vectorstore specific parameters Returns: @@ -75,6 +78,7 @@ async def aadd_texts( Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + doc_id: Optional document ID to associate with the texts. concurrent_inserts: How many concurrent inserts to make to the database. Defaults to 100. kwargs: vectorstore specific parameters @@ -255,7 +259,6 @@ def from_texts( **kwargs: Any, ) -> CVS: """Return VectorStore initialized from texts and embeddings.""" - instance = cls(database=database, embedding_model=embedding_model, **kwargs) instance.add_texts(texts=texts, metadatas=metadatas) return instance diff --git a/libs/ragulate/colbert_chunk_size_and_k.py b/libs/ragulate/colbert_chunk_size_and_k.py index 6b238ca21..29f6faf45 100644 --- a/libs/ragulate/colbert_chunk_size_and_k.py +++ b/libs/ragulate/colbert_chunk_size_and_k.py @@ -1,4 +1,4 @@ -# ruff: noqa: INP001, T201 +# ruff: noqa: D103, INP001, T201 import logging import os import time diff --git a/libs/ragulate/open_ai_chunk_size_and_k.py b/libs/ragulate/open_ai_chunk_size_and_k.py index 2630dc77a..27f5563ca 100644 --- a/libs/ragulate/open_ai_chunk_size_and_k.py +++ b/libs/ragulate/open_ai_chunk_size_and_k.py @@ -1,4 +1,4 @@ -# ruff: noqa: INP001, T201 +# ruff: noqa: D103, INP001, T201 import os from langchain_astradb import AstraDBVectorStore diff --git a/libs/ragulate/ragstack_ragulate/analysis.py b/libs/ragulate/ragstack_ragulate/analysis.py index 2ed18033b..42de00564 100644 --- a/libs/ragulate/ragstack_ragulate/analysis.py +++ b/libs/ragulate/ragstack_ragulate/analysis.py @@ -12,7 +12,10 @@ class Analysis: + """Analysis class.""" + def get_all_data(self, recipes: List[str]) -> DataFrame: + """Get all data from the recipes.""" df_all = pd.DataFrame() all_metrics: List[str] = [] @@ -53,6 +56,7 @@ def get_all_data(self, recipes: List[str]) -> DataFrame: return reset_df, list(set(all_metrics)) def calculate_statistics(self, df: pd.DataFrame, metrics: list): + """Calculate statistics.""" stats = {} for recipe in df["recipe"].unique(): stats[recipe] = {} @@ -73,6 +77,7 @@ def calculate_statistics(self, df: pd.DataFrame, metrics: list): return stats def output_box_plots_by_dataset(self, df: DataFrame, metrics: List[str]): + """Output box plots by dataset.""" stats = self.calculate_statistics(df, metrics) recipes = sorted(df["recipe"].unique(), key=lambda x: x.lower()) datasets = sorted(df["dataset"].unique(), key=lambda x: x.lower()) @@ -155,6 +160,7 @@ def output_box_plots_by_dataset(self, df: DataFrame, metrics: List[str]): write_image(fig, f"./{dataset}_box_plot.png") def output_histograms_by_dataset(self, df: pd.DataFrame, metrics: List[str]): + """Output histograms by dataset.""" # Append "latency" to the metrics list metrics.append("latency") @@ -246,6 +252,7 @@ def custom_hist(data, **kws): plt.close() def compare(self, recipes: List[str], output: str): + """Compare results from 2 (or more) recipes.""" df, metrics = self.get_all_data(recipes=recipes) if output == "box-plots": self.output_box_plots_by_dataset(df=df, metrics=metrics) diff --git a/libs/ragulate/ragstack_ragulate/cli.py b/libs/ragulate/ragstack_ragulate/cli.py index 8abb250ca..75d8b3bdf 100644 --- a/libs/ragulate/ragstack_ragulate/cli.py +++ b/libs/ragulate/ragstack_ragulate/cli.py @@ -14,6 +14,7 @@ def main() -> None: + """Main function for the CLI.""" parser = argparse.ArgumentParser(description="RAGu-late CLI tool.") # Subparsers for the main commands diff --git a/libs/ragulate/ragstack_ragulate/cli_commands/compare.py b/libs/ragulate/ragstack_ragulate/cli_commands/compare.py index 7b46f6a6d..7525d94fe 100644 --- a/libs/ragulate/ragstack_ragulate/cli_commands/compare.py +++ b/libs/ragulate/ragstack_ragulate/cli_commands/compare.py @@ -4,6 +4,7 @@ def setup_compare(subparsers): + """Setup the compare command.""" compare_parser = subparsers.add_parser( "compare", help="Compare results from 2 (or more) recipes" ) @@ -26,6 +27,7 @@ def setup_compare(subparsers): def remove_sqlite_extension(s): + """Remove the .sqlite extension from a string.""" if s.endswith(".sqlite"): return s[:-7] return s @@ -36,6 +38,7 @@ def call_compare( output: Optional[str] = "box-plots", **_, ): + """Compare results from 2 (or more) recipes.""" analysis = Analysis() recipes = [remove_sqlite_extension(r) for r in recipe] diff --git a/libs/ragulate/ragstack_ragulate/cli_commands/download.py b/libs/ragulate/ragstack_ragulate/cli_commands/download.py index 19b31c392..a8885a602 100644 --- a/libs/ragulate/ragstack_ragulate/cli_commands/download.py +++ b/libs/ragulate/ragstack_ragulate/cli_commands/download.py @@ -2,6 +2,7 @@ def setup_download(subparsers): + """Setup the download command.""" download_parser = subparsers.add_parser("download", help="Download a dataset") download_parser.add_argument( "dataset_name", @@ -22,5 +23,6 @@ def setup_download(subparsers): def call_download(dataset_name: str, kind: str, **_): + """Download a dataset.""" dataset = get_dataset(name=dataset_name, kind=kind) dataset.download_dataset() diff --git a/libs/ragulate/ragstack_ragulate/cli_commands/ingest.py b/libs/ragulate/ragstack_ragulate/cli_commands/ingest.py index 6e611ee81..ef2e5e88d 100644 --- a/libs/ragulate/ragstack_ragulate/cli_commands/ingest.py +++ b/libs/ragulate/ragstack_ragulate/cli_commands/ingest.py @@ -6,6 +6,7 @@ def setup_ingest(subparsers): + """Setup the ingest command.""" ingest_parser = subparsers.add_parser("ingest", help="Run an ingest pipeline") ingest_parser.add_argument( "-n", @@ -59,6 +60,7 @@ def call_ingest( dataset: List[str], **_, ): + """Run an ingest pipeline.""" datasets = [find_dataset(name=name) for name in dataset] ingredients = convert_vars_to_ingredients( diff --git a/libs/ragulate/ragstack_ragulate/cli_commands/query.py b/libs/ragulate/ragstack_ragulate/cli_commands/query.py index 034373f65..8cc55faa5 100644 --- a/libs/ragulate/ragstack_ragulate/cli_commands/query.py +++ b/libs/ragulate/ragstack_ragulate/cli_commands/query.py @@ -6,6 +6,7 @@ def setup_query(subparsers): + """Setup the query command.""" query_parser = subparsers.add_parser("query", help="Run a query pipeline") query_parser.add_argument( "-n", @@ -111,6 +112,7 @@ def call_query( model: str, **_, ): + """Run a query pipeline.""" if sample <= 0.0 or sample > 1.0: raise ValueError("Sample percent must be between 0 and 1") diff --git a/libs/ragulate/ragstack_ragulate/cli_commands/run.py b/libs/ragulate/ragstack_ragulate/cli_commands/run.py index ebe94604b..593e53b9e 100644 --- a/libs/ragulate/ragstack_ragulate/cli_commands/run.py +++ b/libs/ragulate/ragstack_ragulate/cli_commands/run.py @@ -7,6 +7,7 @@ def setup_run(subparsers): + """Setup the run command.""" run_parser = subparsers.add_parser( "run", help="Run an experiment from a config file" ) @@ -22,6 +23,7 @@ def setup_run(subparsers): def call_run(config_file: str, **_): + """Run an experiment from a config file.""" config_parser = ConfigParser.from_file(file_path=config_file) config = config_parser.get_config() diff --git a/libs/ragulate/ragstack_ragulate/config/base_config_schema.py b/libs/ragulate/ragstack_ragulate/config/base_config_schema.py index 894d68028..681e9bd2b 100644 --- a/libs/ragulate/ragstack_ragulate/config/base_config_schema.py +++ b/libs/ragulate/ragstack_ragulate/config/base_config_schema.py @@ -5,14 +5,16 @@ class BaseConfigSchema(ABC): + """Base config schema.""" + @abstractmethod - def version() -> float: - """returns the config file version""" + def version(self) -> float: + """Returns the config file version.""" @abstractmethod def schema(self) -> Dict[str, Any]: - """returns the config file schema""" + """Returns the config file schema.""" @abstractmethod def parse_document(self, document: Dict[str, Any]) -> Config: - """parses a validated config file and returns a Config object""" + """Parses a validated config file and returns a Config object.""" diff --git a/libs/ragulate/ragstack_ragulate/config/config_parser.py b/libs/ragulate/ragstack_ragulate/config/config_parser.py index 10095b726..d2ad491bb 100644 --- a/libs/ragulate/ragstack_ragulate/config/config_parser.py +++ b/libs/ragulate/ragstack_ragulate/config/config_parser.py @@ -9,6 +9,8 @@ class ConfigParser: + """Config parser.""" + _config_schema: BaseConfigSchema _valid: bool _errors: Any @@ -22,12 +24,14 @@ def __init__(self, config_schema: BaseConfigSchema, config: Dict[str, Any]): self._document = validator.document def get_config(self) -> Config: + """Return the config.""" if not self.is_valid: return None return self._config_schema.parse_document(self._document) @classmethod def from_file(cls, file_path: str) -> "ConfigParser": + """Create a ConfigParser from a file.""" with open(file_path) as file: config = yaml.safe_load(file) diff --git a/libs/ragulate/ragstack_ragulate/config/config_schema_0_1.py b/libs/ragulate/ragstack_ragulate/config/config_schema_0_1.py index 6d2256963..d228bc685 100644 --- a/libs/ragulate/ragstack_ragulate/config/config_schema_0_1.py +++ b/libs/ragulate/ragstack_ragulate/config/config_schema_0_1.py @@ -1,5 +1,7 @@ from typing import Any, Dict +from typing_extensions import override + from ragstack_ragulate.datasets import BaseDataset, find_dataset, get_dataset from .base_config_schema import BaseConfigSchema @@ -8,9 +10,13 @@ class ConfigSchema0Dot1(BaseConfigSchema): + """Config schema for version 0.1.""" + + @override def version(self): return 0.1 + @override def schema(self) -> Dict[str, Any]: step_list = { "type": "list", @@ -133,6 +139,7 @@ def schema(self) -> Dict[str, Any]: "metrics": metrics, } + @override def parse_document(self, document: Dict[str, Any]) -> Config: ingest_steps: Dict[str, Step] = {} query_steps: Dict[str, Step] = {} diff --git a/libs/ragulate/ragstack_ragulate/config/objects.py b/libs/ragulate/ragstack_ragulate/config/objects.py index c40b998d8..8c44d7e39 100644 --- a/libs/ragulate/ragstack_ragulate/config/objects.py +++ b/libs/ragulate/ragstack_ragulate/config/objects.py @@ -6,12 +6,16 @@ class Step(BaseModel): + """Step of a recipe.""" + name: str script: str method: str class Recipe(BaseModel): + """Recipe object.""" + name: str ingest: Step | None query: Step @@ -20,7 +24,11 @@ class Recipe(BaseModel): class Config(BaseModel): + """Config object.""" + class Config: + """Pydantic configuration.""" + arbitrary_types_allowed = True recipes: Dict[str, Recipe] = {} diff --git a/libs/ragulate/ragstack_ragulate/config/utils.py b/libs/ragulate/ragstack_ragulate/config/utils.py index ad14945c5..e8f2c4923 100644 --- a/libs/ragulate/ragstack_ragulate/config/utils.py +++ b/libs/ragulate/ragstack_ragulate/config/utils.py @@ -2,6 +2,7 @@ def dict_to_string(d: Dict[str, Any]) -> str: + """Convert dictionary to string.""" parts = [] for key, value in d.items(): diff --git a/libs/ragulate/ragstack_ragulate/framework.py b/libs/ragulate/ragstack_ragulate/framework.py index 9c9e7c832..85b2ce1fb 100644 --- a/libs/ragulate/ragstack_ragulate/framework.py +++ b/libs/ragulate/ragstack_ragulate/framework.py @@ -2,5 +2,7 @@ class Framework(Enum): + """Frameworks supported by RagStack Ragulate.""" + LANG_CHAIN = "langChain" LLAMA_INDEX = "llamaIndex" diff --git a/libs/ragulate/ragstack_ragulate/pipelines/base_pipeline.py b/libs/ragulate/ragstack_ragulate/pipelines/base_pipeline.py index 0b14e9e2f..7e5cab700 100644 --- a/libs/ragulate/ragstack_ragulate/pipelines/base_pipeline.py +++ b/libs/ragulate/ragstack_ragulate/pipelines/base_pipeline.py @@ -7,8 +7,8 @@ from ragstack_ragulate.datasets import BaseDataset -# Function to dynamically load a module def load_module(file_path, name): + """Load a module from a file path dynamically.""" spec = importlib.util.spec_from_file_location(name, file_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -16,11 +16,13 @@ def load_module(file_path, name): def get_method(script_path: str, pipeline_type: str, method_name: str): + """Return the method from the script.""" module = load_module(script_path, name=pipeline_type) return getattr(module, method_name) def get_method_params(method: Any) -> List[str]: + """Return the parameters of a method.""" signature = inspect.signature(method) return signature.parameters.keys() @@ -30,6 +32,7 @@ def get_ingredients( reserved_params: List[str], passed_ingredients: Dict[str, Any], ) -> Dict[str, Any]: + """Return ingredients for the given method params.""" ingredients = {} for method_param in method_params: if method_param in reserved_params or method_param in ["kwargs"]: @@ -44,6 +47,8 @@ def get_ingredients( class BasePipeline(ABC): + """Base class for all pipelines.""" + recipe_name: str script_path: str method_name: str @@ -56,12 +61,12 @@ class BasePipeline(ABC): @property @abstractmethod def pipeline_type(self): - """type of pipeline (ingest, query, cleanup)""" + """Type of pipeline (ingest, query, cleanup).""" @property @abstractmethod def get_reserved_params(self) -> List[str]: - """get the list of reserved parameter names for this pipeline type""" + """Get the list of reserved parameter names for this pipeline type.""" def __init__( self, @@ -101,12 +106,15 @@ def __init__( exit(1) def get_method(self): + """Return the pipeline method.""" return self._method def dataset_names(self) -> List[str]: + """Return the names of the datasets.""" return [d.name for d in self.datasets] def key(self) -> str: + """Return the pipeline key.""" key_parts = [ self.pipeline_type, self.script_path, diff --git a/libs/ragulate/ragstack_ragulate/pipelines/feedbacks.py b/libs/ragulate/ragstack_ragulate/pipelines/feedbacks.py index fa662cfff..5b21fb161 100644 --- a/libs/ragulate/ragstack_ragulate/pipelines/feedbacks.py +++ b/libs/ragulate/ragstack_ragulate/pipelines/feedbacks.py @@ -9,6 +9,8 @@ class Feedbacks: + """Pipeline feedbacks.""" + _context: Lens _llm_provider: LLMProvider @@ -17,6 +19,7 @@ def __init__(self, llm_provider: LLMProvider, pipeline: Any) -> None: self._llm_provider = llm_provider def groundedness(self) -> Feedback: + """Return groundedness feedback.""" return ( Feedback( self._llm_provider.groundedness_measure_with_cot_reasons, @@ -27,11 +30,13 @@ def groundedness(self) -> Feedback: ) def answer_relevance(self) -> Feedback: + """Return answer relevance feedback.""" return Feedback( self._llm_provider.relevance_with_cot_reasons, name="answer_relevance" ).on_input_output() def context_relevance(self) -> Feedback: + """Return context relevance feedback.""" return ( Feedback( self._llm_provider.qs_relevance_with_cot_reasons, @@ -43,6 +48,7 @@ def context_relevance(self) -> Feedback: ) def answer_correctness(self, golden_set: List[Dict[str, str]]) -> Feedback: + """Return answer correctness feedback.""" # GroundTruth for comparing the Answer to the Ground-Truth Answer ground_truth_collection = GroundTruthAgreement( ground_truth=golden_set, provider=self._llm_provider diff --git a/libs/ragulate/ragstack_ragulate/pipelines/ingest_pipeline.py b/libs/ragulate/ragstack_ragulate/pipelines/ingest_pipeline.py index 1f1baac93..928d3da7a 100644 --- a/libs/ragulate/ragstack_ragulate/pipelines/ingest_pipeline.py +++ b/libs/ragulate/ragstack_ragulate/pipelines/ingest_pipeline.py @@ -2,6 +2,7 @@ from typing import List from tqdm import tqdm +from typing_extensions import override from ragstack_ragulate.logging_config import logger @@ -9,15 +10,20 @@ class IngestPipeline(BasePipeline): + """Ingest pipeline.""" + @property + @override def pipeline_type(self): return "ingest" @property + @override def get_reserved_params(self) -> List[str]: return ["file_path"] def ingest(self): + """Run the ingest pipeline.""" logger.info( f"Starting ingest {self.recipe_name} " f"on {self.script_path}/{self.method_name} " diff --git a/libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py b/libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py index 1085f7666..8b97e2991 100644 --- a/libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py +++ b/libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py @@ -15,6 +15,7 @@ ) from trulens_eval.feedback.provider.base import LLMProvider from trulens_eval.schema.feedback import FeedbackMode, FeedbackResultStatus +from typing_extensions import override from ragstack_ragulate.datasets import BaseDataset from ragstack_ragulate.logging_config import logger @@ -25,6 +26,8 @@ class QueryPipeline(BasePipeline): + """Query pipeline.""" + _sigint_received = False _tru: Tru @@ -39,10 +42,12 @@ class QueryPipeline(BasePipeline): _evaluation_running = False @property + @override def pipeline_type(self): return "query" @property + @override def get_reserved_params(self) -> List[str]: return [] @@ -110,14 +115,17 @@ def __init__( self._total_feedbacks = self._total_queries * metric_count def signal_handler(self, _, __): + """Handle SIGINT signal.""" self._sigint_received = True self.stop_evaluation("sigint") def start_evaluation(self): + """Start evaluation.""" self._tru.start_evaluator(disable_tqdm=True) self._evaluation_running = True def stop_evaluation(self, loc: str): + """Stop evaluation.""" if self._evaluation_running: try: logger.debug(f"Stopping evaluation from: {loc}") @@ -130,6 +138,7 @@ def stop_evaluation(self, loc: str): self._progress.close() def update_progress(self, query_change: int = 0): + """Update progress bar.""" self._finished_queries += query_change status = self._tru.db.get_feedback_count_by_status() @@ -152,6 +161,7 @@ def update_progress(self, query_change: int = 0): self._finished_feedbacks = done def get_provider(self) -> LLMProvider: + """Get the LLM provider.""" llm_provider = self.llm_provider.lower() model_name = self.model_name @@ -170,6 +180,7 @@ def get_provider(self) -> LLMProvider: raise ValueError(f"Unsupported provider: {llm_provider}") def query(self): + """Run the query pipeline.""" query_method = self.get_method() pipeline = query_method(**self.ingredients) diff --git a/libs/ragulate/ragstack_ragulate/utils.py b/libs/ragulate/ragstack_ragulate/utils.py index 6345cca34..fdf6a9ea9 100644 --- a/libs/ragulate/ragstack_ragulate/utils.py +++ b/libs/ragulate/ragstack_ragulate/utils.py @@ -5,6 +5,7 @@ def get_tru(recipe_name: str) -> Tru: + """Return Tru for given recipe name.""" Tru.RETRY_FAILED_SECONDS = 60 Tru.RETRY_RUNNING_SECONDS = 30 return Tru( @@ -15,13 +16,14 @@ def get_tru(recipe_name: str) -> Tru: def convert_vars_to_ingredients( var_names: List[str], var_values: List[str] ) -> Dict[str, Any]: + """Convert variables to ingredients.""" params: Dict[str, Any] = {} for i, name in enumerate(var_names): - params[name] = convert_string(var_values[i]) + params[name] = _convert_string(var_values[i]) return params -def convert_string(s): +def _convert_string(s): s = s.strip() if re.match(r"^\d+$", s): return int(s) diff --git a/libs/ragulate/scripts/test_integration_runner.py b/libs/ragulate/scripts/test_integration_runner.py index 238aceff3..59824d171 100755 --- a/libs/ragulate/scripts/test_integration_runner.py +++ b/libs/ragulate/scripts/test_integration_runner.py @@ -5,6 +5,7 @@ def main(): + """Run the integration tests.""" sys.exit(pytest.main(["tests/integration_tests"])) diff --git a/libs/ragulate/scripts/test_unit_runner.py b/libs/ragulate/scripts/test_unit_runner.py index b1ea137ac..c38a3a989 100755 --- a/libs/ragulate/scripts/test_unit_runner.py +++ b/libs/ragulate/scripts/test_unit_runner.py @@ -5,6 +5,7 @@ def main(): + """Run the unit tests.""" sys.exit(pytest.main(["tests/unit_tests"])) diff --git a/libs/tests-utils/ragstack_tests_utils/test_data.py b/libs/tests-utils/ragstack_tests_utils/test_data.py index f49d56cac..a9a433d14 100644 --- a/libs/tests-utils/ragstack_tests_utils/test_data.py +++ b/libs/tests-utils/ragstack_tests_utils/test_data.py @@ -30,22 +30,17 @@ def save_csv_embedding(csv_file_name: str, embedding: Embedding): @staticmethod def climate_change_text() -> str: - """ - Returns: A short, highly-technical text on climate change. - """ + """Returns: A short, highly-technical text on climate change.""" return TestData._get_text_file("climate_change.txt") @staticmethod def climate_change_embedding() -> Embedding: - """ - Returns: An embedding for the `climate_change_text()` text - """ + """Returns: An embedding for the `climate_change_text()` text""" return TestData._get_csv_embedding("climate_change.csv") @staticmethod def marine_animals_text() -> str: - """ - Returns: + """Returns: A story of approx 350 words about marine animals. Potential queries on the text: @@ -54,13 +49,11 @@ def marine_animals_text() -> str: - How do anglerfish adapt to the deep ocean's darkness? - What role do coral reefs play in marine ecosystems? """ - return TestData._get_text_file("marine_animals.txt") @staticmethod def nebula_voyager_text() -> str: - """ - Returns: + """Returns: A story of approx 2500 words about a theoretical spaceship. Includes very technical names and terms that can be difficult for standard retrieval systems. @@ -80,13 +73,9 @@ def nebula_voyager_text() -> str: @staticmethod def renewable_energy_text() -> str: - """ - Returns: A short, highly-technical text on renewable energy - """ + """Returns: A short, highly-technical text on renewable energy""" return TestData._get_text_file("renewable_energy.txt") def renewable_energy_embedding() -> Embedding: - """ - Returns: An embedding for the `renewable_energy_text()` text - """ + """Returns: An embedding for the `renewable_energy_text()` text""" return TestData._get_csv_embedding("renewable_energy.csv") diff --git a/pyproject.toml b/pyproject.toml index 3777c302d..81d73466c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,8 +58,13 @@ warn_unused_ignores = true extend-include = ["*.ipynb"] [tool.ruff.lint] +pydocstyle.convention = "google" ignore = [ "COM812", # Messes with the formatter + "D100", # Do we want to activate (docstring in module) ? + "D104", # Do we want to activate (docstring in package) ? + "D105", # Do we want to activate (docstring in magic method) ? + "D107", # Do we want to activate (docstring in __init__) ? "ERA", # Do we want to activate (no commented code) ? "ISC001", # Messes with the formatter "PERF203", # Incorrect detection @@ -76,6 +81,7 @@ select = [ "BLE", "C4", "COM", + "D", "DTZ", "E", "EXE", @@ -120,15 +126,20 @@ select = [ "UP006", # Incompatible with Pydantic v1 "UP007", # Incompatible with Pydantic v1 ] -"**/{examples,notebooks,tests,e2e-tests}/*" = [ +"**/{examples,notebooks,tests,e2e-tests,tests-utils}/*" = [ + "D", "T20", ] "scripts/*" = [ + "D", "T20", ] "docker/examples/*" = [ "INP001", ] +"libs/langchain/ragstack_langchain/graph_store/*" = [ + "D", +] [build-system] requires = ["poetry-core"]