From 5982a0796718b46b5475d92470b9fac6d07060c8 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Fri, 8 Nov 2024 17:30:27 +0800 Subject: [PATCH] refine functions Signed-off-by: ChengZi --- .../langchain_milvus/vectorstores/milvus.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/libs/milvus/langchain_milvus/vectorstores/milvus.py b/libs/milvus/langchain_milvus/vectorstores/milvus.py index 408da6e..636eb23 100644 --- a/libs/milvus/langchain_milvus/vectorstores/milvus.py +++ b/libs/milvus/langchain_milvus/vectorstores/milvus.py @@ -19,7 +19,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore -from pymilvus import RRFRanker, WeightedRanker +from pymilvus import MilvusClient, RRFRanker, WeightedRanker from langchain_milvus import MilvusCollectionHybridSearchRetriever from langchain_milvus.utils.sparse import BaseSparseEmbedding @@ -409,7 +409,7 @@ def embeddings(self) -> Union[EmbeddingType, List[EmbeddingType]]: # type: igno return self.embedding_func @property - def client(self) -> Any: + def client(self) -> MilvusClient: """Get client.""" return self._milvus_client @@ -419,15 +419,11 @@ def _is_multi_vector(self) -> bool: @property def _is_sparse(self) -> bool: - if self.index_params is None: - return False - indexes_params = self._as_list(self.index_params) - if len(indexes_params) > 1: - return False - index_type = indexes_params[0]["index_type"] - if "SPARSE" in index_type: + embedding_func: List[EmbeddingType] = self._as_list(self.embedding_func) + if self._is_sparse_embedding(embedding_func[0]): return True - return False + else: + return False @staticmethod def _is_sparse_embedding(embeddings_function: EmbeddingType) -> bool: @@ -1396,12 +1392,20 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: - etc. """ - if self.index_params is None: - raise ValueError("No index params provided.") + if not self.col or not self.col.indexes: + raise ValueError( + "No index params provided. Could not determine relevance function." + ) if self._is_multi_vector: - raise ValueError("No supported normalization function for multi vectors.") + raise ValueError( + "No supported normalization function for multi vectors. " + "Could not determine relevance function." + ) if self._is_sparse: - raise ValueError("No supported normalization function for sparse indexes.") + raise ValueError( + "No supported normalization function for sparse indexes. " + "Could not determine relevance function." + ) def _map_l2_to_similarity(l2_distance: float) -> float: """Return a similarity score on a scale [0, 1]. @@ -1423,6 +1427,12 @@ def _map_ip_to_similarity(ip_score: float) -> float: """ return (ip_score + 1) / 2.0 + if self.index_params is None: + logger.warning( + "No index params provided. Could not determine relevance function. " + "Use L2 distance as default." + ) + return _map_l2_to_similarity indexes_params = self._as_list(self.index_params) metric_type = indexes_params[0]["metric_type"] if metric_type == "L2":